diff --git a/docs/sources/experiments/architecture.md b/docs/sources/experiments/architecture.md index 55db7e134..887858618 100644 --- a/docs/sources/experiments/architecture.md +++ b/docs/sources/experiments/architecture.md @@ -334,7 +334,7 @@ On the other hand, new options that affect only runtime concerns (e.g. performan ## Storage Abstractions -Another important part of enabling different execution contexts is being able to pass large, persisted objects (e.g. matrices or models) by reference to another process or cluster. To achieve this, as well as provide the ability to configure different storage mediums (e.g. S3) and formats (e,g, HDF) without changes to the Experiment class, all references to these large objects within any components are handled through an abstraction layer. +Another important part of enabling different execution contexts is being able to pass large, persisted objects (e.g. matrices or models) by reference to another process or cluster. To achieve this, as well as provide the ability to configure different storage mediums (e.g. S3) without changes to the Experiment class, all references to these large objects within any components are handled through an abstraction layer. ### Matrix Storage diff --git a/docs/sources/experiments/running.md b/docs/sources/experiments/running.md index 457065513..88ec1f574 100644 --- a/docs/sources/experiments/running.md +++ b/docs/sources/experiments/running.md @@ -114,37 +114,6 @@ experiment.run() ``` -## Using HDF5 as a matrix storage format - -Triage by default uses CSV format to store matrices, but this can take up a lot of space. However, this is configurable. Triage ships with an HDF5 storage module that you can use. - -### CLI - -On the command-line, this is configurable using the `--matrix-format` option, and supports `csv` and `hdf`. - -```bash -triage experiment example/config/experiment.yaml --matrix-format hdf -``` - -### Python - -In Python, this is configurable using the `matrix_storage_class` keyword argument. To allow users to write their own storage modules, this is passed in the form of a class. The shipped modules are in `triage.component.catwalk.storage`. If you'd like to write your own storage module, you can use the [existing modules](https://github.com/dssg/triage/blob/master/src/triage/component/catwalk/storage.py) as a guide. - -```python -from triage.experiments import SingleThreadedExperiment -from triage.component.catwalk.storage import HDFMatrixStore - -experiment = SingleThreadedExperiment( - config=experiment_config - db_engine=create_engine(...), - matrix_storage_class=HDFMatrixStore, - project_path='/path/to/directory/to/save/data', -) -experiment.run() -``` - -Note: The HDF storage option is *not* compatible with S3. - ## Validating an Experiment Configuring an experiment is complex, and running an experiment can take a long time as data scales up. If there are any misconfigured values, it's going to help out a lot to figure out what they are before we run the Experiment. So when you have completed your experiment config and want to test it out, it's best to validate the Experiment first. If any problems are detectable in your Experiment, either in configuration or the database tables referenced by it, this method will throw an exception. For instance, if I refer to the `cat_complaints` table in a feature aggregation but it doesn't exist, I'll see something like this: diff --git a/src/tests/architect_tests/test_builders.py b/src/tests/architect_tests/test_builders.py index cda80f4b8..371d4bb61 100644 --- a/src/tests/architect_tests/test_builders.py +++ b/src/tests/architect_tests/test_builders.py @@ -2,6 +2,7 @@ from unittest import TestCase import pandas as pd +from pandas.testing import assert_frame_equal import testing.postgresql from mock import Mock from sqlalchemy import create_engine @@ -11,8 +12,9 @@ from triage.component.architect.feature_group_creator import FeatureGroup from triage.component.architect.builders import MatrixBuilder from triage.component.catwalk.db import ensure_db -from triage.component.catwalk.storage import ProjectStorage, HDFMatrixStore +from triage.component.catwalk.storage import ProjectStorage from triage.component.results_schema.schema import Matrix +from triage.util.pandas import downcast_matrix from .utils import ( create_schemas, @@ -245,38 +247,6 @@ def get_matrix_storage_engine(): yield ProjectStorage(temp_dir).matrix_storage_engine() -def test_query_to_df(): - """ Test the write_to_csv function by checking whether the csv contains the - correct number of lines. - """ - with testing.postgresql.Postgresql() as postgresql: - # create an engine and generate a table with fake feature data - engine = create_engine(postgresql.url()) - create_schemas( - engine=engine, features_tables=features_tables, labels=labels, states=states - ) - - with get_matrix_storage_engine() as matrix_storage_engine: - builder = MatrixBuilder( - db_config=db_config, - matrix_storage_engine=matrix_storage_engine, - experiment_hash=experiment_hash, - engine=engine, - ) - - # for each table, check that corresponding csv has the correct # of rows - for table in features_tables: - df = builder.query_to_df( - """ - select * - from features.features{} - """.format( - features_tables.index(table) - ) - ) - assert len(df) == len(table) - - def test_make_entity_date_table(): """ Test that the make_entity_date_table function contains the correct values. @@ -402,142 +372,7 @@ def test_make_entity_date_table_include_missing_labels(): assert sorted(result.values.tolist()) == sorted(ids_dates.values.tolist()) -def test_load_features_data(): - dates = [datetime.datetime(2016, 1, 1, 0, 0), datetime.datetime(2016, 2, 1, 0, 0)] - - # make dataframe for entity ids and dates - ids_dates = create_entity_date_df( - labels=labels, - states=states, - as_of_dates=dates, - label_name="booking", - label_type="binary", - label_timespan="1 month", - ) - - features = [["f1", "f2"], ["f3", "f4"]] - # make dataframes of features to test against - features_dfs = [] - for i, table in enumerate(features_tables): - cols = ["entity_id", "as_of_date"] + features[i] - temp_df = pd.DataFrame(table, columns=cols) - temp_df["as_of_date"] = convert_string_column_to_date(temp_df["as_of_date"]) - features_dfs.append( - ids_dates.merge( - right=temp_df, how="left", on=["entity_id", "as_of_date"] - ).set_index(["entity_id", "as_of_date"]) - ) - - # create an engine and generate a table with fake feature data - with testing.postgresql.Postgresql() as postgresql: - engine = create_engine(postgresql.url()) - create_schemas( - engine=engine, features_tables=features_tables, labels=labels, states=states - ) - - with get_matrix_storage_engine() as matrix_storage_engine: - builder = MatrixBuilder( - db_config=db_config, - matrix_storage_engine=matrix_storage_engine, - experiment_hash=experiment_hash, - engine=engine, - ) - - # make the entity-date table - entity_date_table_name = builder.make_entity_date_table( - as_of_times=dates, - label_type="binary", - label_name="booking", - state="active", - matrix_type="train", - matrix_uuid="my_uuid", - label_timespan="1 month", - ) - - feature_dictionary = dict( - ("features{}".format(i), feature_list) - for i, feature_list in enumerate(features) - ) - - returned_features_dfs = builder.load_features_data( - as_of_times=dates, - feature_dictionary=feature_dictionary, - entity_date_table_name=entity_date_table_name, - matrix_uuid="my_uuid", - ) - - # get the queries and test them - for result, df in zip(returned_features_dfs, features_dfs): - test = result == df - assert test.all().all() - - -def test_load_labels_data(): - """ Test the load_labels_data function by checking whether the query - produces the correct labels - """ - # set up labeling config variables - dates = [datetime.datetime(2016, 1, 1, 0, 0), datetime.datetime(2016, 2, 1, 0, 0)] - - # make a dataframe of labels to test against - labels_df = pd.DataFrame( - labels, - columns=[ - "entity_id", - "as_of_date", - "label_timespan", - "label_name", - "label_type", - "label", - ], - ) - - labels_df["as_of_date"] = convert_string_column_to_date(labels_df["as_of_date"]) - labels_df.set_index(["entity_id", "as_of_date"]) - - # create an engine and generate a table with fake feature data - with testing.postgresql.Postgresql() as postgresql: - engine = create_engine(postgresql.url()) - create_schemas(engine, features_tables, labels, states) - with get_matrix_storage_engine() as matrix_storage_engine: - builder = MatrixBuilder( - db_config=db_config, - matrix_storage_engine=matrix_storage_engine, - experiment_hash=experiment_hash, - engine=engine, - ) - - # make the entity-date table - entity_date_table_name = builder.make_entity_date_table( - as_of_times=dates, - label_type="binary", - label_name="booking", - state="active", - matrix_type="train", - matrix_uuid="my_uuid", - label_timespan="1 month", - ) - - result = builder.load_labels_data( - label_name=label_name, - label_type=label_type, - label_timespan="1 month", - matrix_uuid="my_uuid", - entity_date_table_name=entity_date_table_name, - ) - df = pd.DataFrame.from_dict( - { - "entity_id": [2, 3, 4, 4], - "as_of_date": [dates[1], dates[1], dates[0], dates[1]], - "booking": [0, 0, 1, 0], - } - ).set_index(["entity_id", "as_of_date"]) - - test = result == df - assert test.all().all() - - -def test_load_labels_data_include_missing_labels_as_false(): +def test_label_query_include_missing_labels_in_train_as_False(): """ Test the load_labels_data function by checking whether the query produces the correct labels """ @@ -591,58 +426,22 @@ def test_load_labels_data_include_missing_labels_as_false(): label_timespan="1 month", ) - result = builder.load_labels_data( + result = pd.read_sql(builder.label_load_query( label_name=label_name, label_type=label_type, label_timespan="1 month", - matrix_uuid="my_uuid", entity_date_table_name=entity_date_table_name, - ) - df = pd.DataFrame.from_dict( + ), con=engine)['booking'] + expected = pd.DataFrame.from_dict( { "entity_id": [0, 2, 3, 4, 4], "as_of_date": [dates[2], dates[1], dates[1], dates[0], dates[1]], "booking": [0, 0, 0, 1, 0], } - ).set_index(["entity_id", "as_of_date"]) + )['booking'] # the first row would not be here if we had not configured the Builder # to include missing labels as false - - test = result == df - assert test.all().all() - - -class TestMergeFeatureCSVs(TestCase): - def test_badinput(self): - """We assert column names, so replacing 'date' with 'as_of_date' - should result in an error""" - with get_matrix_storage_engine() as matrix_storage_engine: - builder = MatrixBuilder( - db_config=db_config, - matrix_storage_engine=matrix_storage_engine, - experiment_hash=experiment_hash, - engine=None, - ) - dataframes = [ - pd.DataFrame.from_records( - [(1, 3, 3), (4, 5, 6), (7, 8, 9)], - columns=("entity_id", "date", "f1"), - index=["entity_id", "date"], - ), - pd.DataFrame.from_records( - [(1, 2, 3), (4, 5, 9), (7, 8, 15)], - columns=("entity_id", "date", "f3"), - index=["entity_id", "date"], - ), - pd.DataFrame.from_records( - [(1, 2, 2), (4, 5, 20), (7, 8, 56)], - columns=("entity_id", "date", "f3"), - index=["entity_id", "date"], - ), - ] - - with self.assertRaises(ValueError): - builder.merge_feature_csvs(dataframes, matrix_uuid="1234") + assert expected.tolist() == result.tolist() class TestBuildMatrix(TestCase): @@ -675,6 +474,19 @@ def good_dates(self): datetime.datetime(2016, 3, 1, 0, 0), ] + @property + def expected_design_matrix(self): + return downcast_matrix(pd.DataFrame.from_dict( + { + "entity_id": [2, 3, 4, 4, 4], + "as_of_date": [self.good_dates[1], self.good_dates[1], self.good_dates[0], self.good_dates[1], self.good_dates[2]], + "f1": [9.0, 9.0, 9.0, 9.0, 9.0], + "f2": [9.0, 9.0, 9.0, 9.0, 9.0], + "f3": [2.0, 2.0, 9.0, 9.0, 1.0], + "f4": [3.0, 2.0, 9.0, 9.0, 4.0], + } + ).set_index(["entity_id", "as_of_date"])) + def test_train_matrix(self): with testing.postgresql.Postgresql() as postgresql: # create an engine and generate a table with fake feature data @@ -704,7 +516,12 @@ def test_train_matrix(self): matrix_uuid=uuid, matrix_type="train", ) - assert len(matrix_storage_engine.get_store(uuid).design_matrix) == 5 + matrix_store = matrix_storage_engine.get_store(uuid) + assert_frame_equal( + self.expected_design_matrix, + matrix_store.design_matrix + ) + assert [0.0, 0.0, 1.0, 0.0, 0.0] == matrix_store.labels.tolist() assert builder.sessionmaker().query(Matrix).get(uuid).feature_dictionary ==self.good_feature_dictionary def test_test_matrix(self): @@ -737,12 +554,17 @@ def test_test_matrix(self): matrix_uuid=uuid, matrix_type="test", ) + matrix_store = matrix_storage_engine.get_store(uuid) + assert_frame_equal( + self.expected_design_matrix, + matrix_store.design_matrix + ) + assert [0.0, 0.0, 1.0, 0.0, 0.0] == matrix_store.labels.tolist() - assert len(matrix_storage_engine.get_store(uuid).design_matrix) == 5 - - def test_hdf_matrix(self): - with testing.postgresql.Postgresql() as postgresql: - # create an engine and generate a table with fake feature data + def test_bad_index(self): + """The feature load queries assume certain column names for the index names, + so replacing 'as_of_date' with 'mydatecol' should result in an error""" + with testing.postgresql.Postgresql() as postgresql, get_matrix_storage_engine() as matrix_storage_engine: engine = create_engine(postgresql.url()) ensure_db(engine) create_schemas( @@ -751,16 +573,18 @@ def test_hdf_matrix(self): labels=labels, states=states, ) + first_feature_table = next(iter(self.good_feature_dictionary.keys())) + engine.execute( + f"alter table {db_config['features_schema_name']}.{first_feature_table} rename column as_of_date to mydatecol" + ) + builder = MatrixBuilder( + db_config=db_config, + matrix_storage_engine=matrix_storage_engine, + experiment_hash=experiment_hash, + engine=engine, + ) - with get_matrix_storage_engine() as matrix_storage_engine: - matrix_storage_engine.matrix_storage_class = HDFMatrixStore - builder = MatrixBuilder( - db_config=db_config, - matrix_storage_engine=matrix_storage_engine, - experiment_hash=experiment_hash, - engine=engine, - ) - + with self.assertRaises(ValueError): uuid = filename_friendly_hash(self.good_metadata) builder.build_matrix( as_of_times=self.good_dates, @@ -772,67 +596,6 @@ def test_hdf_matrix(self): matrix_type="test", ) - assert len(matrix_storage_engine.get_store(uuid).design_matrix) == 5 - - def test_nullcheck(self): - f0_dict = {(r[0], r[1]): r for r in features0_pre} - f1_dict = {(r[0], r[1]): r for r in features1_pre} - - features0 = sorted(f0_dict.values(), key=lambda x: (x[1], x[0])) - features1 = sorted(f1_dict.values(), key=lambda x: (x[1], x[0])) - - features_tables = [features0, features1] - - with testing.postgresql.Postgresql() as postgresql: - # create an engine and generate a table with fake feature data - engine = create_engine(postgresql.url()) - create_schemas( - engine=engine, - features_tables=features_tables, - labels=labels, - states=states, - ) - - dates = [ - datetime.datetime(2016, 1, 1, 0, 0), - datetime.datetime(2016, 2, 1, 0, 0), - datetime.datetime(2016, 3, 1, 0, 0), - ] - - with get_matrix_storage_engine() as matrix_storage_engine: - builder = MatrixBuilder( - db_config=db_config, - matrix_storage_engine=matrix_storage_engine, - experiment_hash=experiment_hash, - engine=engine, - ) - - feature_dictionary = { - "features0": ["f1", "f2"], - "features1": ["f3", "f4"], - } - matrix_metadata = { - "matrix_id": "hi", - "state": "active", - "label_name": "booking", - "end_time": datetime.datetime(2016, 3, 1, 0, 0), - "feature_start_time": datetime.datetime(2016, 1, 1, 0, 0), - "label_timespan": "1 month", - "test_duration": "1 month", - "indices": ["entity_id", "as_of_date"], - } - uuid = filename_friendly_hash(matrix_metadata) - with self.assertRaises(ValueError): - builder.build_matrix( - as_of_times=dates, - label_name="booking", - label_type="binary", - feature_dictionary=feature_dictionary, - matrix_metadata=matrix_metadata, - matrix_uuid=uuid, - matrix_type="test", - ) - def test_replace_false_rerun(self): with testing.postgresql.Postgresql() as postgresql: # create an engine and generate a table with fake feature data diff --git a/src/tests/architect_tests/utils.py b/src/tests/architect_tests/utils.py index e21e3f701..b56f25b32 100644 --- a/src/tests/architect_tests/utils.py +++ b/src/tests/architect_tests/utils.py @@ -138,30 +138,6 @@ def TemporaryDirectory(): shutil.rmtree(name) -@contextmanager -def fake_metta(matrix_dict, metadata): - """Stores matrix and metadata in a metta-data-like form - - Args: - matrix_dict (dict) of form { columns: values }. - Expects an entity_id to be present which it will use as the index - metadata (dict). Any metadata that should be set - - Yields: - tuple of filenames for matrix and metadata - """ - matrix = pd.DataFrame.from_dict(matrix_dict).set_index("entity_id") - with tempfile.NamedTemporaryFile() as matrix_file: - with tempfile.NamedTemporaryFile("w") as metadata_file: - hdf = pd.HDFStore(matrix_file.name) - hdf.put("title", matrix, data_columns=True) - matrix_file.seek(0) - - yaml.dump(metadata, metadata_file) - metadata_file.seek(0) - yield (matrix_file.name, metadata_file.name) - - def fake_labels(length): return numpy.array([random.choice([True, False]) for i in range(0, length)]) diff --git a/src/tests/catwalk_tests/test_storage.py b/src/tests/catwalk_tests/test_storage.py index fb44e044d..f16d0a240 100644 --- a/src/tests/catwalk_tests/test_storage.py +++ b/src/tests/catwalk_tests/test_storage.py @@ -1,4 +1,5 @@ import os +import io import tempfile from collections import OrderedDict @@ -8,15 +9,15 @@ from moto import mock_s3 import boto3 from numpy.testing import assert_almost_equal -from pandas.testing import assert_frame_equal +from pandas.testing import assert_frame_equal, assert_series_equal from unittest import mock +from triage.util.pandas import downcast_matrix import pytest from triage.component.catwalk.storage import ( MatrixStore, CSVMatrixStore, FSStore, - HDFMatrixStore, S3Store, ProjectStorage, ModelStorageEngine, @@ -95,22 +96,16 @@ def matrix_stores(): project_storage = ProjectStorage(tmpdir) tmpcsv = os.path.join(tmpdir, "df.csv.gz") tmpyaml = os.path.join(tmpdir, "df.yaml") - tmphdf = os.path.join(tmpdir, "df.h5") with open(tmpyaml, "w") as outfile: yaml.dump(METADATA, outfile, default_flow_style=False) df.to_csv(tmpcsv, compression="gzip") - df.to_hdf(tmphdf, "matrix") csv = CSVMatrixStore(project_storage, [], "df") - hdf = HDFMatrixStore(project_storage, [], "df") - assert csv.design_matrix.equals(hdf.design_matrix) # first test with caching - with csv.cache(), hdf.cache(): + with csv.cache(): yield csv - yield hdf # with the caching out of scope they will be nuked # and these last two versions will not have any cache yield csv - yield hdf def test_MatrixStore_empty(): @@ -175,25 +170,24 @@ def test_MatrixStore_labels_idempotency(): def test_MatrixStore_save(): - data = { - "entity_id": [1, 2], - "as_of_date": [pd.Timestamp(2017, 1, 1), pd.Timestamp(2017, 1, 1)], - "feature_one": [0.5, 0.6], - "feature_two": [0.5, 0.6], - "label": [1, 0] - } - df = pd.DataFrame.from_dict(data) - labels = df.pop("label") - for matrix_store in matrix_stores(): - matrix_store.metadata = METADATA + data = { + "entity_id": [1, 2], + "as_of_date": [pd.Timestamp(2016, 1, 1), pd.Timestamp(2016, 1, 1)], + "feature_one": [0.5, 0.6], + "feature_two": [0.5, 0.6], + "label": [1, 0] + } + df = pd.DataFrame.from_dict(data) + df.set_index(MatrixStore.indices, inplace=True) + df = downcast_matrix(df) + bytestream = io.BytesIO(df.to_csv(None).encode('utf-8')) + + matrix_store.save(bytestream, METADATA) - matrix_store.matrix_label_tuple = df, labels - matrix_store.save() - assert_frame_equal( - matrix_store.design_matrix, - df - ) + labels = df.pop("label") + assert_frame_equal(matrix_store.design_matrix, df) + assert_series_equal(matrix_store.labels, labels) def test_MatrixStore_caching(): @@ -228,16 +222,22 @@ def test_s3_save(): with mock_s3(): client = boto3.client("s3") client.create_bucket(Bucket="fake-matrix-bucket", ACL="public-read-write") - for example in matrix_stores(): - if not isinstance(example, CSVMatrixStore): - continue - project_storage = ProjectStorage("s3://fake-matrix-bucket") - - tosave = CSVMatrixStore(project_storage, [], "test") - tosave.metadata = example.metadata - tosave.matrix_label_tuple = example.matrix_label_tuple - tosave.save() - - tocheck = CSVMatrixStore(project_storage, [], "test") - assert tocheck.metadata == example.metadata - assert tocheck.design_matrix.to_dict() == example.design_matrix.to_dict() + project_storage = ProjectStorage("s3://fake-matrix-bucket") + matrix_store = project_storage.matrix_storage_engine().get_store('1234') + data = { + "entity_id": [1, 2], + "as_of_date": [pd.Timestamp(2017, 1, 1), pd.Timestamp(2017, 1, 1)], + "feature_one": [0.5, 0.6], + "feature_two": [0.5, 0.6], + "label": [1, 0] + } + df = pd.DataFrame.from_dict(data) + df.set_index(MatrixStore.indices, inplace=True) + df = downcast_matrix(df) + bytestream = io.BytesIO(df.to_csv(None).encode('utf-8')) + + matrix_store.save(bytestream, METADATA) + + labels = df.pop("label") + assert_frame_equal(matrix_store.design_matrix, df) + assert_series_equal(matrix_store.labels, labels) diff --git a/src/tests/catwalk_tests/utils.py b/src/tests/catwalk_tests/utils.py index f8e4c8997..da20734c0 100644 --- a/src/tests/catwalk_tests/utils.py +++ b/src/tests/catwalk_tests/utils.py @@ -14,30 +14,6 @@ from triage.util.structs import FeatureNameList -@contextmanager -def fake_metta(matrix_dict, metadata): - """Stores matrix and metadata in a metta-data-like form - - Args: - matrix_dict (dict) of form { columns: values }. - Expects an entity_id to be present which it will use as the index - metadata (dict). Any metadata that should be set - - Yields: - tuple of filenames for matrix and metadata - """ - matrix = pandas.DataFrame.from_dict(matrix_dict).set_index("entity_id") - with tempfile.NamedTemporaryFile() as matrix_file: - with tempfile.NamedTemporaryFile("w") as metadata_file: - hdf = pandas.HDFStore(matrix_file.name) - hdf.put("title", matrix, data_columns=True) - matrix_file.seek(0) - - yaml.dump(metadata, metadata_file) - metadata_file.seek(0) - yield (matrix_file.name, metadata_file.name) - - def fake_labels(length): return numpy.array([random.choice([True, False]) for i in range(0, length)]) diff --git a/src/tests/test_experiments.py b/src/tests/test_experiments.py index 7f99e4bdc..eb8ecb9f4 100644 --- a/src/tests/test_experiments.py +++ b/src/tests/test_experiments.py @@ -13,8 +13,8 @@ from sqlalchemy.orm import sessionmaker from tests.utils import sample_config, populate_source_data -from triage.component.catwalk.storage import HDFMatrixStore, CSVMatrixStore from triage.component.results_schema.schema import Experiment +from triage.component.catwalk.storage import CSVMatrixStore from triage.experiments import ( MultiCoreExperiment, @@ -55,7 +55,7 @@ def num_linked_evaluations(db_engine): ) parametrize_matrix_storage_classes = pytest.mark.parametrize( - ("matrix_storage_class",), [(HDFMatrixStore,), (CSVMatrixStore,)] + ("matrix_storage_class",), [(CSVMatrixStore,)] ) diff --git a/src/tests/test_utils_pandas.py b/src/tests/test_utils_pandas.py index 4b28a6f29..8dd814995 100644 --- a/src/tests/test_utils_pandas.py +++ b/src/tests/test_utils_pandas.py @@ -1,5 +1,6 @@ -from triage.util.pandas import downcast_matrix from triage.component.catwalk.storage import MatrixStore +import pandas as pd +from triage.util.pandas import columns_with_nulls, downcast_matrix from .utils import matrix_creator @@ -12,3 +13,15 @@ def test_downcast_matrix(): # make sure the memory usage is lower because there would be no point of this otherwise assert downcasted_df.memory_usage().sum() < df.memory_usage().sum() + + +def test_columns_with_nulls(): + assert columns_with_nulls(pd.DataFrame.from_dict({ + "feature_one": [0.5, 0.6, 0.5, 0.6], + "feature_two": [0.5, 0.6, 0.5, 0.6], + })) == [] + + assert columns_with_nulls(pd.DataFrame.from_dict({ + "feature_one": [0.5, None, 0.5, 0.6], + "feature_two": [0.5, 0.6, 0.5, 0.6], + })) == ["feature_one"] diff --git a/src/tests/utils.py b/src/tests/utils.py index c2d3e2e54..e7a25f933 100644 --- a/src/tests/utils.py +++ b/src/tests/utils.py @@ -1,4 +1,5 @@ import datetime +import io import random import tempfile from contextlib import contextmanager @@ -180,12 +181,10 @@ def get_matrix_store(project_storage, matrix=None, metadata=None, write_to_db=Tr matrix["as_of_date"] = matrix["as_of_date"].apply(pandas.Timestamp) matrix.set_index(MatrixStore.indices, inplace=True) matrix_store = project_storage.matrix_storage_engine().get_store(filename_friendly_hash(metadata)) - matrix_store.metadata = metadata - new_matrix = matrix.copy() - labels = new_matrix.pop(matrix_store.label_column_name) - matrix_store.matrix_label_tuple = new_matrix, labels - matrix_store.save() - matrix_store.clear_cache() + matrix_store.save( + from_fileobj=io.BytesIO(matrix.to_csv(None).encode('utf-8')), + metadata=metadata + ) if write_to_db: if ( session.query(Matrix).filter(Matrix.matrix_uuid == matrix_store.uuid).count() diff --git a/src/triage/cli.py b/src/triage/cli.py index e7a0637d0..a40157b90 100755 --- a/src/triage/cli.py +++ b/src/triage/cli.py @@ -2,6 +2,7 @@ import argparse import importlib.util import logging +import time import os import yaml from datetime import datetime @@ -15,7 +16,7 @@ from triage.component.audition import AuditionRunner from triage.component.results_schema import upgrade_db, stamp_db, REVISION_MAPPING from triage.component.timechop.plotting import visualize_chops -from triage.component.catwalk.storage import CSVMatrixStore, HDFMatrixStore, Store, ProjectStorage +from triage.component.catwalk.storage import CSVMatrixStore, Store, ProjectStorage from triage.experiments import ( CONFIG_VERSION, MultiCoreExperiment, @@ -24,8 +25,6 @@ from triage.component.postmodeling.crosstabs import CrosstabsConfigLoader, run_crosstabs from triage.util.db import create_engine -logging.basicConfig(level=logging.INFO) - def natural_number(value): natural = int(value) @@ -157,7 +156,6 @@ class Experiment(Command): matrix_storage_map = { "csv": CSVMatrixStore, - "hdf": HDFMatrixStore, } matrix_storage_default = "csv" diff --git a/src/triage/component/architect/builders.py b/src/triage/component/architect/builders.py index c1d75fb12..5b55f862c 100644 --- a/src/triage/component/architect/builders.py +++ b/src/triage/component/architect/builders.py @@ -1,14 +1,17 @@ -import io import json import logging -import pandas +import contextlib from sqlalchemy.orm import sessionmaker +from functools import partial +from ohio import PipeTextIO +from triage.util.io import IteratorBytesIO from triage.component.results_schema import Matrix -from triage.database_reflection import table_has_data +from triage.database_reflection import table_has_data, table_row_count from triage.tracking import built_matrix, skipped_matrix, errored_matrix from triage.util.pandas import downcast_matrix +from triage.validation_primitives import table_should_have_entity_date_columns, table_should_have_data class BuilderBase(object): @@ -62,6 +65,8 @@ def _outer_join_query( right_column_selections, entity_date_table_name, additional_conditions="", + include_index=False, + column_override=None, ): """ Given a (features or labels) table, a list of times, columns to select, and (optionally) a set of join conditions, perform an outer @@ -82,23 +87,46 @@ def _outer_join_query( """ # put everything into the query - query = """ - SELECT ed.entity_id, - ed.as_of_date{columns} - FROM {entity_date_table_name} ed - LEFT OUTER JOIN {right_table} r - ON ed.entity_id = r.entity_id AND - ed.as_of_date = r.as_of_date - {more} - ORDER BY ed.entity_id, - ed.as_of_date - """.format( - columns="".join(right_column_selections), - feature_schema=self.db_config["features_schema_name"], - entity_date_table_name=entity_date_table_name, - right_table=right_table_name, - more=additional_conditions, - ) + + if include_index: + query = """ + SELECT ed.entity_id, + ed.as_of_date{columns} + FROM {entity_date_table_name} ed + LEFT OUTER JOIN {right_table} r + ON ed.entity_id = r.entity_id AND + ed.as_of_date = r.as_of_date + {more} + ORDER BY ed.entity_id, + ed.as_of_date + """.format( + columns="".join(right_column_selections), + feature_schema=self.db_config["features_schema_name"], + entity_date_table_name=entity_date_table_name, + right_table=right_table_name, + more=additional_conditions, + ) + else: + query = """ + with r as ( + SELECT ed.entity_id, + ed.as_of_date, {columns} + FROM {entity_date_table_name} ed + LEFT OUTER JOIN {right_table} r + ON ed.entity_id = r.entity_id AND + ed.as_of_date = r.as_of_date + {more} + ORDER BY ed.entity_id, + ed.as_of_date + ) select {columns_maybe_override} from r + """.format( + columns="".join(right_column_selections)[2:], + columns_maybe_override="".join(right_column_selections)[2:] if not column_override else column_override, + feature_schema=self.db_config["features_schema_name"], + entity_date_table_name=entity_date_table_name, + right_table=right_table_name, + more=additional_conditions, + ) return query def make_entity_date_table( @@ -214,6 +242,7 @@ def _all_valid_entity_dates_query(self, state, as_of_time_strings): class MatrixBuilder(BuilderBase): + def build_matrix( self, as_of_times, @@ -253,11 +282,13 @@ def build_matrix( if self.run_id: errored_matrix(self.run_id, self.db_engine) return + + # what should the labels table look like? + # 1. have data + # 2. entity date/column + labels_table_name = f"{self.db_config['labels_schema_name']}.{self.db_config['labels_table_name']}" if not table_has_data( - "{}.{}".format( - self.db_config["labels_schema_name"], - self.db_config["labels_table_name"], - ), + labels_table_name, self.db_engine, ): logging.warning("labels table is not populated, cannot build matrix") @@ -265,6 +296,21 @@ def build_matrix( errored_matrix(self.run_id, self.db_engine) return + table_should_have_entity_date_columns( + labels_table_name, + self.db_engine + ) + + # what should the feature tables look like? + # 1. have data + # 2. entity/date column + for feature_table in feature_dictionary.keys(): + full_feature_table = \ + f"{self.db_config['features_schema_name']}.{feature_table}" + table_should_have_data(full_feature_table, self.db_engine) + table_should_have_entity_date_columns(full_feature_table, self.db_engine) + + matrix_store = self.matrix_storage_engine.get_store(matrix_uuid) if not self.replace and matrix_store.exists: logging.info("Skipping %s because matrix already exists", matrix_uuid) @@ -289,7 +335,7 @@ def build_matrix( matrix_uuid, matrix_metadata["label_timespan"], ) - except ValueError as e: + except ValueError: logging.warning( "Not able to build entity-date table due to: %s - will not build matrix", exc_info=True, @@ -297,39 +343,22 @@ def build_matrix( if self.run_id: errored_matrix(self.run_id, self.db_engine) return - logging.info( - "Extracting feature group data from database into file " "for matrix %s", - matrix_uuid, - ) - dataframes = self.load_features_data( - as_of_times, feature_dictionary, entity_date_table_name, matrix_uuid - ) - logging.info(f"Feature data extracted for matrix {matrix_uuid}") - logging.info( - "Extracting label data from database into file for " "matrix %s", - matrix_uuid, - ) - labels_df = self.load_labels_data( + feature_queries = self.feature_load_queries(feature_dictionary, entity_date_table_name) + label_query = self.label_load_query( label_name, label_type, entity_date_table_name, - matrix_uuid, matrix_metadata["label_timespan"], ) - dataframes.insert(0, labels_df) - logging.info(f"Label data extracted for matrix {matrix_uuid}") # stitch together the csvs - logging.info("Merging feature files for matrix %s", matrix_uuid) - output = self.merge_feature_csvs(dataframes, matrix_uuid) - logging.info(f"Features data merged for matrix {matrix_uuid}") - - matrix_store.metadata = matrix_metadata - # store the matrix - labels = output.pop(matrix_store.label_column_name) - matrix_store.matrix_label_tuple = output, labels - matrix_store.save() - logging.info("Matrix %s saved", matrix_uuid) + logging.info("Building and saving matrix %s by querying and joining tables", matrix_uuid) + self._save_matrix( + queries=feature_queries + [label_query], + matrix_store=matrix_store, + matrix_metadata=matrix_metadata + ) + # If completely archived, save its information to matrices table # At this point, existence of matrix already tested, so no need to delete from db if matrix_type == "train": @@ -337,12 +366,20 @@ def build_matrix( else: lookback = matrix_metadata["test_duration"] + row_count = table_row_count( + '{schema}."{table}"'.format( + schema=self.db_config["features_schema_name"], + table=entity_date_table_name, + ), + self.db_engine + ) + matrix = Matrix( matrix_id=matrix_metadata["matrix_id"], matrix_uuid=matrix_uuid, matrix_type=matrix_type, labeling_window=matrix_metadata["label_timespan"], - num_observations=len(output), + num_observations=row_count, lookback_duration=lookback, feature_start_time=matrix_metadata["feature_start_time"], feature_dictionary=feature_dictionary, @@ -357,12 +394,11 @@ def build_matrix( built_matrix(self.run_id, self.db_engine) - def load_labels_data( + def label_load_query( self, label_name, label_type, entity_date_table_name, - matrix_uuid, label_timespan, ): """ Query the labels table and write the data to disk in csv format. @@ -371,12 +407,10 @@ def load_labels_data( :param label_name: name of the label to be used :param label_type: the type of label to be used :param entity_date_table_name: the name of the entity date table - :param matrix_uuid: a unique id for the matrix :param label_timespan: the time timespan that labels in matrix will include :type label_name: str :type label_type: str :type entity_date_table_name: str - :type matrix_uuid: str :type label_timespan: str :return: name of csv containing labels @@ -412,36 +446,32 @@ def load_labels_data( """.format( name=label_name, type=label_type, timespan=label_timespan ), + include_index=False, + column_override=label_name ) - return self.query_to_df(labels_query) + return labels_query - def load_features_data( - self, as_of_times, feature_dictionary, entity_date_table_name, matrix_uuid - ): + def feature_load_queries(self, feature_dictionary, entity_date_table_name): """ Loop over tables in features schema, writing the data from each to a csv. Return the full list of feature csv names and the list of all features. - :param as_of_times: the times to be included in the matrix :param feature_dictionary: a dictionary of feature tables and features to be included in the matrix :param entity_date_table_name: the name of the entity date table for the matrix - :param matrix_uuid: a human-readable id for the matrix - :type as_of_times: list :type feature_dictionary: dict :type entity_date_table_name: str - :type matrix_uuid: str :return: list of csvs containing feature data :rtype: tuple """ # iterate! for each table, make query, write csv, save feature & file names - feature_dfs = [] - for feature_table_name, feature_names in feature_dictionary.items(): - logging.info("Retrieving feature data from %s", feature_table_name) - features_query = self._outer_join_query( + queries = [] + for num, (feature_table_name, feature_names) in enumerate(feature_dictionary.items()): + logging.info("Generating feature query for %s", feature_table_name) + queries.append(self._outer_join_query( right_table_name="{schema}.{table}".format( schema=self.db_config["features_schema_name"], table=feature_table_name, @@ -450,84 +480,43 @@ def load_features_data( schema=self.db_config["features_schema_name"], table=entity_date_table_name, ), - # collate imputation shouldn't leave any nulls and we double-check - # the imputed table in FeatureGenerator.create_all_tables() but as - # a final check, raise a divide by zero error on export if the - # database encounters any during the outer join right_column_selections=[', "{0}"'.format(fn) for fn in feature_names], - ) - feature_dfs.append(self.query_to_df(features_query)) + include_index=True if num==0 else False, + )) + return queries - return feature_dfs - - def query_to_df(self, query_string, header="HEADER"): - """ Given a query, write the requested data to csv. + @property + def _raw_connections(self): + while True: + yield self.db_engine.raw_connection() - :param query_string: query to send - :param file_name: name to save the file as - :header: text to include in query indicating if a header should be saved - in output - :type query_string: str - :type file_name: str - :type header: str + def _save_matrix(self, queries, matrix_store, matrix_metadata): + """Construct and save a matrix CSV from a list of queries - :return: none - :rtype: none - """ - logging.debug("Copying to CSV query %s", query_string) - copy_sql = "COPY ({query}) TO STDOUT WITH CSV {head}".format( - query=query_string, head=header - ) - conn = self.db_engine.raw_connection() - cur = conn.cursor() - out = io.StringIO() - cur.copy_expert(copy_sql, out) - out.seek(0) - df = pandas.read_csv(out, parse_dates=["as_of_date"]) - df.set_index(["entity_id", "as_of_date"], inplace=True) - return downcast_matrix(df) - - def merge_feature_csvs(self, dataframes, matrix_uuid): - """Horizontally merge a list of feature CSVs - Assumptions: - - The first and second columns of each CSV are - the entity_id and date - - That the CSVs have the same list of entity_id/date combinations - in the same order. - - The first CSV is expected to be labels, and only have - entity_id, date, and label. - - All other CSVs do not have any labels (all non entity_id/date columns - will be treated as features) - - The label will be in the *last* column of the merged CSV - - :param source_filenames: the filenames of each feature csv - :param out_filename: the desired filename of the merged csv - :type source_filenames: list - :type out_filename: str + The results of each query are expected to return the same number of rows in the same order. + The columns will be placed alongside each other in the CSV much as a SQL join would. + However, this code does not deduplicate the columns, so the actual row identifiers + (e.g. entity id, as of date) should only be present in one of the queries + unless you want duplicate columns. - :return: none - :rtype: none + The result, and the given metadata, will be given to the supplied MatrixStore for saving. - :raises: ValueError if the first two columns in every CSV don't match + Args: + queries (iterable) SQL queries + matrix_store (triage.component.catwalk.storage.CSVMatrixStore) + matrix_metadata (dict) matrix metadata to save alongside the data """ - - for i, df in enumerate(dataframes): - if df.index.names != ["entity_id", "as_of_date"]: - raise ValueError( - f"index must be entity_id and as_of_date, value was {df.index}" - ) - # check for any nulls. the labels, understood to be the first file, - # can have nulls but no features should. therefore, skip the first dataframe - if i > 0: - columns_with_nulls = [ - column for column in df.columns if df[column].isnull().values.any() - ] - if len(columns_with_nulls) > 0: - raise ValueError( - "Imputation failed for the following features: %s" - % columns_with_nulls - ) - i += 1 - - big_df = dataframes[1].join(dataframes[2:] + [dataframes[0]]) - return big_df + copy_sqls = (f"COPY ({query}) TO STDOUT WITH CSV HEADER" for query in queries) + with contextlib.ExitStack() as stack: + connections = (stack.enter_context(contextlib.closing(conn)) + for conn in self._raw_connections) + cursors = (conn.cursor() for conn in connections) + + writers = (partial(cursor.copy_expert, copy_sql) + for (cursor, copy_sql) in zip(cursors, copy_sqls)) + pipes = (stack.enter_context(PipeTextIO(writer)) for writer in writers) + iterable = ( + b','.join(line.rstrip('\r\n').encode('utf-8') for line in join) + b'\n' + for join in zip(*pipes) + ) + matrix_store.save(from_fileobj=IteratorBytesIO(iterable), metadata=matrix_metadata) diff --git a/src/triage/component/architect/feature_generators.py b/src/triage/component/architect/feature_generators.py index d7c386d42..af07b0523 100644 --- a/src/triage/component/architect/feature_generators.py +++ b/src/triage/component/architect/feature_generators.py @@ -635,7 +635,7 @@ def _generate_agg_table_tasks_for(self, aggregation): return table_tasks - def _generate_imp_table_tasks_for(self, aggregation, drop_preagg=True): + def _generate_imp_table_tasks_for(self, aggregation, drop_preagg=False): """Generate SQL statements for preparing, populating, and finalizing imputations, for each feature group table in the given aggregation. diff --git a/src/triage/component/architect/utils.py b/src/triage/component/architect/utils.py index d20337bfd..2f9cf5c29 100644 --- a/src/triage/component/architect/utils.py +++ b/src/triage/component/architect/utils.py @@ -126,30 +126,6 @@ def TemporaryDirectory(): shutil.rmtree(name) -@contextmanager -def fake_metta(matrix_dict, metadata): - """Stores matrix and metadata in a metta-data-like form - - Args: - matrix_dict (dict) of form { columns: values }. - Expects an entity_id to be present which it will use as the index - metadata (dict). Any metadata that should be set - - Yields: - tuple of filenames for matrix and metadata - """ - matrix = pd.DataFrame.from_dict(matrix_dict).set_index("entity_id") - with tempfile.NamedTemporaryFile() as matrix_file: - with tempfile.NamedTemporaryFile("w") as metadata_file: - hdf = pd.HDFStore(matrix_file.name) - hdf.put("title", matrix, data_columns=True) - matrix_file.seek(0) - - yaml.dump(metadata, metadata_file) - metadata_file.seek(0) - yield (matrix_file.name, metadata_file.name) - - def fake_labels(length): return numpy.array([random.choice([True, False]) for i in range(0, length)]) diff --git a/src/triage/component/catwalk/storage.py b/src/triage/component/catwalk/storage.py index d357ce90b..2b6bf579d 100644 --- a/src/triage/component/catwalk/storage.py +++ b/src/triage/component/catwalk/storage.py @@ -13,12 +13,13 @@ TestPrediction, TrainPrediction, ) -from triage.util.pandas import downcast_matrix +from triage.util.pandas import downcast_matrix, columns_with_nulls import pandas as pd import s3fs import yaml from boto3.s3.transfer import TransferConfig +import shutil import gzip @@ -373,6 +374,11 @@ def _preprocess_and_split_matrix(self, matrix_with_labels): matrix_with_labels = downcast_matrix(matrix_with_labels) labels = matrix_with_labels.pop(self.label_column_name) design_matrix = matrix_with_labels + nullcols = columns_with_nulls(design_matrix) + if nullcols: + raise ValueError(f"Matrix {self.uuid} contains null values in feature columns." + "Inspect matrix, feature tables, and cohort to locate source." + "Null columns: {nullcols}") return design_matrix, labels @property @@ -548,62 +554,6 @@ def __getstate__(self): return state -class HDFMatrixStore(MatrixStore): - """Store and access matrices using HDF - - Instead of overriding head_of_matrix, which cannot be easily and reliably done - using HDFStore without loading the whole matrix into memory, - we individually override 'empty' and 'columns' - to obviate the need for a performant 'head_of_matrix' operation. - """ - - suffix = "h5" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if isinstance(self.matrix_base_store, S3Store): - raise ValueError("HDFMatrixStore cannot be used with S3") - - def columns(self, include_label=False): - """The matrix's column list""" - head_of_matrix = pd.read_hdf(self.matrix_base_store.path, start=0, stop=0) - columns = head_of_matrix.columns.tolist() - if include_label: - return columns - else: - return [col for col in columns if col != self.metadata["label_name"]] - - @property - def empty(self): - """Whether or not the matrix has at least one row""" - if not self.matrix_base_store.exists(): - return True - else: - try: - head_of_matrix = pd.read_hdf(self.matrix_base_store.path, start=0, stop=1) - return head_of_matrix.empty - except ValueError: - # There is no known way to make the start/stop operations work all the time - # , there is often a ValueError when trying to load just the first row - # However, if we do get a ValueError that means there is data so it can't be empty - return False - - def _load(self): - return pd.read_hdf(self.matrix_base_store.path) - - def save(self): - hdf = pd.HDFStore( - self.matrix_base_store.path, - mode="w", - complevel=4, - complib="zlib", - ) - hdf.put(f"matrix_{self.matrix_uuid}", self.full_matrix_for_saving.apply(pd.to_numeric), data_columns=True) - hdf.close() - with self.metadata_base_store.open("wb") as fd: - yaml.dump(self.metadata, fd, encoding="utf-8") - - class CSVMatrixStore(MatrixStore): """Store and access compressed matrices using CSV""" @@ -626,10 +576,18 @@ def _load(self): with self.matrix_base_store.open("rb") as fd: return pd.read_csv(fd, compression="gzip", parse_dates=["as_of_date"]) - def save(self): - self.matrix_base_store.write(gzip.compress(self.full_matrix_for_saving.to_csv(None).encode("utf-8"))) - with self.metadata_base_store.open("wb") as fd: - yaml.dump(self.metadata, fd, encoding="utf-8") + def save(self, from_fileobj, metadata): + """Compress and save the matrix from a CSV bytestream file object + + Args: + from_fileobj (file-like): A readable file object containing a CSV bytestream to save + """ + with self.matrix_base_store.open('wb') as fdesc: + with gzip.GzipFile(fileobj=fdesc, mode='w') as compressor: + shutil.copyfileobj(from_fileobj, compressor) + + with self.metadata_base_store.open('wb') as fd: + yaml.dump(metadata, fd, encoding="utf-8") class TestMatrixType(object): diff --git a/src/triage/database_reflection.py b/src/triage/database_reflection.py index 406209bbb..2ff89ed45 100644 --- a/src/triage/database_reflection.py +++ b/src/triage/database_reflection.py @@ -95,7 +95,7 @@ def table_row_count(table_name, db_engine): Returns: (int) The number of rows in the table """ return next( - row for row in db_engine.execute("select count(*) from {}".format(table_name)) + row[0] for row in db_engine.execute(f"select count(*) from {table_name}") ) diff --git a/src/triage/experiments/base.py b/src/triage/experiments/base.py index 13156d707..b065c6d78 100644 --- a/src/triage/experiments/base.py +++ b/src/triage/experiments/base.py @@ -758,7 +758,7 @@ def clean_up_subset_tables(self): def _run_profile(self): cp = cProfile.Profile() - cp.runcall(self._run) + cp.runcall(self.generate_matrices) store = self.project_storage.get_store( ["profiling_stats"], f"{int(time.time())}.profile" diff --git a/src/triage/util/io.py b/src/triage/util/io.py new file mode 100644 index 000000000..cc4142f8d --- /dev/null +++ b/src/triage/util/io.py @@ -0,0 +1,86 @@ +import io +from ohio import IOClosed + + +class StreamBytesIOBase(io.BufferedIOBase): + """Readable file-like abstract base class. + Concrete classes may implemented method `__next_chunk__` to return + chunks (or all) of the bytes to be read. + """ + def __init__(self): + self._remainder = '' + + def __next_chunk__(self): + raise NotImplementedError("StreamBytesIOBase subclasses must implement __next_chunk__") + + def readable(self): + if self.closed: + raise IOClosed() + + return True + + def _read1(self, size=None): + while not self._remainder: + try: + self._remainder = self.__next_chunk__() + except StopIteration: + break + + result = self._remainder[:size] + self._remainder = self._remainder[len(result):] + + return result + + def read(self, size=None): + if self.closed: + raise IOClosed() + + if size is not None and size < 0: + size = None + + result = b'' + + while size is None or size > 0: + content = self._read1(size) + if not content: + break + + if size is not None: + size -= len(content) + + result += content + + return result + + def readline(self): + if self.closed: + raise IOClosed() + + result = '' + + while True: + index = self._remainder.find('\n') + if index == -1: + result += self._remainder + try: + self._remainder = self.__next_chunk__() + except StopIteration: + self._remainder = '' + break + else: + result += self._remainder[:(index + 1)] + self._remainder = self._remainder[(index + 1):] + break + + return result + + +class IteratorBytesIO(StreamBytesIOBase): + """Readable file-like interface for iterable byte streams.""" + + def __init__(self, iterable): + super().__init__() + self.__iterator__ = iter(iterable) + + def __next_chunk__(self): + return next(self.__iterator__) diff --git a/src/triage/util/pandas.py b/src/triage/util/pandas.py index 5b9238805..aa05833e4 100644 --- a/src/triage/util/pandas.py +++ b/src/triage/util/pandas.py @@ -27,3 +27,7 @@ def downcast_matrix(df): logging.debug("Downcasted matrix. Final memory usage: %s", new_df.memory_usage()) return new_df + + +def columns_with_nulls(df): + return [column for column in df.columns if df[column].isnull().values.any()] diff --git a/src/triage/validation_primitives.py b/src/triage/validation_primitives.py index 92a7ff45c..0de5e2124 100644 --- a/src/triage/validation_primitives.py +++ b/src/triage/validation_primitives.py @@ -155,3 +155,8 @@ def string_is_tablesafe(string): if not string: return False return all(c.isalpha() or c.isdigit() or c == '_' for c in string) + + +def table_should_have_entity_date_columns(table_name, db_engine): + table_should_have_column(table_name, "entity_id", db_engine) + table_should_have_column(table_name, "as_of_date", db_engine)