diff --git a/micro_manager/adaptivity/adaptivity.py b/micro_manager/adaptivity/adaptivity.py index a977cabd..d2ad49a0 100644 --- a/micro_manager/adaptivity/adaptivity.py +++ b/micro_manager/adaptivity/adaptivity.py @@ -25,6 +25,7 @@ def __init__(self, configurator, logger) -> None: self._hist_param = configurator.get_adaptivity_hist_param() self._adaptivity_data_names = configurator.get_data_for_adaptivity() self._adaptivity_type = configurator.get_adaptivity_type() + self._micro_file_name = configurator.get_micro_file_name() self._logger = logger diff --git a/micro_manager/adaptivity/global_adaptivity.py b/micro_manager/adaptivity/global_adaptivity.py index 90690ec7..872be912 100644 --- a/micro_manager/adaptivity/global_adaptivity.py +++ b/micro_manager/adaptivity/global_adaptivity.py @@ -6,6 +6,7 @@ Note: All ID variables used in the methods of this class are global IDs, unless they have *local* in their name. """ import hashlib +import importlib from copy import deepcopy from typing import Dict @@ -13,6 +14,7 @@ from mpi4py import MPI from .adaptivity import AdaptivityCalculator +from ..micro_simulation import create_simulation_class class GlobalAdaptivityCalculator(AdaptivityCalculator): @@ -259,6 +261,16 @@ def _update_inactive_sims( # Only handle activation of simulations on this rank -- LOCAL SCOPE HERE ON if self._is_sim_on_this_rank[i]: to_be_activated_local_id = self._global_ids.index(i) + if micro_sims[to_be_activated_local_id] == None: + self._logger.info(f"{i} to be solved, lazy initialization") + micro_problem = getattr( + importlib.import_module( + self._micro_file_name, "MicroSimulation" + ), + "MicroSimulation", + ) + micro_sims[to_be_activated_local_id] = create_simulation_class(micro_problem)(i) + self._logger.info(f"lazy initialization of {i} successful") assoc_active_id = local_sim_is_associated_to[to_be_activated_local_id] if self._is_sim_on_this_rank[ @@ -278,9 +290,13 @@ def _update_inactive_sims( to_be_activated_local_id ] + # TODO: could be moved to before the lazy initialization above sim_states_and_global_ids = [] - for sim in micro_sims: - sim_states_and_global_ids.append((sim.get_state(), sim.get_global_id())) + for local_id, sim in enumerate(micro_sims): + if sim == None: + sim_states_and_global_ids.append((None, self._global_ids[local_id])) + else: + sim_states_and_global_ids.append((sim.get_state(), sim.get_global_id())) recv_reqs = self._p2p_comm( list(to_be_activated_map.keys()), sim_states_and_global_ids diff --git a/micro_manager/config.py b/micro_manager/config.py index 3cd4a3c3..7b828421 100644 --- a/micro_manager/config.py +++ b/micro_manager/config.py @@ -56,6 +56,8 @@ def __init__(self, logger, config_filename): self._output_micro_sim_time = False + self._micro_sims_lazy_init = False + self.read_json(config_filename) def read_json(self, config_filename): @@ -182,6 +184,9 @@ def read_json_micro_manager(self): else: raise Exception("Adaptivity type can be either local or global.") + if data["simulation_params"]["adaptivity_settings"]["lazy_init"] == "True": + self._micro_sims_lazy_init = True + exchange_data = {**self._read_data_names, **self._write_data_names} for dname in self._data["simulation_params"]["adaptivity_settings"]["data"]: self._data_for_adaptivity[dname] = exchange_data[dname] @@ -500,6 +505,18 @@ def is_adaptivity_required_in_every_implicit_iteration(self): """ return self._adaptivity_every_implicit_iteration + def micro_sims_lazy_init(self): + """ + Boolean stating whether micro simulations are created in a lazy manner. + + Returns + ------- + adaptivity : bool + True if micro simulations are created only when needed, False otherwise. + + """ + return self._micro_sims_lazy_init + def get_micro_dt(self): """ Get the size of the micro time window. diff --git a/micro_manager/micro_manager.py b/micro_manager/micro_manager.py index 9e2a62e4..99a57e35 100644 --- a/micro_manager/micro_manager.py +++ b/micro_manager/micro_manager.py @@ -80,6 +80,8 @@ def __init__(self, config_file: str) -> None: self._is_adaptivity_on = self._config.turn_on_adaptivity() + self._micro_sims_lazy_init = self._is_adaptivity_on and self._config.micro_sims_lazy_init() + if self._is_adaptivity_on: self._number_of_sims_for_adaptivity = 0 @@ -140,7 +142,7 @@ def solve(self) -> None: ) # If micro simulations have been initialized, compute adaptivity before starting the coupling - if self._micro_sims_init: + if self._micro_sims_init or self._micro_sims_lazy_init: ( similarity_dists, is_sim_active, @@ -153,6 +155,30 @@ def solve(self) -> None: sim_is_associated_to, self._data_for_adaptivity, ) + if self._micro_sims_lazy_init: + if self._adaptivity_type == "local": + active_sim_ids = np.where(is_sim_active)[0] + elif self._adaptivity_type == "global": + active_sim_ids = np.where( + is_sim_active[ + self._global_ids_of_local_sims[ + 0 + ] : self._global_ids_of_local_sims[-1] + + 1 + ] + )[0] + micro_problem = getattr( + importlib.import_module( + self._config.get_micro_file_name(), "MicroSimulation" + ), + "MicroSimulation", + ) + for i in active_sim_ids: + self._logger.info(f"lazy initialization of micro sim {i} started") + self._micro_sims[i] = create_simulation_class(micro_problem)( + self._global_ids_of_local_sims[i] + ) + self._logger.info(f"lazy initialization of micro sim {i} completed") while self._participant.is_coupling_ongoing(): @@ -161,7 +187,7 @@ def solve(self) -> None: # Write a checkpoint if self._participant.requires_writing_checkpoint(): for i in range(self._local_number_of_sims): - sim_states_cp[i] = self._micro_sims[i].get_state() + sim_states_cp[i] = self._micro_sims[i].get_state() if self._micro_sims[i] else None t_checkpoint = t n_checkpoint = n @@ -232,6 +258,9 @@ def solve(self) -> None: for active_id in active_sim_ids: self._micro_sims_active_steps[active_id] += 1 + if sim_states_cp[active_id] == None: + sim_states_cp[active_id] = self._micro_sims[active_id].get_state() + self._logger.info(f"state of lazily initialized micro sim {self._global_ids_of_local_sims[active_id]} successfully checkpointed") micro_sims_output = self._solve_micro_simulations_with_adaptivity( micro_sims_input, is_sim_active, sim_is_associated_to, dt @@ -273,7 +302,8 @@ def solve(self) -> None: # Revert micro simulations to their last checkpoints if required if self._participant.requires_reading_checkpoint(): for i in range(self._local_number_of_sims): - self._micro_sims[i].set_state(sim_states_cp[i]) + if self._micro_sims[i]: + self._micro_sims[i].set_state(sim_states_cp[i]) n = n_checkpoint t = t_checkpoint @@ -289,8 +319,8 @@ def solve(self) -> None: ): # Time window has converged, now micro output can be generated self._logger.info( "Micro simulations {} - {} have converged at t = {}".format( - self._micro_sims[0].get_global_id(), - self._micro_sims[-1].get_global_id(), + self._global_ids_of_local_sims[0], + self._global_ids_of_local_sims[-1], t, ) ) @@ -298,7 +328,8 @@ def solve(self) -> None: if self._micro_sims_have_output: if n % self._micro_n_out == 0: for sim in self._micro_sims: - sim.output() + if sim: + sim.output() self._participant.finalize() @@ -404,16 +435,16 @@ def initialize(self) -> None: ) # Create micro simulation objects - for i in range(self._local_number_of_sims): - self._micro_sims[i] = create_simulation_class(micro_problem)( - self._global_ids_of_local_sims[i] - ) - - self._logger.info( - "Micro simulations with global IDs {} - {} created.".format( - self._global_ids_of_local_sims[0], self._global_ids_of_local_sims[-1] + if not self._micro_sims_lazy_init: + for i in range(self._local_number_of_sims): + self._micro_sims[i] = create_simulation_class(micro_problem)( + self._global_ids_of_local_sims[i] + ) + self._logger.info( + "Micro simulations with global IDs {} - {} created.".format( + self._global_ids_of_local_sims[0], self._global_ids_of_local_sims[-1] + ) ) - ) if self._is_adaptivity_on: if self._adaptivity_type == "local": @@ -441,6 +472,8 @@ def initialize(self) -> None: if not initial_data: is_initial_data_available = False + if self._micro_sims_lazy_init: + raise Exception("no initial macro data available, lazy initialization would result in only one active simulation.") else: is_initial_data_available = True @@ -451,6 +484,8 @@ def initialize(self) -> None: if hasattr(micro_problem, "initialize") and callable( getattr(micro_problem, "initialize") ): + if self._micro_sims_lazy_init: + raise Exception("Adaptivity can't use data returned by initialize function of micro sims when using lazy initialization.") self._micro_sims_init = True # Starting value before setting try: # Try to get the signature of the initialize() method, if it is written in Python