Skip to content

Commit

Permalink
Merge pull request #628 from dssg/faster_task_generation
Browse files Browse the repository at this point in the history
Faster train/test task generation
  • Loading branch information
thcrock authored Mar 12, 2019
2 parents ae448f2 + 8e2089f commit 041f568
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 31 deletions.
53 changes: 29 additions & 24 deletions src/triage/component/catwalk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,6 @@ def __init__(
def generate_tasks(self, split, grid_config, model_comment=None):
logging.info("Generating train/test tasks for split %s", split["train_uuid"])
train_store = self.matrix_storage_engine.get_store(split["train_uuid"])
if train_store.empty:
logging.warning(
"""Train matrix for split %s was empty,
no point in training this model. Skipping
""",
split["train_uuid"],
)
return []
if len(train_store.labels.unique()) == 1:
logging.warning(
"""Train Matrix for split %s had only one
unique value, no point in training this model. Skipping
""",
split["train_uuid"],
)
return []
train_tasks = self.model_trainer.generate_train_tasks(
grid_config=grid_config,
misc_db_parameters=dict(test=False, model_comment=model_comment),
Expand All @@ -59,14 +43,6 @@ def generate_tasks(self, split, grid_config, model_comment=None):
):
test_store = self.matrix_storage_engine.get_store(test_uuid)

if test_store.empty:
logging.warning(
"""Test matrix for uuid %s
was empty, no point in generating predictions. Not creating train/test task.
""",
test_uuid,
)
continue
for train_task in train_tasks:
train_test_tasks.append(
{
Expand All @@ -83,6 +59,35 @@ def process_all_tasks(self, tasks):

def process_task(self, test_store, train_store, train_kwargs):
logging.info("Beginning train task %s", train_kwargs)

# If the train or test design matrix empty, or if the train store only
# has one label value, skip training the model.
if train_store.empty:
logging.warning(
"""Train matrix for split %s was empty,
no point in training this model. Skipping
""",
split["train_uuid"],
)
return
if len(train_store.labels.unique()) == 1:
logging.warning(
"""Train Matrix for split %s had only one
unique value, no point in training this model. Skipping
""",
split["train_uuid"],
)
return
if test_store.empty:
logging.warning(
"""Test matrix for uuid %s
was empty, no point in generating predictions. Not processing train/test task.
""",
test_uuid,
)
return

# If the matrices and train labels are OK, train and test the model!
with self.model_trainer.cache_models(), test_store.cache(), train_store.cache():
# will cache any trained models until it goes out of scope (at the end of the task)
# this way we avoid loading the model pickle again for predictions
Expand Down
8 changes: 1 addition & 7 deletions src/triage/component/catwalk/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,11 +399,7 @@ def metadata(self):
"""The raw metadata. Will load from storage into memory if not already loaded"""
if self.__metadata is not None:
return self.__metadata
metadata = self.load_metadata()
if self.should_cache:
self.__metadata = metadata
else:
return metadata
self.__metadata = self.load_metadata()
return self.__metadata

@metadata.setter
Expand Down Expand Up @@ -541,7 +537,6 @@ def save(self):

def clear_cache(self):
self._matrix_label_tuple = None
self.metadata = None

def __getstate__(self):
"""Remove object of a large size upon serialization.
Expand All @@ -550,7 +545,6 @@ def __getstate__(self):
"""
state = self.__dict__.copy()
state['_matrix_label_tuple'] = None
state['__metadata'] = None
return state


Expand Down

0 comments on commit 041f568

Please sign in to comment.