diff --git a/pybamm/expression_tree/operations/evaluate_python.py b/pybamm/expression_tree/operations/evaluate_python.py index 93a056d00d..cd7cb86219 100644 --- a/pybamm/expression_tree/operations/evaluate_python.py +++ b/pybamm/expression_tree/operations/evaluate_python.py @@ -286,6 +286,8 @@ def find_symbols( # Index has a different syntax than other univariate operations if isinstance(symbol, pybamm.Index): symbol_str = f"{children_vars[0]}[{symbol.slice.start}:{symbol.slice.stop}]" + elif isinstance(symbol, pybamm.AbsoluteValue): + symbol_str = f"{symbol.name}({children_vars[0]})" else: symbol_str = symbol.name + children_vars[0] diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index d0031ece9d..26bd2efaf2 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -4,7 +4,9 @@ import sys import warnings import platform -import asyncio + +# import asyncio +import multiprocessing as mp import casadi @@ -772,9 +774,7 @@ def _handle_integrate_defaults( y0S_list = model.y0S_list return inputs_list, batched_inputs, nbatches, batch_size, y0S_list - def _integrate( - self, model, t_eval, inputs_list=None, batched_inputs=None, nproc=None - ): + def _integrate(self, model, t_eval, inputs_list=None, batched_inputs=None): """ Solve a DAE model defined by residuals with initial conditions y0. @@ -804,53 +804,55 @@ def _integrate( ) # async io is not parallel, but if solve is io bound, it can be faster - async def solve_model_batches(): - async def solve_model_async(y0, y0S, inputs, inputs_array): - return self._integrate_batch( - model, t_eval, y0, y0S, inputs, inputs_array - ) - - coro = [] - for i in range(nbatches): - coro.append( - asyncio.create_task( - solve_model_async( - model.y0_list[i], - y0S_list[i], - inputs_list[i * batch_size : (i + 1) * batch_size], - batched_inputs[i], - ) - ) - ) - return await asyncio.gather(*coro) - - new_solutions = asyncio.run(solve_model_batches()) + # async def solve_model_batches(): + # async def solve_model_async(y0, y0S, inputs, inputs_array): + # return self._integrate_batch( + # model, t_eval, y0, y0S, inputs, inputs_array + # ) + + # coro = [] + # for i in range(nbatches): + # coro.append( + # asyncio.create_task( + # solve_model_async( + # model.y0_list[i], + # y0S_list[i], + # inputs_list[i * batch_size : (i + 1) * batch_size], + # batched_inputs[i], + # ) + # ) + # ) + # return await asyncio.gather(*coro) + + # new_solutions = asyncio.run(solve_model_batches()) # new_solutions = [] # for i in range(nbatches): # new_solutions.append(self._integrate_batch(model, t_eval, model.y0_list[i], y0S_list[i], inputs_list[i * batch_size : (i + 1) * batch_size], batched_inputs[i])) - # with mp.get_context(self._mp_context).Pool(processes=nproc) as p: - # model_list = [model] * nbatches - # t_eval_list = [t_eval] * nbatches - # y0_list = model.y0_list - # inputs_list_of_list = [ - # inputs_list[i * batch_size : (i + 1) * batch_size] - # for i in range(nbatches) - # ] - # new_solutions = p.starmap( - # self._integrate_batch, - # zip( - # model_list, - # t_eval_list, - # y0_list, - # y0S_list, - # inputs_list_of_list, - # batched_inputs, - # ), - # ) - # p.close() - # p.join() + threads_per_batch = max(self._base_options["num_threads"] // nbatches, 1) + nproc = self._base_options["num_threads"] // threads_per_batch + with mp.get_context(self._mp_context).Pool(processes=nproc) as p: + model_list = [model] * nbatches + t_eval_list = [t_eval] * nbatches + y0_list = model.y0_list + inputs_list_of_list = [ + inputs_list[i * batch_size : (i + 1) * batch_size] + for i in range(nbatches) + ] + new_solutions = p.starmap( + self._integrate_batch, + zip( + model_list, + t_eval_list, + y0_list, + y0S_list, + inputs_list_of_list, + batched_inputs, + ), + ) + p.close() + p.join() new_solutions_flat = [sol for sublist in new_solutions for sol in sublist] return new_solutions_flat @@ -1700,9 +1702,11 @@ def _zip_state_vector(model, y_diff, y_alg): len_alg = model.len_alg + model.len_alg_sens batch_size = model.batch_size y_diff_list = [ - y_diff[i * len_rhs : (i + 1) * len_rhs] for i in range(batch_size) + y_diff[i * len_rhs : (i + 1) * len_rhs, :] for i in range(batch_size) + ] + y_alg_list = [ + y_alg[i * len_alg : (i + 1) * len_alg, :] for i in range(batch_size) ] - y_alg_list = [y_alg[i * len_alg : (i + 1) * len_alg] for i in range(batch_size)] if isinstance(y_diff, casadi.DM): y = casadi.vertcat( *[val for pair in zip(y_diff_list, y_alg_list) for val in pair] @@ -1744,10 +1748,7 @@ def map_func_over_inputs_casadi(name, f, vars_for_processing, ninputs, nthreads) nstates = vars_for_processing["y_and_S"].shape[0] nparams = vars_for_processing["p_casadi_stacked"].shape[0] - threads_per_input = nthreads // ninputs - if threads_per_input > 1: - threads_per_input = 1 - if threads_per_input > 1: + if nthreads > 1: parallelisation = "thread" else: parallelisation = "none" @@ -1773,7 +1774,7 @@ def map_func_over_inputs_casadi(name, f, vars_for_processing, ninputs, nthreads) inputs_2d = [t_2d, y_and_S_2d, p_casadi_2d] inputs_stacked = [t_stacked, y_and_S_inputs_stacked, p_casadi_inputs_stacked] - mapped_f = f.map(ninputs, parallelisation, threads_per_input)(*inputs_2d) + mapped_f = f.map(ninputs, parallelisation, nthreads)(*inputs_2d) if matrix_output: # for matrix output we need to stack the outputs in a block diagonal matrix splits = [i * nstates for i in range(ninputs + 1)] @@ -1849,6 +1850,8 @@ def process( else: inputs_batch = [inputs[0]] is_event = "event" in name + nbatches = len(inputs) // batch_size + nthreads_per_batch = max(nthreads // nbatches, 1) def report(string): # don't log event conversion @@ -2028,31 +2031,30 @@ def jacp(t, y, inputs): name, [t_casadi, y_and_S, p_casadi_stacked], [casadi_expression] ) - ninputs = len(inputs_batch) - if ninputs > 1: + if batch_size > 1: func = map_func_over_inputs_casadi( - name, func, vars_for_processing, ninputs, nthreads + name, func, vars_for_processing, batch_size, nthreads_per_batch ) jac = map_func_over_inputs_casadi( name + "_jac", jac, vars_for_processing, - ninputs, - nthreads, + batch_size, + nthreads_per_batch, ) jacp = map_func_over_inputs_casadi( name + "_jacp", jacp, vars_for_processing, - ninputs, - nthreads, + batch_size, + nthreads_per_batch, ) jac_action = map_func_over_inputs_casadi( name + "_jac_action", jac_action, vars_for_processing, - ninputs, - nthreads, + batch_size, + nthreads_per_batch, ) return func, jac, jacp, jac_action diff --git a/pybamm/solvers/c_solvers/idaklu/CasadiSolverGroup.cpp b/pybamm/solvers/c_solvers/idaklu/CasadiSolverGroup.cpp index e672df87bc..66ba8221d0 100644 --- a/pybamm/solvers/c_solvers/idaklu/CasadiSolverGroup.cpp +++ b/pybamm/solvers/c_solvers/idaklu/CasadiSolverGroup.cpp @@ -56,14 +56,12 @@ std::vector CasadiSolverGroup::solve(np_array t_np, np_array y0_np, np const std::size_t solves_per_thread = n_groups / m_solvers.size(); const std::size_t remainder_solves = n_groups % m_solvers.size(); - const std::size_t nthreads = m_solvers.size(); - const realtype *t = t_np.data(); const realtype *y0 = y0_np.data(); const realtype *yp0 = yp0_np.data(); const realtype *inputs_data = inputs.data(); - omp_set_num_threads(nthreads); + omp_set_num_threads(m_solvers.size()); #pragma omp parallel for for (int i = 0; i < m_solvers.size(); i++) { for (int j = 0; j < solves_per_thread; j++) { diff --git a/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp b/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp index 8c1f06711f..27eb3315a4 100644 --- a/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp +++ b/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp @@ -90,6 +90,7 @@ CasadiSolverOpenMP::CasadiSolverOpenMP( } } + // todo: should use std::vector for these... res = new realtype[max_res_size]; res_dvar_dy = new realtype[max_res_dvar_dy]; res_dvar_dp = new realtype[max_res_dvar_dp]; @@ -261,6 +262,7 @@ void CasadiSolverOpenMP::CalcVarsSensitivities( for(int k=0; k namespace py = pybind11; + +// note: we rely on c_style ordering for numpy arrays so don't change this! using np_array = py::array_t; using np_array_int = py::array_t; diff --git a/pybamm/solvers/casadi_solver.py b/pybamm/solvers/casadi_solver.py index 14bdb77a32..f4b2af84fb 100644 --- a/pybamm/solvers/casadi_solver.py +++ b/pybamm/solvers/casadi_solver.py @@ -69,6 +69,12 @@ class CasadiSolver(pybamm.BaseSolver): The maximum number of integrators that the solver will retain before ejecting past integrators using an LRU methodology. A value of 0 or None leaves the number of integrators unbound. Default is 100. + options: dict, optional + List of options, defaults are: + options = { + # Number of threads available for OpenMP + "num_threads": 1, + } """ def __init__( @@ -86,6 +92,7 @@ def __init__( return_solution_if_failed_early=False, perturb_algebraic_initial_conditions=None, integrators_maxcount=100, + options=None, ): super().__init__( "problem dependent", @@ -94,6 +101,7 @@ def __init__( root_method, root_tol, extrap_tol, + options=options, ) if mode in ["safe", "fast", "fast with events", "safe without grid"]: self.mode = mode diff --git a/pybamm/solvers/jax_solver.py b/pybamm/solvers/jax_solver.py index 03e4e5c06b..4447ff36ab 100644 --- a/pybamm/solvers/jax_solver.py +++ b/pybamm/solvers/jax_solver.py @@ -4,7 +4,8 @@ import numpy as onp import pybamm -import multiprocessing as mp +import asyncio +import copy if pybamm.have_jax(): import jax @@ -43,7 +44,7 @@ class JaxSolver(pybamm.BaseSolver): The absolute tolerance for the solver (default is 1e-6). extrap_tol : float, optional The tolerance to assert whether extrapolation occurs or not (default is 0). - extra_options : dict, optional + options : dict, optional Any options to pass to the solver. Please consult `JAX documentation `_ @@ -57,7 +58,7 @@ def __init__( rtol=1e-6, atol=1e-6, extrap_tol=None, - extra_options=None, + options=None, ): if not pybamm.have_jax(): raise ModuleNotFoundError( @@ -67,15 +68,25 @@ def __init__( # note: bdf solver itself calculates consistent initial conditions so can set # root_method to none, allow user to override this behavior super().__init__( - method, rtol, atol, root_method=root_method, extrap_tol=extrap_tol + method, + rtol, + atol, + root_method=root_method, + extrap_tol=extrap_tol, + options=options, ) + method_options = ["RK45", "BDF"] if method not in method_options: raise ValueError(f"method must be one of {method_options}") self.ode_solver = False if method == "RK45": self.ode_solver = True - self.extra_options = extra_options or {} + options = options or {} + self.options = copy.copy(options) or {} + # remove "num_threads" from options as it is not supported + if "num_threads" in self.options: + self.options.pop("num_threads", None) self.name = f"JAX solver ({method})" self._cached_solves = dict() pybamm.citations.register("jax2018") @@ -177,7 +188,7 @@ def solve_model_rk45(y0, inputs: dict | list[dict]): stack_inputs(inputs), rtol=self.rtol, atol=self.atol, - **self.extra_options, + **self.options, ) return jnp.transpose(y) @@ -192,7 +203,7 @@ def solve_model_bdf(y0, inputs: dict | list[dict]): rtol=self.rtol, atol=self.atol, mass=mass, - **self.extra_options, + **self.options, ) return jnp.transpose(y) @@ -231,35 +242,99 @@ def _integrate(self, model, t_eval, inputs_list=None, batched_inputs=None): if model not in self._cached_solves: self._cached_solves[model] = self.create_solve(model, t_eval) - # todo: make this parallel - - solns = [] batch_size = len(inputs_list) // len(batched_inputs) nbatches = len(batched_inputs) - def solve_batch(i): - y0 = model.y0_list[i] - inputs_sublist = inputs_list[i * batch_size : (i + 1) * batch_size] - y = self._cached_solves[model](y0, inputs_sublist) - # convert to a normal numpy array - y = onp.array(y) - return pybamm.Solution.from_concatenated_state( - t_eval, - y, - model, - inputs_sublist, - termination="final time", - check_solution=False, + platform = jax.lib.xla_bridge.get_backend().platform.casefold() + if nbatches == 1: + y = [self._cached_solves[model](model.y0_list[0], inputs_list)] + elif platform.startswith("cpu"): + # cpu execution runs faster when multithreaded + async def solve_model_for_inputs(): + async def solve_model_async(y0, inputs_sublist): + return self._cached_solves[model](y0, inputs_sublist) + + coro = [] + for i in range(nbatches): + y0 = model.y0_list[i] + inputs_sublist = inputs_list[i * batch_size : (i + 1) * batch_size] + coro.append( + asyncio.create_task(solve_model_async(y0, inputs_sublist)) + ) + return await asyncio.gather(*coro) + + y = asyncio.run(solve_model_for_inputs()) + elif ( + platform.startswith("gpu") + or platform.startswith("tpu") + or platform.startswith("metal") + ): + # gpu execution runs faster when parallelised with vmap + # (see also comment below regarding single-program multiple-data + # execution (SPMD) using pmap on multiple XLAs) + + # convert inputs (array of dict) to a dict of arrays for vmap + inputs_v = { + key: jnp.array([dic[key] for dic in inputs_list]) + for key in inputs_list[0] + } + y0 = onp.vstack([model.y0_list[i].flatten() for i in range(nbatches)]) + y.extend(jax.vmap(self._cached_solves[model])(y0, inputs_v)) + else: + # Unknown platform, use serial execution as fallback + print( + f'Unknown platform requested: "{platform}", ' + "falling back to serial execution" ) - nproc = None - with mp.get_context(self._mp_context).Pool(processes=nproc) as p: - solns = p.map(solve_batch, range(nbatches)) - # flatten list of lists - solns = [x for xs in solns for x in xs] - + y = [] + for y0, inputs_v in zip(model.y0_list, inputs_list): + y.append(self._cached_solves[model](y0, inputs_v)) + + # This code block implements single-program multiple-data execution + # using pmap across multiple XLAs. It is currently commented out + # because it produces bus errors for even moderate-sized models. + # It is suspected that this is due to either a bug in JAX, insufficient + # sparse matrix support in JAX resulting in high memory usage, or a bug + # in the BDF solver. + # + # This issue on guthub appears related: + # https://github.com/google/jax/discussions/13930 + # + # # Split input list based on the number of available xla devices + # device_count = jax.local_device_count() + # inputs_listoflists = [inputs[x:x + device_count] + # for x in range(0, len(inputs), device_count)] + # if len(inputs_listoflists) > 1: + # print(f"{len(inputs)} parameter sets were provided, " + # f"but only {device_count} XLA devices are available") + # print(f"Parameter sets split into {len(inputs_listoflists)} " + # "lists for parallel processing") + # y = [] + # for k, inputs_list in enumerate(inputs_listoflists): + # if len(inputs_listoflists) > 1: + # print(f" Solving list {k+1} of {len(inputs_listoflists)} " + # f"({len(inputs_list)} parameter sets)") + # # convert inputs to a dict of arrays for pmap + # inputs_v = {key: jnp.array([dic[key] for dic in inputs_list]) + # for key in inputs_list[0]} + # y.extend(jax.pmap(self._cached_solves[model])(inputs_v)) integration_time = timer.time() - for sol in solns: - sol.integration_time = integration_time - return solns + termination = "final time" + t_event = None + y_event = onp.array(None) + + # Extract solutions from y with their associated input dicts + solutions = [] + for i in range(nbatches): + state_vec = onp.array(y[i]) + inputs = inputs_list[i * batch_size : (i + 1) * batch_size] + solution_batch = pybamm.Solution.from_concatenated_state( + t_eval, state_vec, model, inputs, t_event, y_event, termination + ) + for soln in solution_batch: + soln.integration_time = integration_time + solutions += solution_batch + + return solutions diff --git a/tests/unit/test_expression_tree/test_broadcasts.py b/tests/unit/test_expression_tree/test_broadcasts.py index be8fe1a677..43b01826a8 100644 --- a/tests/unit/test_expression_tree/test_broadcasts.py +++ b/tests/unit/test_expression_tree/test_broadcasts.py @@ -340,15 +340,16 @@ def test_diff(self): d = b.diff(a) self.assertIsInstance(d, pybamm.PrimaryBroadcast) self.assertEqual(d.child.evaluate(y=y), 1) - # diff of broadcast w.r.t. itself is 1 + # diff of broadcast w.r.t. itself is a vector of 1 d = b.diff(b) - self.assertIsInstance(d, pybamm.Scalar) - self.assertEqual(d.evaluate(y=y), 1) - # diff of broadcast of a constant is 0 + self.assertIsInstance(d, pybamm.Vector) + b_size = pybamm.domain_size(["separator"]) + np.testing.assert_array_equal(d.evaluate(y=y), np.ones((b_size, 1))) + # diff of broadcast of a constant is 0 vector c = pybamm.PrimaryBroadcast(pybamm.Scalar(4), "separator") d = c.diff(a) - self.assertIsInstance(d, pybamm.Scalar) - self.assertEqual(d.evaluate(y=y), 0) + self.assertIsInstance(d, pybamm.Vector) + np.testing.assert_array_equal(d.evaluate(y=y), np.zeros((b_size, 1))) def test_to_from_json_error(self): a = pybamm.StateVector(slice(0, 1)) diff --git a/tests/unit/test_solvers/test_base_solver.py b/tests/unit/test_solvers/test_base_solver.py index cf3bec2bc8..57287281ea 100644 --- a/tests/unit/test_solvers/test_base_solver.py +++ b/tests/unit/test_solvers/test_base_solver.py @@ -128,6 +128,10 @@ class ScalarModel: def __init__(self): self.y0_list = [np.array([2])] self.rhs = {} + self.len_rhs = 0 + self.len_rhs_sens = 0 + self.len_alg = 1 + self.len_alg_sens = 0 self.jac_algebraic_eval = None t = casadi.MX.sym("t") y = casadi.MX.sym("y") @@ -175,7 +179,10 @@ def __init__(self): self.convert_to_format = "casadi" self.bounds = (-np.inf * np.ones(4), np.inf * np.ones(4)) self.len_rhs = 1 - self.len_rhs_and_alg = 4 + self.len_rhs_sens = 0 + self.len_alg = len(vec) - 1 + self.len_alg_sens = 0 + self.events = [] self.batch_size = 1 @@ -222,6 +229,10 @@ def __init__(self): self.casadi_algebraic = casadi.Function( "alg", [t, y, p], [self.algebraic_eval(t, y, p)] ) + self.len_rhs = 0 + self.len_rhs_sens = 0 + self.len_alg = 1 + self.len_alg_sens = 0 self.convert_to_format = "casadi" self.bounds = (np.array([-np.inf]), np.array([np.inf])) self.batch_size = 1 diff --git a/tests/unit/test_solvers/test_casadi_algebraic_solver.py b/tests/unit/test_solvers/test_casadi_algebraic_solver.py index 5cec12e1ce..bc563d9ed9 100644 --- a/tests/unit/test_solvers/test_casadi_algebraic_solver.py +++ b/tests/unit/test_solvers/test_casadi_algebraic_solver.py @@ -56,6 +56,10 @@ class Model: p = casadi.MX.sym("p") rhs = {} casadi_algebraic = casadi.Function("alg", [t, y, p], [y**2 + 1]) + len_rhs = 0 + len_rhs_sens = 0 + len_alg = 1 + len_alg_sens = 0 bounds = (np.array([-np.inf]), np.array([np.inf])) interpolant_extrapolation_events_eval = [] batch_size = 1 @@ -84,6 +88,10 @@ class NaNModel: t = casadi.MX.sym("t") y = casadi.MX.sym("y") p = casadi.MX.sym("p") + len_rhs = 0 + len_rhs_sens = 0 + len_alg = 1 + len_alg_sens = 0 rhs = {} casadi_algebraic = casadi.Function("alg", [t, y, p], [y**0.5]) bounds = (np.array([-np.inf]), np.array([np.inf]))