Skip to content

Commit

Permalink
#4087 fixes to casadi and scipy solvers to support multiple inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jun 5, 2024
1 parent 69ca159 commit d09bf18
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 103 deletions.
24 changes: 15 additions & 9 deletions pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,15 +486,21 @@ def __call__(self, t=None, y=None, inputs=None):
y = y.reshape(-1, 1)

if isinstance(inputs, list):
ny = y.shape[0]
ni = len(inputs)
result = np.zeros((ni * ny, 1))
i = 0
for input in inputs:
result[i : i + ny] += self._evaluate(
self._constants, t, y[i : i + ny], input
)
i += ny
result = self._evaluate(self._constants, t, y, inputs[0])
if len(inputs) > 1:
if isinstance(result, numbers.Number):
result = np.array([result])
ny = result.shape[0]
ni = len(inputs)
results = np.zeros((ni * ny, 1))
results[:ny] = result
i = ny
for input in inputs[1:]:
results[i : i + ny] += self._evaluate(
self._constants, t, y[i : i + ny], input
)
i += ny
result = results
else:
result = self._evaluate(self._constants, t, y, inputs)

Expand Down
27 changes: 14 additions & 13 deletions pybamm/solvers/algebraic_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _integrate(self, model, t_eval, inputs_list=None):
inputs_list: list of dict, optional
Any input parameters to pass to the model when solving
"""
inputs_list = inputs_list or {}
inputs_list = inputs_list or [{}]
if model.convert_to_format == "casadi":
inputs = casadi.vertcat(
*[x for inputs in inputs_list for x in inputs.values()]
Expand Down Expand Up @@ -149,7 +149,7 @@ def jac_fn(y_alg, jac=jac):
if jac_fn is None:
jac_fn = "2-point"
timer.reset()
sol = optimize.least_squares(
solns = optimize.least_squares(
root_fun,
y0_alg,
method=method,
Expand Down Expand Up @@ -186,7 +186,7 @@ def jac_norm(y, jac_fn=jac_fn):
]
extra_options["bounds"] = bounds
timer.reset()
sol = optimize.minimize(
solns = optimize.minimize(
root_norm,
y0_alg,
method=method,
Expand All @@ -197,7 +197,7 @@ def jac_norm(y, jac_fn=jac_fn):
integration_time += timer.time()
else:
timer.reset()
sol = optimize.root(
solns = optimize.root(
root_fun,
y0_alg,
method=self.method,
Expand All @@ -207,32 +207,33 @@ def jac_norm(y, jac_fn=jac_fn):
)
integration_time += timer.time()

if sol.success and np.all(abs(sol.fun) < self.tol):
if solns.success and np.all(abs(solns.fun) < self.tol):
# update initial guess for the next iteration
y0_alg = sol.x
y0_alg = solns.x
# update solution array
y_alg[:, idx] = y0_alg
success = True
elif not sol.success:
elif not solns.success:
raise pybamm.SolverError(
f"Could not find acceptable solution: {sol.message}"
f"Could not find acceptable solution: {solns.message}"
)
else:
y0_alg = sol.x
y0_alg = solns.x
if itr > maxiter:
raise pybamm.SolverError(
"Could not find acceptable solution: solver terminated "
"successfully, but maximum solution error "
f"({np.max(abs(sol.fun))}) above tolerance ({self.tol})"
f"({np.max(abs(solns.fun))}) above tolerance ({self.tol})"
)
itr += 1

# Concatenate differential part
y_diff = np.r_[[y0_diff] * len(t_eval)].T
y_sol = np.r_[y_diff, y_alg]
# Return solution object (no events, so pass None to t_event, y_event)
sol = pybamm.Solution.from_concatenated_state(
solns = pybamm.Solution.from_concatenated_state(
t_eval, y_sol, model, inputs_list, termination="final time"
)
sol.integration_time = integration_time
return sol
for s in solns:
s.integration_time = integration_time
return solns
Loading

0 comments on commit d09bf18

Please sign in to comment.