diff --git a/src/tests/test_experiments.py b/src/tests/test_experiments.py index 4ff37a550..feceab941 100644 --- a/src/tests/test_experiments.py +++ b/src/tests/test_experiments.py @@ -490,8 +490,8 @@ def test_serializable_engine_check_triage_noconvert(): with testing.postgresql.Postgresql() as postgresql: db_engine = create_engine(postgresql.url()) with TemporaryDirectory() as temp_dir: - with mock.patch('triage.experiments.base.create_engine', wraps=create_engine) as engine_create_mock: - experiment = SingleThreadedExperiment( + with mock.patch('triage.experiments.multicore.create_engine', wraps=create_engine) as engine_create_mock: + experiment = MultiCoreExperiment( config=sample_config(), db_engine=db_engine, project_path=os.path.join(temp_dir, "inspections"), @@ -506,8 +506,8 @@ def test_serializable_engine_check_sqlalchemy_convert(): with testing.postgresql.Postgresql() as postgresql: db_engine = sqlalchemy.create_engine(postgresql.url()) with TemporaryDirectory() as temp_dir: - with mock.patch('triage.experiments.base.create_engine', wraps=create_engine) as engine_create_mock: - experiment = SingleThreadedExperiment( + with mock.patch('triage.experiments.multicore.create_engine', wraps=create_engine) as engine_create_mock: + experiment = MultiCoreExperiment( config=sample_config(), db_engine=db_engine, project_path=os.path.join(temp_dir, "inspections"), diff --git a/src/triage/experiments/base.py b/src/triage/experiments/base.py index e4049cb3d..839c3affc 100644 --- a/src/triage/experiments/base.py +++ b/src/triage/experiments/base.py @@ -46,8 +46,6 @@ from triage.database_reflection import table_has_data from triage.util.conf import dt_from_str -from triage.util.db import create_engine -from triage.util.pickling import can_pickle class ExperimentBase(ABC): @@ -84,15 +82,6 @@ def __init__( self._check_config_version(config) self.config = config - if not can_pickle(db_engine): - logging.warning( - "Raw, unserializable SQLAlchemy engine passed. " - "URL will be used, other options may be lost in multi-process environments" - ) - self.db_engine = create_engine(db_engine.url) - else: - self.db_engine = db_engine - self.project_storage = ProjectStorage(project_path) self.model_storage_engine = ModelStorageEngine(self.project_storage) self.matrix_storage_engine = MatrixStorageEngine( @@ -100,6 +89,7 @@ def __init__( ) self.project_path = project_path self.replace = replace + self.db_engine = db_engine upgrade_db(db_engine=self.db_engine) self.features_schema_name = "features" diff --git a/src/triage/experiments/multicore.py b/src/triage/experiments/multicore.py index 142304e86..f686a6aae 100644 --- a/src/triage/experiments/multicore.py +++ b/src/triage/experiments/multicore.py @@ -2,6 +2,8 @@ import traceback from functools import partial from pebble import ProcessPool +from triage.util.pickling import can_pickle +from triage.util.db import create_engine from triage.component.catwalk.utils import Batch @@ -9,8 +11,15 @@ class MultiCoreExperiment(ExperimentBase): - def __init__(self, n_processes=1, n_db_processes=1, *args, **kwargs): - super(MultiCoreExperiment, self).__init__(*args, **kwargs) + def __init__(self, db_engine, n_processes=1, n_db_processes=1, *args, **kwargs): + if not can_pickle(db_engine): + logging.warning( + "Raw, unserializable SQLAlchemy engine passed. " + "URL will be used, other options may be lost in multi-process environments" + ) + db_engine = create_engine(db_engine.url) + + super(MultiCoreExperiment, self).__init__(db_engine=db_engine, *args, **kwargs) if n_processes < 1: raise ValueError("n_processes must be 1 or greater") if n_db_processes < 1: @@ -24,6 +33,7 @@ def __init__(self, n_processes=1, n_db_processes=1, *args, **kwargs): self.n_processes = n_processes self.n_db_processes = n_db_processes + def generated_chunked_parallelized_results( self, partially_bound_function, tasks, n_processes, chunksize=1 ):