diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index 5a89828f50..b0f793e94a 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -4,7 +4,8 @@ import sys import warnings import platform -import multiprocessing as mp +import asyncio + import casadi import numpy as np @@ -39,6 +40,12 @@ class BaseSolver: output_variables : list[str], optional List of variables to calculate and return. If none are specified then the complete state vector is returned (can be very large) (default is []) + options: dict, optional + List of options, defaults are: + options = { + # Number of threads available for OpenMP + "num_threads": 1, + } """ def __init__( @@ -50,6 +57,7 @@ def __init__( root_tol=1e-6, extrap_tol=None, output_variables=None, + options=None, ): self.method = method self.rtol = rtol @@ -59,6 +67,17 @@ def __init__( self.extrap_tol = extrap_tol or -1e-10 self.output_variables = [] if output_variables is None else output_variables self._model_set_up = {} + default_options = { + "num_threads": 1, + } + if options is None: + options = default_options + else: + print("options", options) + for key, value in default_options.items(): + if key not in options: + options[key] = value + self._base_options = options # Defaults, can be overwritten by specific solver self.name = "Base solver" @@ -145,6 +164,7 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False, batch_size=1): vars_for_processing, inputs, batch_size=batch_size, + nthreads=self._base_options["num_threads"], use_jacobian=False, ) model.initial_conditions_eval = initial_conditions @@ -195,6 +215,7 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False, batch_size=1): vars_for_processing, inputs, batch_size=batch_size, + nthreads=self._base_options["num_threads"], ) algebraic, jac_algebraic, jacp_algebraic, jac_algebraic_action = process( @@ -203,6 +224,7 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False, batch_size=1): vars_for_processing, inputs, batch_size=batch_size, + nthreads=self._base_options["num_threads"], ) # combine rhs and algebraic functions @@ -226,6 +248,7 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False, batch_size=1): vars_for_processing, inputs, batch_size=batch_size, + nthreads=self._base_options["num_threads"], ) ( @@ -295,6 +318,7 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False, batch_size=1): vars_for_processing, inputs, batch_size=batch_size, + nthreads=self._base_options["num_threads"], use_jacobian=True, return_jacp_stacked=True, ) @@ -565,6 +589,7 @@ def _set_up_events(self, model, t_eval, inputs, vars_for_processing, batch_size) vars_for_processing, inputs, batch_size=batch_size, + nthreads=self._base_options["num_threads"], use_jacobian=False, )[0] # use the actual casadi object as this will go into the rhs @@ -577,6 +602,7 @@ def _set_up_events(self, model, t_eval, inputs, vars_for_processing, batch_size) vars_for_processing, inputs, batch_size=batch_size, + nthreads=self._base_options["num_threads"], use_jacobian=False, )[0] if event.event_type == pybamm.EventType.TERMINATION: @@ -636,25 +662,28 @@ def _set_initial_conditions( ] # Reuse old solution for algebraic equations y0_from_model = model.y0_list - len_rhs = model.len_rhs + len_rhs = model.len_rhs + model.len_rhs_sens # update model.y0_list, which is used for initialising the algebraic solver if len_rhs == 0: model.y0_list = y0_from_model else: - if isinstance(y0_from_inputs, casadi.DM): - for i in range(len(y0_from_inputs)): - model.y0_list[i] = casadi.vertcat( - y0_from_inputs[i][:len_rhs], y0_from_model[i][len_rhs:] - ) - else: - for i in range(len(y0_from_inputs)): - model.y0_list[i] = np.vstack( - ( - y0_from_inputs[i][:len_rhs], - y0_from_model[i][len_rhs:], - ) - ) + for i in range(len(y0_from_inputs)): + _, y_alg = self._unzip_state_vector(model, y0_from_model[i]) + y_diff, _ = self._unzip_state_vector(model, y0_from_inputs[i]) + model.y0_list[i] = self._zip_state_vector(model, y_diff, y_alg) y0_list = self.calculate_consistent_state(model, time, inputs_list) + + # concatenate batches again + nbatches = len(batched_inputs) + batch_size = len(inputs_list) // nbatches + y0_list_of_list = [ + y0_list[i * batch_size : (i + 1) * batch_size] for i in range(nbatches) + ] + if isinstance(y0_list[0], casadi.DM): + y0_list = [casadi.vertcat(*y0s) for y0s in y0_list_of_list] + else: + y0_list = [np.vstack(y0s) for y0s in y0_list_of_list] + # Make y0 a function of inputs if doing symbolic with casadi model.y0_list = y0_list @@ -765,29 +794,66 @@ def _integrate( self._handle_integrate_defaults(model, inputs_list, batched_inputs) ) - # todo: make this parallel - with mp.get_context(self._mp_context).Pool(processes=nproc) as p: - model_list = [model] * nbatches - t_eval_list = [t_eval] * nbatches - y0_list = model.y0_list - inputs_list_of_list = [ - inputs_list[i * batch_size : (i + 1) * batch_size] - for i in range(nbatches) - ] - new_solutions = p.starmap( - self._integrate_batch, - zip( - model_list, - t_eval_list, - y0_list, - y0S_list, - inputs_list_of_list, - batched_inputs, - ), + if nbatches == 1: + return self._integrate_batch( + model, + t_eval, + model.y0_list[0], + y0S_list[0], + inputs_list, + batched_inputs[0], ) - p.close() - p.join() - return new_solutions + + # async io is not parallel, but if solve is io bound, it can be faster + async def solve_model_batches(): + async def solve_model_async(y0, y0S, inputs, inputs_array): + return self._integrate_batch( + model, t_eval, y0, y0S, inputs, inputs_array + ) + + coro = [] + for i in range(nbatches): + coro.append( + asyncio.create_task( + solve_model_async( + model.y0_list[i], + y0S_list[i], + inputs_list[i * batch_size : (i + 1) * batch_size], + batched_inputs[i], + ) + ) + ) + return await asyncio.gather(*coro) + + new_solutions = asyncio.run(solve_model_batches()) + + # new_solutions = [] + # for i in range(nbatches): + # new_solutions.append(self._integrate_batch(model, t_eval, model.y0_list[i], y0S_list[i], inputs_list[i * batch_size : (i + 1) * batch_size], batched_inputs[i])) + + # with mp.get_context(self._mp_context).Pool(processes=nproc) as p: + # model_list = [model] * nbatches + # t_eval_list = [t_eval] * nbatches + # y0_list = model.y0_list + # inputs_list_of_list = [ + # inputs_list[i * batch_size : (i + 1) * batch_size] + # for i in range(nbatches) + # ] + # new_solutions = p.starmap( + # self._integrate_batch, + # zip( + # model_list, + # t_eval_list, + # y0_list, + # y0S_list, + # inputs_list_of_list, + # batched_inputs, + # ), + # ) + # p.close() + # p.join() + new_solutions_flat = [sol for sublist in new_solutions for sol in sublist] + return new_solutions_flat def _integrate_batch(self, model, t_eval, y0, y0S, inputs_list, inputs): """ @@ -1428,6 +1494,7 @@ def get_termination_reason(solution, events): termination_events = [ x for x in events if x.event_type == pybamm.EventType.TERMINATION ] + if solution.termination == "final time": return ( solution, @@ -1435,6 +1502,10 @@ def get_termination_reason(solution, events): ) elif solution.termination == "event": pybamm.logger.debug("Start post-processing events") + if isinstance(solution.y_event, casadi.DM): + solution_y_event = solution.y_event.full() + else: + solution_y_event = solution.y_event if solution.closest_event_idx is not None: solution.termination = ( f"event: {termination_events[solution.closest_event_idx].name}" @@ -1446,7 +1517,7 @@ def get_termination_reason(solution, events): for event in termination_events: final_event_values[event.name] = event.expression.evaluate( solution.t_event, - solution.y_event, + solution_y_event, inputs=solution.all_inputs[-1], ) termination_event = min(final_event_values, key=final_event_values.get) @@ -1593,8 +1664,58 @@ def _input_dict_to_slices(input_dict: dict): i += inc return input_slices + @staticmethod + def _unzip_state_vector(model, y): + nstates = ( + model.len_rhs + model.len_rhs_sens + model.len_alg + model.len_alg_sens + ) + len_rhs = model.len_rhs + model.len_rhs_sens + batch_size = model.batch_size + + if isinstance(y, casadi.DM): + y_diff = casadi.vertcat( + *[y[i * nstates : i * nstates + len_rhs] for i in range(batch_size)] + ) + y_alg = casadi.vertcat( + *[ + y[i * nstates + len_rhs : (i + 1) * nstates] + for i in range(batch_size) + ] + ) + else: + y_diff = np.vstack( + [y[i * nstates : i * nstates + len_rhs] for i in range(batch_size)] + ) + y_alg = np.vstack( + [ + y[i * nstates + len_rhs : (i + 1) * nstates] + for i in range(batch_size) + ] + ) + + return y_diff, y_alg -def map_func_over_inputs_casadi(name, f, vars_for_processing, ninputs): + @staticmethod + def _zip_state_vector(model, y_diff, y_alg): + len_rhs = model.len_rhs + model.len_rhs_sens + len_alg = model.len_alg + model.len_alg_sens + batch_size = model.batch_size + y_diff_list = [ + y_diff[i * len_rhs : (i + 1) * len_rhs] for i in range(batch_size) + ] + y_alg_list = [y_alg[i * len_alg : (i + 1) * len_alg] for i in range(batch_size)] + if isinstance(y_diff, casadi.DM): + y = casadi.vertcat( + *[val for pair in zip(y_diff_list, y_alg_list) for val in pair] + ) + else: + y = np.vstack( + [val for pair in zip(y_diff_list, y_alg_list) for val in pair] + ) + return y + + +def map_func_over_inputs_casadi(name, f, vars_for_processing, ninputs, nthreads): """ This takes a casadi function f and returns a new casadi function that maps f over the provided number of inputs. Some functions (e.g. jacobian action) require an additional @@ -1611,6 +1732,8 @@ def map_func_over_inputs_casadi(name, f, vars_for_processing, ninputs): dictionary of variables for processing ninputs: int number of inputs to map over + nthreads: int + number of threads to use """ if f is None: return None @@ -1622,30 +1745,33 @@ def map_func_over_inputs_casadi(name, f, vars_for_processing, ninputs): nstates = vars_for_processing["y_and_S"].shape[0] nparams = vars_for_processing["p_casadi_stacked"].shape[0] - parallelisation = "thread" + if nthreads == 1: + parallelisation = "none" + else: + parallelisation = "thread" y_and_S_inputs_stacked = casadi.MX.sym("y_and_S_stacked", nstates * ninputs) p_casadi_inputs_stacked = casadi.MX.sym("p_stacked", nparams * ninputs) v_inputs_stacked = casadi.MX.sym("v_stacked", nstates * ninputs) + t_stacked = casadi.MX.sym("t_stacked", ninputs) y_and_S_2d = y_and_S_inputs_stacked.reshape((nstates, ninputs)) p_casadi_2d = p_casadi_inputs_stacked.reshape((nparams, ninputs)) v_2d = v_inputs_stacked.reshape((nstates, ninputs)) - - t_casadi = vars_for_processing["t_casadi"] + t_2d = t_stacked.reshape((1, ninputs)) if add_v: - inputs_2d = [t_casadi, y_and_S_2d, p_casadi_2d, v_2d] + inputs_2d = [t_2d, y_and_S_2d, p_casadi_2d, v_2d] inputs_stacked = [ - t_casadi, + t_stacked, y_and_S_inputs_stacked, p_casadi_inputs_stacked, v_inputs_stacked, ] else: - inputs_2d = [t_casadi, y_and_S_2d, p_casadi_2d] - inputs_stacked = [t_casadi, y_and_S_inputs_stacked, p_casadi_inputs_stacked] + inputs_2d = [t_2d, y_and_S_2d, p_casadi_2d] + inputs_stacked = [t_stacked, y_and_S_inputs_stacked, p_casadi_inputs_stacked] - mapped_f = f.map(ninputs, parallelisation)(*inputs_2d) + mapped_f = f.map(ninputs, parallelisation, nthreads)(*inputs_2d) if matrix_output: # for matrix output we need to stack the outputs in a block diagonal matrix splits = [i * nstates for i in range(ninputs + 1)] @@ -1669,6 +1795,7 @@ def process( vars_for_processing, inputs: list[dict], batch_size, + nthreads, use_jacobian=None, return_jacp_stacked=None, ): @@ -1901,15 +2028,29 @@ def jacp(t, y, inputs): ninputs = len(inputs_batch) if ninputs > 1: - func = map_func_over_inputs_casadi(name, func, vars_for_processing, ninputs) + func = map_func_over_inputs_casadi( + name, func, vars_for_processing, ninputs, nthreads + ) jac = map_func_over_inputs_casadi( - name + "_jac", jac, vars_for_processing, ninputs + name + "_jac", + jac, + vars_for_processing, + ninputs, + nthreads, ) jacp = map_func_over_inputs_casadi( - name + "_jacp", jacp, vars_for_processing, ninputs + name + "_jacp", + jacp, + vars_for_processing, + ninputs, + nthreads, ) jac_action = map_func_over_inputs_casadi( - name + "_jac_action", jac_action, vars_for_processing, ninputs + name + "_jac_action", + jac_action, + vars_for_processing, + ninputs, + nthreads, ) return func, jac, jacp, jac_action diff --git a/pybamm/solvers/casadi_algebraic_solver.py b/pybamm/solvers/casadi_algebraic_solver.py index da18f399a9..7bc2d2c40b 100644 --- a/pybamm/solvers/casadi_algebraic_solver.py +++ b/pybamm/solvers/casadi_algebraic_solver.py @@ -42,25 +42,37 @@ def _integrate_batch(self, model, t_eval, y0, y0S, inputs_list, inputs): # i.e. the part of the solution vector that corresponds to the differential # equations will be equal to the initial condition provided. This allows this # solver to be used for initialising the DAE solvers + nstates = ( + model.len_rhs + model.len_rhs_sens + model.len_alg + model.len_alg_sens + ) + len_rhs = model.len_rhs + model.len_rhs_sens + len_alg = model.len_alg + model.len_alg_sens + batch_size = len(inputs_list) if model.rhs == {}: - len_rhs = 0 - y0_diff = casadi.DM() + y0_diff_list = [casadi.DM() for _ in range(batch_size)] y0_alg = y0 else: - # Check y0 to see if it includes sensitivities - if model.len_rhs_and_alg == y0.shape[0]: - len_rhs = model.len_rhs - else: - len_rhs = model.len_rhs + model.len_rhs_sens - y0_diff = y0[:len_rhs] - y0_alg = y0[len_rhs:] + y0_diff_list = [ + y0[i * nstates : i * nstates + len_rhs] for i in range(batch_size) + ] + y0_alg_list = [ + y0[i * nstates + len_rhs : (i + 1) * nstates] for i in range(batch_size) + ] + y0_alg = casadi.vertcat(*y0_alg_list) - y_alg = None + y_sol = None # Set up t_sym = casadi.MX.sym("t") - y_alg_sym = casadi.MX.sym("y_alg", y0_alg.shape[0]) - y_sym = casadi.vertcat(y0_diff, y_alg_sym) + y_alg_sym_list = [ + casadi.MX.sym(f"y_alg{i}", len_alg) for i in range(batch_size) + ] + y_alg_sym = casadi.vertcat(*y_alg_sym_list) + + # interleave the differential and algebraic parts + y_sym = casadi.vertcat( + *[val for pair in zip(y0_diff_list, y_alg_sym_list) for val in pair] + ) alg = model.casadi_algebraic(t_sym, y_sym, inputs) @@ -81,7 +93,7 @@ def _integrate_batch(self, model, t_eval, y0, y0S, inputs_list, inputs): { **self.extra_options, "abstol": self.tol, - "constraints": list(constraints[len_rhs:]), + "constraints": list(constraints[len_rhs:]) * batch_size, }, ) @@ -96,8 +108,11 @@ def _integrate_batch(self, model, t_eval, y0, y0S, inputs_list, inputs): success = True message = None # Check final output - y_sol = casadi.vertcat(y0_diff, y_alg_sol) - fun = model.casadi_algebraic(t, y_sol, inputs) + y_alg_sol_list = casadi.vertsplit(y_alg_sol, len_alg) + yi_sol = casadi.vertcat( + *[val for pair in zip(y0_diff_list, y_alg_sol_list) for val in pair] + ) + fun = model.casadi_algebraic(t, yi_sol, inputs) except RuntimeError as err: success = False message = err.args[0] @@ -110,12 +125,11 @@ def _integrate_batch(self, model, t_eval, y0, y0S, inputs_list, inputs): ): # update initial guess for the next iteration y0_alg = y_alg_sol - y0 = casadi.vertcat(y0_diff, y0_alg) # update solution array - if y_alg is None: - y_alg = y_alg_sol + if y_sol is None: + y_sol = yi_sol else: - y_alg = casadi.horzcat(y_alg, y_alg_sol) + y_sol = casadi.horzcat(y_sol, yi_sol) elif not success: raise pybamm.SolverError( f"Could not find acceptable solution: {message}" @@ -133,10 +147,6 @@ def _integrate_batch(self, model, t_eval, y0, y0S, inputs_list, inputs): """ ) - # Concatenate differential part - y_diff = casadi.horzcat(*[y0_diff] * len(t_eval)) - y_sol = casadi.vertcat(y_diff, y_alg) - # Return solution object (no events, so pass None to t_event, y_event) try: diff --git a/pybamm/solvers/casadi_solver.py b/pybamm/solvers/casadi_solver.py index f86105f53e..14bdb77a32 100644 --- a/pybamm/solvers/casadi_solver.py +++ b/pybamm/solvers/casadi_solver.py @@ -525,9 +525,20 @@ def create_integrator(self, model, inputs, y0, t_eval=None, use_event_switch=Fal t = casadi.MX.sym("t") p = casadi.MX.sym("p", inputs.shape[0]) - y_diff = casadi.MX.sym("y_diff", rhs(0, y0, p).shape[0]) - y_alg = casadi.MX.sym("y_alg", algebraic(0, y0, p).shape[0]) - y_full = casadi.vertcat(y_diff, y_alg) + batch_size = model.batch_size + len_rhs = model.len_rhs + model.len_rhs_sens + len_alg = model.len_alg + model.len_alg_sens + + y_diff_list = [ + casadi.MX.sym(f"y_diff{i}", len_rhs) for i in range(batch_size) + ] + y_diff = casadi.vertcat(*y_diff_list) + y_alg_list = [ + casadi.MX.sym(f"y_alg{i}", len_alg) for i in range(batch_size) + ] + y_alg = casadi.vertcat(*y_alg_list) + y_full_list = [val for pair in zip(y_diff_list, y_alg_list) for val in pair] + y_full = casadi.vertcat(*y_full_list) if use_grid is False: time_args = [] @@ -634,23 +645,17 @@ def _run_integrator( else: integrator = self.integrators[model]["no grid"] - len_rhs = model.concatenated_rhs.size * len(inputs_list) - len_alg = model.concatenated_algebraic.size * len(inputs_list) - - # Check y0 to see if it includes sensitivities - if explicit_sensitivities: - num_parameters = model.len_rhs_sens // model.len_rhs - len_rhs = len_rhs * (num_parameters + 1) - len_alg = len_alg * (num_parameters + 1) - - y0_diff = y0[:len_rhs] - y0_alg_exact = y0[len_rhs:] + y0_diff, y0_alg_exact = self._unzip_state_vector(model, y0) + len_alg = model.len_alg + model.len_alg_sens + batch_size = model.batch_size if self.perturb_algebraic_initial_conditions and len_alg > 0: # Add a tiny perturbation to the algebraic initial conditions # For some reason this helps with convergence # The actual value of the initial conditions for the algebraic variables # doesn't matter - y0_alg = y0_alg_exact * (1 + 1e-6 * casadi.DM(np.random.rand(len_alg))) + y0_alg = y0_alg_exact * ( + 1 + 1e-6 * casadi.DM(np.random.rand(len_alg * batch_size)) + ) else: y0_alg = y0_alg_exact pybamm.logger.spam("Finished preliminary setup for integrator run") @@ -667,6 +672,9 @@ def _run_integrator( casadi_sol = integrator( x0=y0_diff, z0=y0_alg, p=inputs_with_tmin, **self.extra_options_call ) + casadi_y = self._zip_state_vector( + model, casadi_sol["xf"], casadi_sol["zf"] + ) except RuntimeError as error: # If it doesn't work raise error pybamm.logger.debug(f"Casadi integrator failed with error {error}") @@ -674,12 +682,7 @@ def _run_integrator( pybamm.logger.debug("Finished casadi integrator") integration_time = timer.time() # Manually add initial conditions and concatenate - x_sol = casadi.horzcat(y0_diff, casadi_sol["xf"]) - if len_alg > 0: - z_sol = casadi.horzcat(y0_alg_exact, casadi_sol["zf"]) - y_sol = casadi.vertcat(x_sol, z_sol) - else: - y_sol = x_sol + y_sol = casadi.horzcat(y0, casadi_y) sol = pybamm.Solution( t_eval, y_sol, @@ -692,10 +695,9 @@ def _run_integrator( return sol else: # Repeated calls to the integrator + y_sol = y0 x = y0_diff z = y0_alg_exact - y_diff = x - y_alg = z for i in range(len(t_eval) - 1): t_min = t_eval[i] t_max = t_eval[i + 1] @@ -705,6 +707,9 @@ def _run_integrator( casadi_sol = integrator( x0=x, z0=z, p=inputs_with_tlims, **self.extra_options_call ) + casadi_y = self._zip_state_vector( + model, casadi_sol["xf"], casadi_sol["zf"] + ) except RuntimeError as error: # If it doesn't work raise error pybamm.logger.debug(f"Casadi integrator failed with error {error}") @@ -712,13 +717,7 @@ def _run_integrator( integration_time = timer.time() x = casadi_sol["xf"] z = casadi_sol["zf"] - y_diff = casadi.horzcat(y_diff, x) - if not z.is_empty(): - y_alg = casadi.horzcat(y_alg, z) - if z.is_empty(): - y_sol = y_diff - else: - y_sol = casadi.vertcat(y_diff, y_alg) + y_sol = casadi.horzcat(y_sol, casadi_y) sol = pybamm.Solution( t_eval, diff --git a/pybamm/solvers/idaklu_solver.py b/pybamm/solvers/idaklu_solver.py index edead7c909..5ea8086b7d 100644 --- a/pybamm/solvers/idaklu_solver.py +++ b/pybamm/solvers/idaklu_solver.py @@ -119,6 +119,10 @@ def __init__( if idaklu_spec is None: # pragma: no cover raise ImportError("KLU is not installed") + base_options = { + "num_threads": options["num_threads"], + } + super().__init__( "ida", rtol, @@ -127,6 +131,7 @@ def __init__( root_tol, extrap_tol, output_variables, + base_options, ) self.name = "IDA KLU solver" @@ -364,12 +369,9 @@ def rootfn(t, y, inputs): return return_root # get ids of rhs and algebraic variables - if model.convert_to_format == "casadi": - rhs_ids = np.ones(model.rhs_eval(0, y0, inputs).shape[0]) - else: - rhs_ids = np.ones(model.rhs_eval(0, y0, inputs).shape[0]) - alg_ids = np.zeros(nstates - len(rhs_ids)) - ids = np.concatenate((rhs_ids, alg_ids)) + ids = np.concatenate( + [np.ones(model.len_rhs), np.zeros(model.len_alg)] * batch_size + ) number_of_sensitivity_parameters = 0 if model.jacp_rhs_algebraic_eval is not None: