diff --git a/pybamm/expression_tree/operations/evaluate_python.py b/pybamm/expression_tree/operations/evaluate_python.py index 7261f5e8fc..61c03279f1 100644 --- a/pybamm/expression_tree/operations/evaluate_python.py +++ b/pybamm/expression_tree/operations/evaluate_python.py @@ -424,6 +424,9 @@ def to_python( + "; print(type({0}),np.shape({0}))".format( id_to_python_variable(symbol_id, False) ) + # + "; jax.debug.print(\"{0} = {{x}}\", x={0}.flatten())".format( + # id_to_python_variable(symbol_id, False) + # ) for symbol_id, symbol_line in variable_symbols.items() ] else: @@ -523,15 +526,13 @@ def __call__(self, t=None, y=None, inputs=None): else: nstates = y.shape[0] // self._ninputs nparams = len(inputs) // self._ninputs - print("nstates:", nstates) - print("nparams:", nparams) results = [ self._evaluate( self._constants, t, y[i * nstates : (i + 1) * nstates], - input[i * nparams : (i + 1) * nparams], + inputs[i * nparams : (i + 1) * nparams], ) for i in range(self._ninputs) ] @@ -665,15 +666,19 @@ def __init__( exec(compiled_function) # use vmap to vectorize the function over the inputs if ninputs > 1 - in_axes = ([None] * len(self._arg_list)) + [None, None, 0] + in_axes = ([None] * len(self._arg_list)) + [None, 0, 0] out_axes = 0 ninputs = len(inputs) if ninputs > 1: if is_event: def mapped_evaluate_jax_event(*args): - # change inputs to a 2d array for vmap (inputs is the last arg) - args[-1] = args[-1].reshape(ninputs, -1) + # change inputs and y to a 2d array for vmap (inputs is the last arg) + args = ( + *args[:-2], + args[-2].reshape(ninputs, -1), + args[-1].reshape(ninputs, -1), + ) # exectute the mapped function results = jax.vmap( @@ -685,13 +690,17 @@ def mapped_evaluate_jax_event(*args): alpha = jax.numpy.log(ninputs) / margin return jax.scipy.special.logsumexp(alpha * results) / alpha - self._evaluate_jax = mapped_evaluate_jax_event + self._mapped_evaluate_jax = mapped_evaluate_jax_event else: def mapped_evaluate_jax(*args): - # change inputs to a 2d array for vmap (inputs is the last arg) - args[-1] = args[-1].reshape(ninputs, -1) + # change inputs and y to a 2d array for vmap (inputs is the last arg) + args = ( + *args[:-2], + args[-2].reshape(ninputs, -1), + args[-1].reshape(ninputs, -1), + ) # exectute the mapped function results = jax.vmap( @@ -701,11 +710,13 @@ def mapped_evaluate_jax(*args): # reshape to a column vector return results.reshape(-1, 1) - self._evaluate_jax = mapped_evaluate_jax + self._mapped_evaluate_jax = mapped_evaluate_jax + else: + self._mapped_evaluate_jax = self._evaluate_jax self._static_argnums = tuple(static_argnums) self._jit_evaluate = jax.jit( - self._evaluate_jax, # type:ignore[attr-defined] + self._mapped_evaluate_jax, # type:ignore[attr-defined] static_argnums=self._static_argnums, ) @@ -715,7 +726,7 @@ def get_jacobian(self): return self._get_jacfwd(1 + n) def _get_jacfwd(self, argnum): - jacobian_evaluate = jax.jacfwd(self._evaluate_jax, argnums=argnum) + jacobian_evaluate = jax.jacfwd(self._mapped_evaluate_jax, argnums=argnum) self._jac_evaluate = jax.jit( jacobian_evaluate, static_argnums=self._static_argnums @@ -737,7 +748,9 @@ def debug(self, t=None, y=None, inputs=None): y = y.reshape(-1, 1) # execute code - jaxpr = jax.make_jaxpr(self._evaluate_jax)(*self._constants, t, y, inputs).jaxpr + jaxpr = jax.make_jaxpr(self._mapped_evaluate_jax)( + *self._constants, t, y, inputs + ).jaxpr print("invars:", jaxpr.invars) print("outvars:", jaxpr.outvars) print("constvars:", jaxpr.constvars) diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index e005d50cf4..d41f84d17f 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -603,13 +603,7 @@ def _set_initial_conditions(self, model, time, inputs_list, update_rhs): ) y_zero = np.zeros((y0_total_size, 1)) - if model.convert_to_format == "casadi": - # stack inputs - inputs = casadi.vertcat( - *[x for inpts in inputs_list for x in inpts.values()] - ) - else: - inputs = inputs_list + inputs = self._inputs_to_stacked_vect(inputs_list, model.convert_to_format) if self.algebraic_solver is True: # Don't update model.y0 @@ -1407,7 +1401,6 @@ def _inputs_to_stacked_vect(inputs_list: list[dict], convert_to_format: str): for inputs in inputs_list for x in inputs.values() ] - print(inputs_list, arrays_to_stack) inputs = np.vstack(arrays_to_stack) return inputs @@ -1495,75 +1488,6 @@ def map_func_over_inputs_casadi(name, f, vars_for_processing, ninputs): return casadi.Function(name, inputs_stacked, [stack]) -def map_func_over_inputs_jax(name, f, vars_for_processing, ninputs): - """ - This takes a casadi function f and returns a new casadi function that maps f over - the provided number of inputs. Some functions (e.g. jacobian action) require an additional - vector input v, which is why add_v is provided. - - Parameters - ---------- - name: str - name of the new function. This must end in the string "_action" for jacobian action functions, - "_jac" for jacobian functions, or "_jacp" for jacp functions. - f: casadi.Function - function to map - vars_for_processing: dict - dictionary of variables for processing - ninputs: int - number of inputs to map over - """ - if f is None: - return None - - is_event = "event" in name - add_v = name.endswith("_action") - matrix_output = name.endswith("_jac") or name.endswith("_jacp") - - nstates = vars_for_processing["y_and_S"].shape[0] - nparams = vars_for_processing["p_casadi_stacked"].shape[0] - - parallelisation = "thread" - y_and_S_inputs_stacked = casadi.MX.sym("y_and_S_stacked", nstates * ninputs) - p_casadi_inputs_stacked = casadi.MX.sym("p_stacked", nparams * ninputs) - v_inputs_stacked = casadi.MX.sym("v_stacked", nstates * ninputs) - - y_and_S_2d = y_and_S_inputs_stacked.reshape((nstates, ninputs)) - p_casadi_2d = p_casadi_inputs_stacked.reshape((nparams, ninputs)) - v_2d = v_inputs_stacked.reshape((nstates, ninputs)) - - t_casadi = vars_for_processing["t_casadi"] - - if add_v: - inputs_2d = [t_casadi, y_and_S_2d, p_casadi_2d, v_2d] - inputs_stacked = [ - t_casadi, - y_and_S_inputs_stacked, - p_casadi_inputs_stacked, - v_inputs_stacked, - ] - else: - inputs_2d = [t_casadi, y_and_S_2d, p_casadi_2d] - inputs_stacked = [t_casadi, y_and_S_inputs_stacked, p_casadi_inputs_stacked] - - mapped_f = f.map(ninputs, parallelisation)(*inputs_2d) - if matrix_output: - # for matrix output we need to stack the outputs in a block diagonal matrix - splits = [i * nstates for i in range(ninputs + 1)] - split = casadi.horzsplit(mapped_f, splits) - stack = casadi.diagcat(*split) - elif is_event: - # Events need to return a scalar, so we combine the vector output - # of the mapped function into a scalar output by calculating a smooth max of the vector output. - stack = casadi.logsumexp(casadi.transpose(mapped_f), 1e-4) - else: - # for vector outputs we need to stack them vertically in a single column vector - splits = [i for i in range(ninputs + 1)] - split = casadi.horzsplit(mapped_f, splits) - stack = casadi.vertcat(*split) - return casadi.Function(name, inputs_stacked, [stack]) - - def process( symbol, name, diff --git a/pybamm/solvers/casadi_solver.py b/pybamm/solvers/casadi_solver.py index 5c2e720f65..feba20a95b 100644 --- a/pybamm/solvers/casadi_solver.py +++ b/pybamm/solvers/casadi_solver.py @@ -175,7 +175,9 @@ def _integrate(self, model, t_eval, inputs_list=None): solution = self._solve_for_event(solution, inputs_list) solution.check_ys_are_not_too_large() - return solution.split(model.len_rhs, model.len_alg, inputs_list) + return solution.split( + model.len_rhs, model.len_alg, inputs_list, is_casadi_solver=True + ) elif self.mode in ["safe", "safe without grid"]: y0 = model.y0 # Step-and-check @@ -314,7 +316,9 @@ def _integrate(self, model, t_eval, inputs_list=None): if bool(model.calculate_sensitivities): solution.sensitivities = True solution.check_ys_are_not_too_large() - return solution.split(model.len_rhs, model.len_alg, inputs_list) + return solution.split( + model.len_rhs, model.len_alg, inputs_list, is_casadi_solver=True + ) def _solve_for_event(self, coarse_solution, inputs_list): """ diff --git a/pybamm/solvers/jax_bdf_solver.py b/pybamm/solvers/jax_bdf_solver.py index 2c7bdc6d17..732aa414cf 100644 --- a/pybamm/solvers/jax_bdf_solver.py +++ b/pybamm/solvers/jax_bdf_solver.py @@ -876,17 +876,11 @@ def initialise(g0, y0, t0): def arg_to_identity(arg): return onp.identity(arg.shape[0] if arg.ndim > 0 else 1, dtype=arg.dtype) - def arg_dicts_to_values(args): - """ - Note:JAX puts in empty arrays into args for some reason, we remove them here - """ - return sum((tuple(b.values()) for b in args if isinstance(b, dict)), ()) - aug_mass = ( mass, mass, onp.array(1.0), - *arg_dicts_to_values(tree_map(arg_to_identity, args)), + *tree_map(arg_to_identity, args), ) def scan_fun(carry, i): diff --git a/pybamm/solvers/jax_solver.py b/pybamm/solvers/jax_solver.py index fbe047b3cc..9326541365 100644 --- a/pybamm/solvers/jax_solver.py +++ b/pybamm/solvers/jax_solver.py @@ -2,7 +2,6 @@ # Solver class using Scipy's adaptive time stepper # import numpy as onp -import asyncio import pybamm @@ -148,6 +147,20 @@ def create_solve(self, model, t_eval): mass = None if self.method == "BDF": mass = model.mass_matrix.entries.toarray() + print("stuff", mass.shape, model.len_rhs, model.len_alg) + + def stack_inputs(inputs: dict | list[dict]): + if isinstance(inputs, dict): + return jnp.array([x.reshape(-1, 1) for x in inputs.values()]) + if len(inputs) == 1: + return jnp.array([x.reshape(-1, 1) for x in inputs[0].values()]) + arrays_to_stack = [ + jnp.array(x).reshape(-1, 1) + for inputs in inputs + for x in inputs.values() + ] + print("stacking", len(arrays_to_stack), arrays_to_stack[0].shape) + return jnp.vstack(arrays_to_stack) def rhs_ode(y, t, inputs): return (model.rhs_eval(t, y, inputs),) @@ -162,7 +175,7 @@ def solve_model_rk45(inputs): rhs_ode, y0, t_eval, - inputs, + stack_inputs(inputs), rtol=self.rtol, atol=self.atol, **self.extra_options, @@ -174,7 +187,7 @@ def solve_model_bdf(inputs): rhs_dae, y0, t_eval, - inputs, + stack_inputs(inputs), rtol=self.rtol, atol=self.atol, mass=mass, @@ -197,7 +210,7 @@ def _integrate(self, model, t_eval, inputs=None): The model whose solution to calculate. t_eval : :class:`numpy.array`, size (k,) The times at which to compute the solution - inputs : dict, list[dict], optional + inputs_list : list[dict], optional Any input parameters to pass to the model when solving Returns @@ -207,77 +220,13 @@ def _integrate(self, model, t_eval, inputs=None): various diagnostic messages. """ - if isinstance(inputs, dict): - inputs = [inputs] + inputs = inputs or [{}] + timer = pybamm.Timer() if model not in self._cached_solves: self._cached_solves[model] = self.create_solve(model, t_eval) - y = [] - platform = jax.lib.xla_bridge.get_backend().platform.casefold() - if len(inputs) <= 1 or platform.startswith("cpu"): - # cpu execution runs faster when multithreaded - async def solve_model_for_inputs(): - async def solve_model_async(inputs_v): - return self._cached_solves[model](inputs_v) - - coro = [] - for inputs_v in inputs: - coro.append(asyncio.create_task(solve_model_async(inputs_v))) - return await asyncio.gather(*coro) - - y = asyncio.run(solve_model_for_inputs()) - elif ( - platform.startswith("gpu") - or platform.startswith("tpu") - or platform.startswith("metal") - ): - # gpu execution runs faster when parallelised with vmap - # (see also comment below regarding single-program multiple-data - # execution (SPMD) using pmap on multiple XLAs) - - # convert inputs (array of dict) to a dict of arrays for vmap - inputs_v = { - key: jnp.array([dic[key] for dic in inputs]) for key in inputs[0] - } - y.extend(jax.vmap(self._cached_solves[model])(inputs_v)) - else: - # Unknown platform, use serial execution as fallback - print( - f'Unknown platform requested: "{platform}", ' - "falling back to serial execution" - ) - for inputs_v in inputs: - y.append(self._cached_solves[model](inputs_v)) - - # This code block implements single-program multiple-data execution - # using pmap across multiple XLAs. It is currently commented out - # because it produces bus errors for even moderate-sized models. - # It is suspected that this is due to either a bug in JAX, insufficient - # sparse matrix support in JAX resulting in high memory usage, or a bug - # in the BDF solver. - # - # This issue on guthub appears related: - # https://github.com/google/jax/discussions/13930 - # - # # Split input list based on the number of available xla devices - # device_count = jax.local_device_count() - # inputs_listoflists = [inputs[x:x + device_count] - # for x in range(0, len(inputs), device_count)] - # if len(inputs_listoflists) > 1: - # print(f"{len(inputs)} parameter sets were provided, " - # f"but only {device_count} XLA devices are available") - # print(f"Parameter sets split into {len(inputs_listoflists)} " - # "lists for parallel processing") - # y = [] - # for k, inputs_list in enumerate(inputs_listoflists): - # if len(inputs_listoflists) > 1: - # print(f" Solving list {k+1} of {len(inputs_listoflists)} " - # f"({len(inputs_list)} parameter sets)") - # # convert inputs to a dict of arrays for pmap - # inputs_v = {key: jnp.array([dic[key] for dic in inputs_list]) - # for key in inputs_list[0]} - # y.extend(jax.pmap(self._cached_solves[model])(inputs_v)) + y = self._cached_solves[model](inputs) integration_time = timer.time() @@ -285,24 +234,14 @@ async def solve_model_async(inputs_v): y = onp.array(y) termination = "final time" - t_event = None - y_event = onp.array(None) - # Extract solutions from y with their associated input dicts - solutions = [] - for k, inputs_dict in enumerate(inputs): - sol = pybamm.Solution( - t_eval, - jnp.reshape(y[k,], y.shape[1:]), - model, - inputs_dict, - t_event, - y_event, - termination, - ) - sol.integration_time = integration_time - solutions.append(sol) - - if len(solutions) == 1: - return solutions[0] - return solutions + sol = pybamm.Solution( + t_eval, + y, + model, + inputs[0], + termination=termination, + check_solution=False, + ) + sol.integration_time = integration_time + return sol.split(model.len_rhs, model.len_alg, inputs) diff --git a/pybamm/solvers/solution.py b/pybamm/solvers/solution.py index 8f5dc23c79..10dda89e9e 100644 --- a/pybamm/solvers/solution.py +++ b/pybamm/solvers/solution.py @@ -138,7 +138,7 @@ def __init__( # Solution now uses CasADi pybamm.citations.register("Andersson2019") - def split(self, rhs_len, alg_len, inputs_list): + def split(self, rhs_len, alg_len, inputs_list, is_casadi_solver=False): """ split up the concatenated solution into a list of solutions for each input the state vector is assumed to have the form: @@ -151,7 +151,7 @@ def split(self, rhs_len, alg_len, inputs_list): if ninputs == 1: return [self] - if isinstance(self.all_ys[0], (casadi.DM, casadi.MX)): + if is_casadi_solver: all_ys_split = [ [ casadi.vertcat( @@ -179,37 +179,24 @@ def split(self, rhs_len, alg_len, inputs_list): ) for p in range(ninputs) ] + else: + y_events = [None] * ninputs else: + state_len = rhs_len + alg_len all_ys_split = [ [ - np.vstack( - [ - self.all_ys[i][(p * rhs_len) : (p * rhs_len + rhs_len)], - self.all_ys[i][ - (p * alg_len + ninputs * rhs_len) : ( - p * alg_len + ninputs * rhs_len + alg_len - ) - ], - ] - ) + self.all_ys[i][(p * state_len) : (p * state_len + state_len)] for i in range(len(self.all_ys)) ] for p in range(ninputs) ] if self.y_event is not None: y_events = [ - np.vstack( - [ - self.y_event[(p * rhs_len) : (p * rhs_len + rhs_len)], - self.y_event[ - (p * alg_len + ninputs * rhs_len) : ( - p * alg_len + ninputs * rhs_len + alg_len - ) - ], - ] - ) + self.y_event[(p * state_len) : (p * state_len + state_len)] for p in range(ninputs) ] + else: + y_events = [None] * ninputs ret = [ type(self)( diff --git a/tests/unit/test_solvers/test_jax_bdf_solver.py b/tests/unit/test_solvers/test_jax_bdf_solver.py index 854a618fba..6ce3dccc3c 100644 --- a/tests/unit/test_solvers/test_jax_bdf_solver.py +++ b/tests/unit/test_solvers/test_jax_bdf_solver.py @@ -88,20 +88,23 @@ def test_solver_sensitivities(self): # Solve t_eval = np.linspace(0, 10, 4) y0 = model.concatenated_initial_conditions.evaluate().reshape(-1) - rhs = pybamm.EvaluatorJax(model.concatenated_rhs) + + rate = 0.1 + inputs = {"rate": rate} + rhs = pybamm.EvaluatorJax(model.concatenated_rhs, inputs=[inputs]) + inputs_stacked = jax.numpy.array([rate]) def fun(y, t, inputs): - return rhs(t=t, y=y, inputs=inputs).reshape(-1) + return rhs(t=t, y=y, inputs=inputs_stacked).reshape(-1) h = 0.0001 - rate = 0.1 # create a dummy "model" where we calculate the sum of the time series @jax.jit def solve_bdf(rate): return jax.numpy.sum( pybamm.jax_bdf_integrate( - fun, y0, t_eval, {"rate": rate}, rtol=1e-9, atol=1e-9 + fun, y0, t_eval, jax.numpy.array([rate]), rtol=1e-9, atol=1e-9 ) ) @@ -166,14 +169,14 @@ def test_solver_with_inputs(self): # Solve t_eval = np.linspace(0, 10, 80) y0 = model.concatenated_initial_conditions.evaluate().reshape(-1) - rhs = pybamm.EvaluatorJax(model.concatenated_rhs) + inputs = {"rate": 0.1} + rhs = pybamm.EvaluatorJax(model.concatenated_rhs, inputs=[inputs]) def fun(y, t, inputs): return rhs(t=t, y=y, inputs=inputs).reshape(-1) - y = pybamm.jax_bdf_integrate( - fun, y0, t_eval, {"rate": 0.1}, rtol=1e-9, atol=1e-9 - ) + inputs = np.array([[0.1]]) + y = pybamm.jax_bdf_integrate(fun, y0, t_eval, inputs, rtol=1e-9, atol=1e-9) np.testing.assert_allclose(y[:, 0].reshape(-1), np.exp(-0.1 * t_eval)) diff --git a/tests/unit/test_solvers/test_jax_solver.py b/tests/unit/test_solvers/test_jax_solver.py index 9df28e8ac2..56c6ccada4 100644 --- a/tests/unit/test_solvers/test_jax_solver.py +++ b/tests/unit/test_solvers/test_jax_solver.py @@ -205,6 +205,34 @@ def test_model_solver_with_inputs(self): solution.y[0], np.exp(-0.2 * solution.t), rtol=1e-6, atol=1e-6 ) + def test_model_solver_multiple_inputs_jax_format(self): + # Create model + model = pybamm.BaseModel() + model.convert_to_format = "jax" + domain = ["negative electrode", "separator", "positive electrode"] + var = pybamm.Variable("var", domain=domain) + model.rhs = {var: -pybamm.InputParameter("rate") * var} + model.initial_conditions = {var: 1} + # create discretisation + mesh = get_mesh_for_testing() + spatial_methods = {"macroscale": pybamm.FiniteVolume()} + disc = pybamm.Discretisation(mesh, spatial_methods) + disc.process_model(model) + + solver = pybamm.JaxSolver(rtol=1e-8, atol=1e-8, method="RK45") + t_eval = np.linspace(0, 10, 100) + ninputs = 8 + inputs_list = [{"rate": 0.01 * (i + 1)} for i in range(ninputs)] + + solutions = solver.solve(model, t_eval, inputs=inputs_list, nproc=2) + for i in range(ninputs): + with self.subTest(i=i): + solution = solutions[i] + np.testing.assert_array_equal(solution.t, t_eval) + np.testing.assert_allclose( + solution.y[0], np.exp(-0.01 * (i + 1) * solution.t) + ) + def test_get_solve(self): # Create model model = pybamm.BaseModel() diff --git a/tests/unit/test_solvers/test_scipy_solver.py b/tests/unit/test_solvers/test_scipy_solver.py index e5647288dc..36cea54be2 100644 --- a/tests/unit/test_solvers/test_scipy_solver.py +++ b/tests/unit/test_solvers/test_scipy_solver.py @@ -315,59 +315,69 @@ def test_model_solver_multiple_inputs_discontinuity_error(self): def test_model_solver_multiple_inputs_initial_conditions(self): # Create model - model = pybamm.BaseModel() - model.convert_to_format = "casadi" - domain = ["negative electrode", "separator", "positive electrode"] - var = pybamm.Variable("var", domain=domain) - rate = pybamm.InputParameter("rate") - model.rhs = {var: -rate * var} - model.initial_conditions = {var: 2 * rate} - # create discretisation - mesh = get_mesh_for_testing() - spatial_methods = {"macroscale": pybamm.FiniteVolume()} - disc = pybamm.Discretisation(mesh, spatial_methods) - disc.process_model(model) - - solver = pybamm.ScipySolver(rtol=1e-8, atol=1e-8, method="RK45") - t_eval = np.linspace(0, 10, 100) - ninputs = 8 - inputs_list = [{"rate": 0.01 * (i + 1)} for i in range(ninputs)] - - solutions = solver.solve(model, t_eval, inputs=inputs_list, nproc=2) - for inputs, solution in zip(inputs_list, solutions): - np.testing.assert_array_equal(solution.t, t_eval) - np.testing.assert_allclose( - solution.y[0], 2 * inputs["rate"] * np.exp(-inputs["rate"] * solution.t) - ) - - def test_model_solver_multiple_inputs_jax_format(self): + formats = ["python", "casadi"] if pybamm.have_jax(): - # Create model + formats.append("jax") + formats = ["jax"] + for convert_to_format in formats: + print(convert_to_format) model = pybamm.BaseModel() - model.convert_to_format = "jax" + model.convert_to_format = convert_to_format domain = ["negative electrode", "separator", "positive electrode"] var = pybamm.Variable("var", domain=domain) - model.rhs = {var: -pybamm.InputParameter("rate") * var} - model.initial_conditions = {var: 1} + rate = pybamm.InputParameter("rate") + model.rhs = {var: -rate * var} + model.initial_conditions = {var: 2 * rate} # create discretisation mesh = get_mesh_for_testing() spatial_methods = {"macroscale": pybamm.FiniteVolume()} disc = pybamm.Discretisation(mesh, spatial_methods) disc.process_model(model) - solver = pybamm.JaxSolver(rtol=1e-8, atol=1e-8, method="RK45") + solver = pybamm.ScipySolver(rtol=1e-8, atol=1e-8, method="RK45") t_eval = np.linspace(0, 10, 100) - ninputs = 8 + ninputs = 2 inputs_list = [{"rate": 0.01 * (i + 1)} for i in range(ninputs)] solutions = solver.solve(model, t_eval, inputs=inputs_list, nproc=2) - for i in range(ninputs): - with self.subTest(i=i): - solution = solutions[i] - np.testing.assert_array_equal(solution.t, t_eval) - np.testing.assert_allclose( - solution.y[0], np.exp(-0.01 * (i + 1) * solution.t) - ) + + inputs_stacked = solver._inputs_to_stacked_vect( + inputs_list, convert_to_format=convert_to_format + ) + + # check initial conditions + y0 = np.vstack([s.y[:, 0].reshape(-1, 1) for s in solutions]) + np.testing.assert_allclose(y0, model.y0) + y0_fun = model.initial_conditions_eval(t_eval[0], y0, inputs_stacked) + np.testing.assert_allclose(y0, y0_fun) + for i, inputs in enumerate(inputs_list): + n = model.len_rhs + y0_slice = y0[i * n : (i + 1) * n] + np.testing.assert_allclose( + y0_slice, + 2 * inputs["rate"] * np.ones((n, 1)), + err_msg="failed for rate = {}".format(inputs["rate"]), + ) + + # check rhs equation + rhs_eval = model.rhs_eval(t_eval[0], y0, inputs_stacked) + for i, inputs in enumerate(inputs_list): + n = model.len_rhs + y_slice = rhs_eval[i * n : (i + 1) * n] + y0_slice = y0[i * n : (i + 1) * n] + np.testing.assert_allclose( + y_slice, + -inputs["rate"] * y0_slice, + err_msg="failed for rate = {}".format(inputs["rate"]), + ) + + # check solution + for inputs, solution in zip(inputs_list, solutions): + np.testing.assert_array_equal(solution.t, t_eval) + np.testing.assert_allclose( + solution.y[0], + 2 * inputs["rate"] * np.exp(-inputs["rate"] * solution.t), + ) def test_model_solver_with_event_with_casadi(self): # Create model