diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index 81e74f89c9..434e1b91e7 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -1407,12 +1407,11 @@ def map_func_over_inputs(name, f, vars_for_processing, ninputs): dictionary of variables for processing ninputs: int number of inputs to map over - add_v: bool - whether to add a vector v to the inputs """ if f is None: return None + is_event = "event" in name add_v = name.endswith("_action") matrix_output = name.endswith("_jac") or name.endswith("_jacp") @@ -1448,6 +1447,10 @@ def map_func_over_inputs(name, f, vars_for_processing, ninputs): splits = [i * nstates for i in range(ninputs + 1)] split = casadi.horzsplit(mapped_f, splits) stack = casadi.diagcat(*split) + elif is_event: + # Events need to return a scalar, so we combine the vector output + # of the mapped function into a scalar output by calculating a smooth max of the vector output. + stack = casadi.logsumexp(casadi.transpose(mapped_f), 1e-4) else: # for vector outputs we need to stack them vertically in a single column vector splits = [i for i in range(ninputs + 1)] diff --git a/pybamm/solvers/solution.py b/pybamm/solvers/solution.py index fe1bbc6e49..8f5dc23c79 100644 --- a/pybamm/solvers/solution.py +++ b/pybamm/solvers/solution.py @@ -167,17 +167,18 @@ def split(self, rhs_len, alg_len, inputs_list): ] for p in range(ninputs) ] - y_events = [ - casadi.vertcat( - self.y_event[(p * rhs_len) : (p * rhs_len + rhs_len)], - self.y_event[ - (p * alg_len + ninputs * rhs_len) : ( - p * alg_len + ninputs * rhs_len + alg_len - ) - ], - ) - for p in range(ninputs) - ] + if self.y_event is not None: + y_events = [ + casadi.vertcat( + self.y_event[(p * rhs_len) : (p * rhs_len + rhs_len)], + self.y_event[ + (p * alg_len + ninputs * rhs_len) : ( + p * alg_len + ninputs * rhs_len + alg_len + ) + ], + ) + for p in range(ninputs) + ] else: all_ys_split = [ [ @@ -195,19 +196,20 @@ def split(self, rhs_len, alg_len, inputs_list): ] for p in range(ninputs) ] - y_events = [ - np.vstack( - [ - self.y_event[(p * rhs_len) : (p * rhs_len + rhs_len)], - self.y_event[ - (p * alg_len + ninputs * rhs_len) : ( - p * alg_len + ninputs * rhs_len + alg_len - ) - ], - ] - ) - for p in range(ninputs) - ] + if self.y_event is not None: + y_events = [ + np.vstack( + [ + self.y_event[(p * rhs_len) : (p * rhs_len + rhs_len)], + self.y_event[ + (p * alg_len + ninputs * rhs_len) : ( + p * alg_len + ninputs * rhs_len + alg_len + ) + ], + ] + ) + for p in range(ninputs) + ] ret = [ type(self)( diff --git a/tests/unit/test_solvers/test_casadi_solver.py b/tests/unit/test_solvers/test_casadi_solver.py index 6a79d639f2..5ad24810b6 100644 --- a/tests/unit/test_solvers/test_casadi_solver.py +++ b/tests/unit/test_solvers/test_casadi_solver.py @@ -396,6 +396,7 @@ def test_model_solver_with_inputs(self): t_eval = np.linspace(0, 10, 100) solution = solver.solve(model, t_eval, inputs={"rate": 0.1}) self.assertLess(len(solution.t), len(t_eval)) + single_len = len(solution.t) np.testing.assert_allclose( solution.y.full()[0], np.exp(-0.1 * solution.t), rtol=1e-04 ) @@ -407,6 +408,7 @@ def test_model_solver_with_inputs(self): self.assertEqual(len(solutions), 2) for solution, rate in zip(solutions, [0.1, 0.2]): self.assertLess(len(solution.t), len(t_eval)) + self.assertEqual(len(solution.t), single_len) np.testing.assert_allclose( solution.y.full()[0], np.exp(-rate * solution.t), rtol=1e-04 )