Skip to content

Commit

Permalink
Refactor functionality to ExperimentBase class [Resolves #400]
Browse files Browse the repository at this point in the history
The execution-oriented experiment subclasses (SingleThreadedExperiment, MultiCoreExperiment) duplicate a lot of functionality in a way that means they can get out of sync if developers are not careful enough. This also harms readability: if you are looking for code relating to the experiment algorithm (anything that has to do with data science), none of that information should be in the experiment subclasses, but right now it is. The ExperimentBase (and the components it delegates to) should have all of that code.

To make this easier, there are two things introduced:

1. The ModelTester component, which encapsulates prediction, individual importance, and evaluation for a single train/test split. It uses a similar pattern to the ModelTrainer of generating and processing tasks only serializable arguments. This is important because a lot of the extra code in those experiment subclasses was related to model testing.

2. The SerializableDbEngine. One of the worst foes for readability in the ExperimentBase and subclasses was the component 'factory' methods that were introduced as a workaround for database engines not being serializable across process boundaries. Now, all components which use a database engine use a SerializableDbEngine that upon serialization (via `__setstate__`) removes the engine but saves the url, so that it may be reinstantiated later. This removes a lot of extra code from the ExperimentBase and subclasses, allowing instance methods on components (e.g. `self.model_trainer.process_train_task`) to be used as task functions that can be sent to different execution contexts.

The result is that subclasses need only implement four methods: `process_query_tasks`, `process_matrix_build_tasks`, `process_train_tasks`, and `process_model_test_tasks`. These generally involve passing through task arguments to the correct component (as opposed to caring what the task arguments are), which should leave the subclasses uninteresting to readers unless they actually want to mess with execution context.

To test out this refactoring's usefulness, there is a third subclass introduced: The RQExperiment (available through the package extra 'rq', e.g. `pip install triage[rq]`), which simply enqueues each task onto an RQ Queue and waits for the group of tasks to be finished before moving on.
  • Loading branch information
thcrock committed May 1, 2018
1 parent d9a1d37 commit 01425b2
Show file tree
Hide file tree
Showing 22 changed files with 533 additions and 455 deletions.
1 change: 1 addition & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ Experiment Classes

- *SingleThreadedExperiment*: An experiment that performs all tasks serially in a single thread. Good for simple use on small datasets, or for understanding the general flow of data through a pipeline.
- *MultiCoreExperiment*: An experiment that makes use of the multiprocessing library to parallelize various time-consuming steps. Takes an ``n_processes`` keyword argument to control how many workers to use.
- *RQExperiment*: An experiment that makes use of the python-rq library to enqueue individual tasks onto the default queue, and wait for the jobs to be finished before moving on. python-rq requires Redis and any number of worker processes running the Triage codebase. Triage does not set up any of this needed infrastructure for you. Available through the RQ extra ( `pip install triage[rq]` )

Background
==========
Expand Down
1 change: 1 addition & 0 deletions requirement/extras-rq.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
rq
1 change: 1 addition & 0 deletions requirement/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ testing.postgresql
pytest==3.2.5
pytest-cov
moto==1.0.1
fakeredis
6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

REQUIREMENTS_TEST_PATH = ROOT_PATH / 'requirement' / 'test.txt'

REQUIREMENTS_RQ_PATH = ROOT_PATH / 'requirement' / 'extras-rq.txt'


def stream_requirements(fd):
"""For a given requirements file descriptor, generate lines of
Expand All @@ -36,6 +38,9 @@ def stream_requirements(fd):
REQUIREMENTS_TEST = REQUIREMENTS[:]
REQUIREMENTS_TEST.extend(stream_requirements(test_requirements_file))

with REQUIREMENTS_RQ_PATH.open() as rq_requirements_file:
RQ_REQUIREMENTS = list(stream_requirements(rq_requirements_file))


setup(
name='triage',
Expand All @@ -49,6 +54,7 @@ def stream_requirements(fd):
package_dir={'': 'src'},
include_package_data=True,
install_requires=REQUIREMENTS,
extras_require={'rq': RQ_REQUIREMENTS},
license=LICENSE_PATH.read_text(),
zip_safe=False,
keywords='triage',
Expand Down
39 changes: 4 additions & 35 deletions src/tests/catwalk_tests/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,11 @@ def test_evaluating_early_warning():
db_engine
)

as_of_date = datetime.date(2016, 5, 5)

# Evaluate the testing metrics and test for all of them.
model_evaluator.evaluate(
trained_model.predict_proba(labels)[:, 1],
fake_test_matrix_store,
model_id,
as_of_date,
as_of_date,
'1y'
)
records = [
row[0] for row in
Expand All @@ -78,7 +73,7 @@ def test_evaluating_early_warning():
where model_id = %s and
evaluation_start_time = %s
order by 1''',
(model_id, as_of_date)
(model_id, fake_test_matrix_store.as_of_dates[0])
)
]
assert records == [
Expand Down Expand Up @@ -120,9 +115,6 @@ def test_evaluating_early_warning():
trained_model.predict_proba(labels)[:, 1],
fake_train_matrix_store,
model_id,
as_of_date,
as_of_date,
'1y'
)
records = [
row[0] for row in
Expand All @@ -132,22 +124,9 @@ def test_evaluating_early_warning():
where model_id = %s and
evaluation_start_time = %s
order by 1''',
(model_id, as_of_date)
)
]
records2 = [
row for row in
db_engine.execute(
'''select *
from train_results.train_evaluations
where model_id = %s and
evaluation_start_time = %s
order by 1''',
(model_id, as_of_date)
(model_id, fake_train_matrix_store.as_of_dates[0])
)
]

print(records2)
assert records == ['accuracy', 'roc_auc']


Expand All @@ -172,10 +151,6 @@ def test_model_scoring_inspections():
training_labels = numpy.array([False, False, True, True, True, False, True, True])
training_prediction_probas = numpy.array([0.6, 0.4, 0.55, 0.70, 0.3, 0.2, 0.8, 0.6])

evaluation_start = datetime.datetime(2016, 4, 1)
evaluation_end = datetime.datetime(2016, 7, 1)
example_as_of_date_frequency = '1d'

fake_train_matrix_store = MockMatrixStore('train', 'efgh', 5, db_engine, training_labels)
fake_test_matrix_store = MockMatrixStore('test', '1234', 5, db_engine, testing_labels)

Expand All @@ -190,15 +165,12 @@ def test_model_scoring_inspections():
testing_prediction_probas,
fake_test_matrix_store,
model_id,
evaluation_start,
evaluation_end,
example_as_of_date_frequency
)
for record in db_engine.execute(
'''select * from test_results.test_evaluations
where model_id = %s and evaluation_start_time = %s
order by 1''',
(model_id, evaluation_start)
(model_id, fake_test_matrix_store.as_of_dates[0])
):
assert record['num_labeled_examples'] == 4
assert record['num_positive_labels'] == 2
Expand All @@ -214,15 +186,12 @@ def test_model_scoring_inspections():
training_prediction_probas,
fake_train_matrix_store,
model_id,
evaluation_start,
evaluation_end,
example_as_of_date_frequency
)
for record in db_engine.execute(
'''select * from train_results.train_evaluations
where model_id = %s and evaluation_start_time = %s
order by 1''',
(model_id, evaluation_start)
(model_id, fake_train_matrix_store.as_of_dates[0])
):
assert record['num_labeled_examples'] == 8
assert record['num_positive_labels'] == 5
Expand Down
6 changes: 2 additions & 4 deletions src/tests/catwalk_tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def test_integration():
'end_time': as_of_date,
'metta-uuid': '1234',
'indices': ['entity_id'],
'matrix_type': 'test'
'matrix_type': 'test',
'as_of_date_frequency': '1month'
}
)
for as_of_date in as_of_dates
Expand Down Expand Up @@ -123,9 +124,6 @@ def test_integration():
predictions_proba,
test_store,
model_id,
as_of_date,
as_of_date,
'6month'
)

# assert
Expand Down
3 changes: 3 additions & 0 deletions src/tests/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import partial
from tempfile import TemporaryDirectory
from unittest import mock, TestCase
import fakeredis

import pytest
import testing.postgresql
Expand All @@ -17,6 +18,7 @@
from triage.experiments import (
MultiCoreExperiment,
SingleThreadedExperiment,
RQExperiment,
CONFIG_VERSION,
)

Expand All @@ -36,6 +38,7 @@ def num_linked_evaluations(db_engine):
parametrize_experiment_classes = pytest.mark.parametrize(('experiment_class',), [
(SingleThreadedExperiment,),
(partial(MultiCoreExperiment, n_processes=2, n_db_processes=2),),
(partial(RQExperiment, redis_connection=fakeredis.FakeStrictRedis(), queue_kwargs={'async': False}),),
])


Expand Down
34 changes: 25 additions & 9 deletions src/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from sqlalchemy.orm import sessionmaker

from triage.component import metta
from triage.component.catwalk.storage import CSVMatrixStore
from triage.component.catwalk.storage import CSVMatrixStore, InMemoryMatrixStore
from triage.component.results_schema import Model, Matrix
from triage.experiments import CONFIG_VERSION
from triage.component.catwalk.storage import TrainMatrixType, TestMatrixType
Expand Down Expand Up @@ -49,14 +49,30 @@ class MockTrainedModel(object):
def predict_proba(self, dataset):
return numpy.random.rand(len(dataset), len(dataset))

class MockMatrixStore(object):
def __init__(self, matrix_type, matrix_uuid, label_count, db_engine, init_labels=[]):
if matrix_type == 'train':
self.matrix_type = TrainMatrixType
elif matrix_type == 'test':
self.matrix_type = TestMatrixType
else:
raise Exception('Initialize MockMatrixStore with matrix_type = "train" or "test"')
class MockMatrixStore(InMemoryMatrixStore):
def __init__(self, matrix_type, matrix_uuid, label_count, db_engine, init_labels=None, metadata_overrides=None, matrix=None):
base_metadata = {
'feature_start_time': datetime.date(2014, 1, 1),
'end_time': datetime.date(2015, 1, 1),
'as_of_date_frequency': '1y',
'matrix_id': 'some_matrix',
'label_name': 'label',
'label_timespan': '3month',
'indices': ['entity_id'],
'matrix_type': matrix_type
}
metadata_overrides = metadata_overrides or {}
base_metadata.update(metadata_overrides)
if matrix is None:
matrix = pandas.DataFrame.from_dict({
'entity_id': [1, 2],
'feature_one': [3, 4],
'feature_two': [5, 6],
'label': [7, 8]
}).set_index('entity_id')
super().__init__(matrix=matrix, metadata=base_metadata)
if init_labels is None:
init_labels = []

self.label_count = label_count
self.init_labels = init_labels
Expand Down
17 changes: 13 additions & 4 deletions src/triage/component/architect/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import s3fs
from sqlalchemy.orm import sessionmaker
from triage.util.db import SerializableDbEngine
from urllib.parse import urlparse


Expand All @@ -24,10 +25,17 @@ def __init__(
):
self.db_config = db_config
self.matrix_directory = matrix_directory
self.engine = engine
self.serializable_db_engine = SerializableDbEngine(engine)
self.replace = replace
self.include_missing_labels_in_train_as = include_missing_labels_in_train_as
self.sessionmaker = sessionmaker(bind=self.engine)

@property
def sessionmaker(self):
return sessionmaker(bind=self.db_engine)

@property
def db_engine(self):
return self.serializable_db_engine.db_engine

def validate(self):
for expected_db_config_val in [
Expand Down Expand Up @@ -153,7 +161,7 @@ def make_entity_date_table(
)
logging.info('Creating matrix-specific entity-date table for matrix '
'%s with query %s', matrix_uuid, query)
self.engine.execute(query)
self.db_engine.execute(query)

return table_name

Expand Down Expand Up @@ -239,6 +247,7 @@ def build_matrix(
:return: none
:rtype: none
"""
logging.info('popped matrix %s build off the queue', matrix_uuid)
matrix_filename = os.path.join(
matrix_directory,
'{}.csv'.format(matrix_uuid)
Expand Down Expand Up @@ -485,7 +494,7 @@ def write_to_csv(self, query_string, file_name, header='HEADER'):
query=query_string,
head=header
)
conn = self.engine.raw_connection()
conn = self.db_engine.raw_connection()
cur = conn.cursor()
cur.copy_expert(copy_sql, matrix_csv)
finally:
Expand Down
7 changes: 6 additions & 1 deletion src/triage/component/architect/feature_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import OrderedDict

import sqlalchemy
from triage.util.db import SerializableDbEngine

from triage.util.conf import convert_str_to_relativedelta

Expand Down Expand Up @@ -33,13 +34,17 @@ def __init__(
feature_start_time (string/datetime, optional) point in time before which
should not be included in features
"""
self.db_engine = db_engine
self.serializable_db_engine = SerializableDbEngine(db_engine)
self.features_schema_name = features_schema_name
self.categorical_cache = {}
self.replace = replace
self.feature_start_time = feature_start_time
self.entity_id_column = 'entity_id'

@property
def db_engine(self):
return self.serializable_db_engine.db_engine

def _validate_keys(self, aggregation_config):
for key in ['from_obj', 'intervals', 'groups', 'knowledge_date_column', 'prefix']:
if key not in aggregation_config:
Expand Down
30 changes: 14 additions & 16 deletions src/triage/component/catwalk/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy
from sqlalchemy.orm import sessionmaker
from triage.util.db import SerializableDbEngine

from . import metrics
from .utils import db_retry, sort_predictions_and_labels
Expand Down Expand Up @@ -97,13 +98,19 @@ def __init__(
"""
self.metric_groups = metric_groups
self.training_metric_groups = training_metric_groups
self.db_engine = db_engine
self.serializable_db_engine = SerializableDbEngine(db_engine)
self.sort_seed = sort_seed or int(time.time())
if custom_metrics:
self._validate_metrics(custom_metrics)
self.available_metrics.update(custom_metrics)
if self.db_engine:
self.sessionmaker = sessionmaker(bind=self.db_engine)

@property
def db_engine(self):
return self.serializable_db_engine.db_engine

@property
def sessionmaker(self):
return sessionmaker(bind=self.db_engine)

def _validate_metrics(
self,
Expand Down Expand Up @@ -302,29 +309,20 @@ def _evaluations_for_group(
)
return evaluations

def evaluate(
self,
predictions_proba,
matrix_store,
model_id,
evaluation_start_time,
evaluation_end_time,
as_of_date_frequency
):
def evaluate(self, predictions_proba, matrix_store, model_id):
"""Evaluate a model based on predictions, and save the results
Args:
predictions_proba (numpy.array) List of prediction probabilities
matrix_store (catwalk.storage.MatrixStore) a wrapper for the
prediction matrix and metadata
model_id (int) The database identifier of the model
evaluation_start_time (datetime.datetime) The time of the first prediction
being evaluated
evaluation_end_time (datetime.datetime) The time of the last prediction being evaluated
as_of_date_frequency (string) How frequently predictions were generated
"""
labels = matrix_store.labels()
matrix_type = matrix_store.matrix_type.string_name
evaluation_start_time = matrix_store.as_of_dates[0]
evaluation_end_time = matrix_store.as_of_dates[-1]
as_of_date_frequency = matrix_store.metadata['as_of_date_frequency']

# Specifies which evaluation table to write to: TestEvaluation or TrainEvaluation
evaluation_table_obj = matrix_store.matrix_type.evaluation_obj
Expand Down
Loading

0 comments on commit 01425b2

Please sign in to comment.