Skip to content

Commit

Permalink
change to stop if all events #4087
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jun 20, 2024
1 parent 9ce0917 commit 7727aef
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 26 deletions.
7 changes: 5 additions & 2 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1407,12 +1407,11 @@ def map_func_over_inputs(name, f, vars_for_processing, ninputs):
dictionary of variables for processing
ninputs: int
number of inputs to map over
add_v: bool
whether to add a vector v to the inputs
"""
if f is None:
return None

is_event = "event" in name
add_v = name.endswith("_action")
matrix_output = name.endswith("_jac") or name.endswith("_jacp")

Expand Down Expand Up @@ -1448,6 +1447,10 @@ def map_func_over_inputs(name, f, vars_for_processing, ninputs):
splits = [i * nstates for i in range(ninputs + 1)]
split = casadi.horzsplit(mapped_f, splits)
stack = casadi.diagcat(*split)
elif is_event:
# Events need to return a scalar, so we combine the vector output
# of the mapped function into a scalar output by calculating a smooth max of the vector output.
stack = casadi.logsumexp(casadi.transpose(mapped_f), 1e-4)
else:
# for vector outputs we need to stack them vertically in a single column vector
splits = [i for i in range(ninputs + 1)]
Expand Down
50 changes: 26 additions & 24 deletions pybamm/solvers/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,17 +167,18 @@ def split(self, rhs_len, alg_len, inputs_list):
]
for p in range(ninputs)
]
y_events = [
casadi.vertcat(
self.y_event[(p * rhs_len) : (p * rhs_len + rhs_len)],
self.y_event[
(p * alg_len + ninputs * rhs_len) : (
p * alg_len + ninputs * rhs_len + alg_len
)
],
)
for p in range(ninputs)
]
if self.y_event is not None:
y_events = [
casadi.vertcat(
self.y_event[(p * rhs_len) : (p * rhs_len + rhs_len)],
self.y_event[
(p * alg_len + ninputs * rhs_len) : (
p * alg_len + ninputs * rhs_len + alg_len
)
],
)
for p in range(ninputs)
]
else:
all_ys_split = [
[
Expand All @@ -195,19 +196,20 @@ def split(self, rhs_len, alg_len, inputs_list):
]
for p in range(ninputs)
]
y_events = [
np.vstack(
[
self.y_event[(p * rhs_len) : (p * rhs_len + rhs_len)],
self.y_event[
(p * alg_len + ninputs * rhs_len) : (
p * alg_len + ninputs * rhs_len + alg_len
)
],
]
)
for p in range(ninputs)
]
if self.y_event is not None:
y_events = [
np.vstack(
[
self.y_event[(p * rhs_len) : (p * rhs_len + rhs_len)],
self.y_event[
(p * alg_len + ninputs * rhs_len) : (
p * alg_len + ninputs * rhs_len + alg_len
)
],
]
)
for p in range(ninputs)
]

ret = [
type(self)(
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_solvers/test_casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ def test_model_solver_with_inputs(self):
t_eval = np.linspace(0, 10, 100)
solution = solver.solve(model, t_eval, inputs={"rate": 0.1})
self.assertLess(len(solution.t), len(t_eval))
single_len = len(solution.t)
np.testing.assert_allclose(
solution.y.full()[0], np.exp(-0.1 * solution.t), rtol=1e-04
)
Expand All @@ -407,6 +408,7 @@ def test_model_solver_with_inputs(self):
self.assertEqual(len(solutions), 2)
for solution, rate in zip(solutions, [0.1, 0.2]):
self.assertLess(len(solution.t), len(t_eval))
self.assertEqual(len(solution.t), single_len)
np.testing.assert_allclose(
solution.y.full()[0], np.exp(-rate * solution.t), rtol=1e-04
)
Expand Down

0 comments on commit 7727aef

Please sign in to comment.