Skip to content

Commit

Permalink
Colocate Testing with Training [Resolves #560]
Browse files Browse the repository at this point in the history
Converts the 'train task' into 'train/test task', so instead of training
a split all at once and then testing it all at once, each model is
trained and then immediately tested. Additionally, the tasks are
flattened instead of grouped by split
  • Loading branch information
thcrock committed Jan 7, 2019
1 parent 91c67e7 commit 08933d6
Show file tree
Hide file tree
Showing 10 changed files with 257 additions and 128 deletions.
138 changes: 137 additions & 1 deletion src/triage/component/catwalk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,141 @@
from .model_trainers import ModelTrainer
from .predictors import Predictor
from .evaluation import ModelEvaluator
from .individual_importance import IndividualImportanceCalculator
from .model_grouping import ModelGrouper

__all__ = ("ModelTrainer", "Predictor", "ModelEvaluator")
import logging


class ModelTrainTester(object):
def __init__(
self,
matrix_storage_engine,
model_trainer,
model_evaluator,
individual_importance_calculator,
predictor
):
self.matrix_storage_engine = matrix_storage_engine
self.model_trainer = model_trainer
self.model_evaluator = model_evaluator
self.individual_importance_calculator = individual_importance_calculator
self.predictor = predictor

def generate_tasks(self, split, grid_config, model_comment=None):
logging.info("Generating train/test tasks for split %s", split["train_uuid"])
train_store = self.matrix_storage_engine.get_store(split["train_uuid"])
if train_store.empty:
logging.warning(
"""Train matrix for split %s was empty,
no point in training this model. Skipping
""",
split["train_uuid"],
)
return []
if len(train_store.labels().unique()) == 1:
logging.warning(
"""Train Matrix for split %s had only one
unique value, no point in training this model. Skipping
""",
split["train_uuid"],
)
return []
train_tasks = self.model_trainer.generate_train_tasks(
grid_config=grid_config,
misc_db_parameters=dict(test=False, model_comment=model_comment),
matrix_store=train_store
)

train_test_tasks = []
for test_matrix_def, test_uuid in zip(
split["test_matrices"], split["test_uuids"]
):
test_store = self.matrix_storage_engine.get_store(test_uuid)

if test_store.empty:
logging.warning(
"""Test matrix for uuid %s
was empty, no point in generating predictions. Not creating train/test task.
""",
test_uuid,
)
continue
for train_task in train_tasks:
train_test_tasks.append(
{
"test_store": test_store,
"train_store": train_store,
"train_kwargs": train_task,
}
)
return train_test_tasks

def process_all_tasks(self, tasks):
for task in tasks:
self.process_task(**task)

def process_task(self, test_store, train_store, train_kwargs):
logging.info("Beginning train task %s", train_kwargs)
model_id = self.model_trainer.process_train_task(**train_kwargs)
if not model_id:
logging.warning("No model id returned from ModelTrainer.process_train_task, "
"training unsuccessful. Not attempting to test")
return
logging.info("Trained task %s and got model id %s", train_kwargs, model_id)
as_of_times = test_store.metadata["as_of_times"]
logging.info(
"Testing and scoring model id %s with test matrix %s. "
"as_of_times min: %s max: %s num: %s",
model_id,
test_store.uuid,
min(as_of_times),
max(as_of_times),
len(as_of_times),
)

self.individual_importance_calculator.calculate_and_save_all_methods_and_dates(
model_id, test_store
)

# Generate predictions for the testing data then training data
for store in (test_store, train_store):
if self.predictor.replace or self.model_evaluator.needs_evaluations(store, model_id):
logging.info(
"The evaluations needed for matrix %s-%s and model %s"
"are not all present in db, so predicting and evaluating",
store.uuid,
store.matrix_type,
model_id
)
predictions_proba = self.predictor.predict(
model_id,
store,
misc_db_parameters=dict(),
train_matrix_columns=train_store.columns(),
)

self.model_evaluator.evaluate(
predictions_proba=predictions_proba,
matrix_store=store,
model_id=model_id,
)
else:
logging.info(
"The evaluations needed for matrix %s-%s and model %s are all present"
"in db from a previous run (or none needed at all), so skipping!",
store.uuid,
store.matrix_type,
model_id
)
self.model_trainer.uncache_model(model_id)


__all__ = (
"IndividualImportanceCalculator",
"ModelEvaluator",
"ModelGrouper"
"ModelTrainer",
"Predictor",
"ModelTrainTester"
)
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,5 @@ def save(self, importance_records, model_id, as_of_date, method_name):
importance_score=float(importance_record["score"]),
)
db_objects.append(db_object)
print(len(db_objects))
save_db_objects(self.db_engine, db_objects)
4 changes: 4 additions & 0 deletions src/triage/component/catwalk/model_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .utils import (
filename_friendly_hash,
retrieve_model_id_from_hash,
retrieve_model_hash_from_id,
db_retry,
save_db_objects,
)
Expand Down Expand Up @@ -420,3 +421,6 @@ def generate_train_tasks(self, grid_config, misc_db_parameters, matrix_store=Non
)
logging.info("Found %s unique model training tasks", len(tasks))
return tasks

def uncache_model(self, model_id):
self.model_storage_engine.uncache(retrieve_model_hash_from_id(self.db_engine, model_id))
22 changes: 3 additions & 19 deletions src/triage/component/catwalk/predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from triage.component.results_schema import Model

from .utils import db_retry
from .utils import db_retry, retrieve_model_hash_from_id


class ModelNotFoundError(ValueError):
Expand All @@ -37,22 +37,6 @@ def __init__(self, model_storage_engine, db_engine, replace=True):
def sessionmaker(self):
return sessionmaker(bind=self.db_engine)

@db_retry
def _retrieve_model_hash(self, model_id):
"""Retrieves the model hash associated with a given model id
Args:
model_id (int) The id of a given model in the database
Returns: (str) the stored hash of the model
"""
try:
session = self.sessionmaker()
model_hash = session.query(Model).get(model_id).model_hash
finally:
session.close()
return model_hash

@db_retry
def load_model(self, model_id):
"""Downloads the cached model associated with a given model id
Expand All @@ -64,7 +48,7 @@ def load_model(self, model_id):
A python object which implements .predict()
"""

model_hash = self._retrieve_model_hash(model_id)
model_hash = retrieve_model_hash_from_id(self.db_engine, model_id)
logging.info("Checking for model_hash %s in store", model_hash)
if self.model_storage_engine.exists(model_hash):
return self.model_storage_engine.load(model_hash)
Expand All @@ -76,7 +60,7 @@ def delete_model(self, model_id):
Args:
model_id (int) The id of a given model in the database
"""
model_hash = self._retrieve_model_hash(model_id)
model_hash = retrieve_model_hash_from_id(self.db_engine, model_id)
self.model_storage_engine.delete(model_hash)

@db_retry
Expand Down
26 changes: 25 additions & 1 deletion src/triage/component/catwalk/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,15 @@ class ModelStorageEngine(object):
A project file storage engine
model_directory (string, optional) A directory name for models.
Defaults to 'trained_models'
should_cache (bool, optional) Whether or not the engine should cache written models
in memory in addition to persisting. Defaults to True
"""
def __init__(self, project_storage, model_directory=None):
def __init__(self, project_storage, model_directory=None, should_cache=True):
self.project_storage = project_storage
self.directories = [model_directory or "trained_models"]
self.should_cache = should_cache
self.cache = {}

def write(self, obj, model_hash):
"""Persist a model object using joblib. Also performs compression
Expand All @@ -218,6 +222,9 @@ def write(self, obj, model_hash):
obj (object) A picklable model object
model_hash (string) An identifier, unique within this project, for the model
"""
if self.should_cache:
logging.info("Caching model %s", model_hash)
self.cache[model_hash] = obj
with self._get_store(model_hash).open("wb") as fd:
joblib.dump(obj, fd, compress=True)

Expand All @@ -229,6 +236,9 @@ def load(self, model_hash):
Returns: (object) A model object
"""
if self.should_cache and model_hash in self.cache:
logging.info("Returning model %s from cache", model_hash)
return self.cache[model_hash]
with self._get_store(model_hash).open("rb") as fd:
return joblib.load(fd)

Expand All @@ -250,6 +260,20 @@ def delete(self, model_hash):
"""
return self._get_store(model_hash).delete()

def uncache(self, model_hash):
"""Remove the model identified by this hash from memory
Args:
model_hash (string) An identifier, unique within this project, for the model
"""
if model_hash in self.cache:
logging.info("Removing model %s from cache", model_hash)
del self.cache[model_hash]
else:
logging.info("Model %s not in cache (likely was trained in another run),"
"so no need to remove",
model_hash)

def _get_store(self, model_hash):
return self.project_storage.get_store(self.directories, model_hash)

Expand Down
16 changes: 16 additions & 0 deletions src/triage/component/catwalk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,22 @@ def retrieve_model_id_from_hash(db_engine, model_hash):
session.close()


@db_retry
def retrieve_model_hash_from_id(db_engine, model_id):
"""Retrieves the model hash associated with a given model id
Args:
model_id (int) The id of a given model in the database
Returns: (str) the stored hash of the model
"""
session = sessionmaker(bind=db_engine)()
try:
return session.query(Model).get(model_id).model_hash
finally:
session.close()


@db_retry
def save_db_objects(db_engine, db_objects):
"""Saves a collection of SQLAlchemy model objects to the database using a COPY command
Expand Down
Loading

0 comments on commit 08933d6

Please sign in to comment.