Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up utils #692

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/tests/architect_tests/test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sqlalchemy import create_engine
from contextlib import contextmanager

from triage.component.catwalk.utils import filename_friendly_hash
from triage.util.hash import filename_friendly_hash
from triage.component.architect.feature_group_creator import FeatureGroup
from triage.component.architect.builders import MatrixBuilder
from triage.component.catwalk.db import ensure_db
Expand Down
50 changes: 0 additions & 50 deletions src/tests/architect_tests/test_database_reflection.py

This file was deleted.

7 changes: 0 additions & 7 deletions src/tests/architect_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@
import shutil
import sys
import tempfile
import random
from contextlib import contextmanager

import pandas as pd
import yaml
import numpy


def convert_string_column_to_date(column):
Expand Down Expand Up @@ -138,10 +135,6 @@ def TemporaryDirectory():
shutil.rmtree(name)


def fake_labels(length):
return numpy.array([random.choice([True, False]) for i in range(0, length)])


def assert_index(engine, table, column):
"""Assert that a table has an index on a given column

Expand Down
3 changes: 2 additions & 1 deletion src/tests/catwalk_tests/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
subset_labels_and_predictions,
)
from triage.component.catwalk.metrics import Metric
from triage.component.catwalk.subsetters import get_subset_table_name
from triage.util.hash import filename_friendly_hash
import testing.postgresql
import datetime
import re
Expand All @@ -15,7 +17,6 @@
from numpy.testing import assert_almost_equal, assert_array_equal
import pandas
from sqlalchemy.sql.expression import text
from triage.component.catwalk.utils import filename_friendly_hash, get_subset_table_name
from tests.utils import fake_labels, fake_trained_model, MockMatrixStore
from tests.results_tests.factories import (
ModelFactory,
Expand Down
2 changes: 1 addition & 1 deletion src/tests/catwalk_tests/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from triage.component.catwalk import ModelTrainTester, Predictor, ModelTrainer, ModelEvaluator, IndividualImportanceCalculator
from triage.component.catwalk.utils import save_experiment_and_get_hash
from triage.component.results_schema.utils import save_experiment_and_get_hash
from triage.component.catwalk.model_trainers import flatten_grid_config
from triage.component.catwalk.storage import (
ModelStorageEngine,
Expand Down
48 changes: 48 additions & 0 deletions src/tests/catwalk_tests/test_ranking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import numpy
from numpy.testing import assert_array_equal
import pytest

from triage.component.catwalk.ranking import sort_predictions_and_labels


def test_sort_predictions_and_labels():
predictions = numpy.array([0.5, 0.4, 0.6, 0.5])

labels = numpy.array([0, 0, 1, 1])

# best sort
sorted_predictions, sorted_labels = sort_predictions_and_labels(
predictions, labels, tiebreaker='best'
)
assert_array_equal(sorted_predictions, numpy.array([0.6, 0.5, 0.5, 0.4]))
assert_array_equal(sorted_labels, numpy.array([1, 1, 0, 0]))

# worst wort
sorted_predictions, sorted_labels = sort_predictions_and_labels(
predictions, labels, tiebreaker='worst'
)
assert_array_equal(sorted_predictions, numpy.array([0.6, 0.5, 0.5, 0.4]))
assert_array_equal(sorted_labels, numpy.array([1, 0, 1, 0]))

# random tiebreaker needs a seed
with pytest.raises(ValueError):
sort_predictions_and_labels(predictions, labels, tiebreaker='random')

# random tiebreaker respects the seed
sorted_predictions, sorted_labels = sort_predictions_and_labels(
predictions,
labels,
tiebreaker='random',
sort_seed=1234
)
assert_array_equal(sorted_predictions, numpy.array([0.6, 0.5, 0.5, 0.4]))
assert_array_equal(sorted_labels, numpy.array([1, 1, 0, 0]))

sorted_predictions, sorted_labels = sort_predictions_and_labels(
predictions,
labels,
tiebreaker='random',
sort_seed=24376234
)
assert_array_equal(sorted_predictions, numpy.array([0.6, 0.5, 0.5, 0.4]))
assert_array_equal(sorted_labels, numpy.array([1, 0, 1, 0]))
155 changes: 0 additions & 155 deletions src/tests/catwalk_tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,155 +0,0 @@
from triage.component.catwalk.utils import (
filename_friendly_hash,
save_experiment_and_get_hash,
associate_models_with_experiment,
associate_matrices_with_experiment,
missing_model_hashes,
missing_matrix_uuids,
sort_predictions_and_labels,
)
from triage.component.results_schema.schema import Matrix, Model
from triage.component.catwalk.db import ensure_db
from sqlalchemy import create_engine
import testing.postgresql
import datetime
import re
import numpy
from numpy.testing import assert_array_equal
import pytest


def test_filename_friendly_hash():
data = {
"stuff": "stuff",
"other_stuff": "more_stuff",
"a_datetime": datetime.datetime(2015, 1, 1),
"a_date": datetime.date(2016, 1, 1),
"a_number": 5.0,
}
output = filename_friendly_hash(data)
assert isinstance(output, str)
assert re.match("^[\w]+$", output) is not None

# make sure ordering keys differently doesn't change the hash
new_output = filename_friendly_hash(
{
"other_stuff": "more_stuff",
"stuff": "stuff",
"a_datetime": datetime.datetime(2015, 1, 1),
"a_date": datetime.date(2016, 1, 1),
"a_number": 5.0,
}
)
assert new_output == output

# make sure new data hashes to something different
new_output = filename_friendly_hash({"stuff": "stuff", "a_number": 5.0})
assert new_output != output


def test_filename_friendly_hash_stability():
nested_data = {"one": "two", "three": {"four": "five", "six": "seven"}}
output = filename_friendly_hash(nested_data)
# 1. we want to make sure this is stable across different runs
# so hardcode an expected value
assert output == "9a844a7ebbfd821010b1c2c13f7391e6"
other_nested_data = {"one": "two", "three": {"six": "seven", "four": "five"}}
new_output = filename_friendly_hash(other_nested_data)
assert output == new_output


def test_save_experiment_and_get_hash():
# no reason to make assertions on the config itself, use a basic dict
experiment_config = {"one": "two"}
with testing.postgresql.Postgresql() as postgresql:
engine = create_engine(postgresql.url())
ensure_db(engine)
exp_hash = save_experiment_and_get_hash(experiment_config, engine)
assert isinstance(exp_hash, str)
new_hash = save_experiment_and_get_hash(experiment_config, engine)
assert new_hash == exp_hash


def test_missing_model_hashes():
with testing.postgresql.Postgresql() as postgresql:
db_engine = create_engine(postgresql.url())
ensure_db(db_engine)

experiment_hash = save_experiment_and_get_hash({}, db_engine)
model_hashes = ['abcd', 'bcde', 'cdef']

# if we associate model hashes with an experiment but don't actually train the models
# they should show up as missing
associate_models_with_experiment(experiment_hash, model_hashes, db_engine)
assert missing_model_hashes(experiment_hash, db_engine) == model_hashes

# if we insert a model row they should no longer be considered missing
db_engine.execute(
f"insert into {Model.__table__.fullname} (model_hash) values (%s)",
model_hashes[0]
)
assert missing_model_hashes(experiment_hash, db_engine) == model_hashes[1:]


def test_missing_matrix_uuids():
with testing.postgresql.Postgresql() as postgresql:
db_engine = create_engine(postgresql.url())
ensure_db(db_engine)

experiment_hash = save_experiment_and_get_hash({}, db_engine)
matrix_uuids = ['abcd', 'bcde', 'cdef']

# if we associate matrix uuids with an experiment but don't actually build the matrices
# they should show up as missing
associate_matrices_with_experiment(experiment_hash, matrix_uuids, db_engine)
assert missing_matrix_uuids(experiment_hash, db_engine) == matrix_uuids

# if we insert a matrix row they should no longer be considered missing
db_engine.execute(
f"insert into {Matrix.__table__.fullname} (matrix_uuid) values (%s)",
matrix_uuids[0]
)
assert missing_matrix_uuids(experiment_hash, db_engine) == matrix_uuids[1:]


def test_sort_predictions_and_labels():
predictions = numpy.array([0.5, 0.4, 0.6, 0.5])

labels = numpy.array([0, 0, 1, 1])

# best sort
sorted_predictions, sorted_labels = sort_predictions_and_labels(
predictions, labels, tiebreaker='best'
)
assert_array_equal(sorted_predictions, numpy.array([0.6, 0.5, 0.5, 0.4]))
assert_array_equal(sorted_labels, numpy.array([1, 1, 0, 0]))

# worst wort
sorted_predictions, sorted_labels = sort_predictions_and_labels(
predictions, labels, tiebreaker='worst'
)
assert_array_equal(sorted_predictions, numpy.array([0.6, 0.5, 0.5, 0.4]))
assert_array_equal(sorted_labels, numpy.array([1, 0, 1, 0]))

# random tiebreaker needs a seed
with pytest.raises(ValueError):
sort_predictions_and_labels(predictions, labels, tiebreaker='random')

# random tiebreaker respects the seed
sorted_predictions, sorted_labels = sort_predictions_and_labels(
predictions,
labels,
tiebreaker='random',
sort_seed=1234
)
assert_array_equal(sorted_predictions, numpy.array([0.6, 0.5, 0.5, 0.4]))
assert_array_equal(sorted_labels, numpy.array([1, 1, 0, 0]))

sorted_predictions, sorted_labels = sort_predictions_and_labels(
predictions,
labels,
tiebreaker='random',
sort_seed=24376234
)
assert_array_equal(sorted_predictions, numpy.array([0.6, 0.5, 0.5, 0.4]))
assert_array_equal(sorted_labels, numpy.array([1, 0, 1, 0]))
22 changes: 0 additions & 22 deletions src/tests/catwalk_tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,11 @@
import datetime
import random
import tempfile
from contextlib import contextmanager
import pytest

import numpy
import pandas
import yaml

from triage.component.catwalk.storage import (
ProjectStorage,
)
from triage.util.structs import FeatureNameList


def fake_labels(length):
return numpy.array([random.choice([True, False]) for i in range(0, length)])


@pytest.fixture
def sample_metadata():
return {
Expand Down Expand Up @@ -46,13 +34,3 @@ def sample_df():
"label": ["good", "bad"],
}
).set_index("entity_id")


@pytest.fixture
def sample_matrix_store():
with tempfile.TemporaryDirectory() as tempdir:
project_storage = ProjectStorage(tempdir)
store = project_storage.matrix_storage_engine().get_store("1234")
store.matrix = sample_df()
store.metadata = sample_metadata()
return store
Loading