Skip to content

Commit

Permalink
#2858 fix immediate casadi errors
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Apr 5, 2023
1 parent 0c12826 commit cb1d8d9
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 12 deletions.
31 changes: 21 additions & 10 deletions pybamm/solvers/casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,9 +516,11 @@ def create_integrator(self, model, inputs, t_eval=None, use_event_switch=False):
if t_eval_shifted_rounded in self.integrators[model]:
return self.integrators[model][t_eval_shifted_rounded]
else:
method, problem, options = self.integrator_specs[model]
options["grid"] = t_eval_shifted
integrator = casadi.integrator("F", method, problem, options)
method, problem, options, time_args = self.integrator_specs[model]
time_args = [t_eval_shifted[0], t_eval_shifted[1:]]
integrator = casadi.integrator(
"F", method, problem, *time_args, options
)
self.integrators[model][t_eval_shifted_rounded] = integrator
return integrator
else:
Expand All @@ -542,6 +544,7 @@ def create_integrator(self, model, inputs, t_eval=None, use_event_switch=False):
y_full = casadi.vertcat(y_diff, y_alg)

if use_grid is False:
time_args = []
# rescale time
t_min = casadi.MX.sym("t_min")
t_max = casadi.MX.sym("t_max")
Expand All @@ -550,7 +553,7 @@ def create_integrator(self, model, inputs, t_eval=None, use_event_switch=False):
# add time limits as inputs
p_with_tlims = casadi.vertcat(p, t_min, t_max)
else:
options.update({"grid": t_eval_shifted, "output_t0": True})
time_args = [t_eval_shifted[0], t_eval_shifted[1:]]
# rescale time
t_min = casadi.MX.sym("t_min")
# Set dummy parameters for consistency with rescaled time
Expand Down Expand Up @@ -583,8 +586,8 @@ def create_integrator(self, model, inputs, t_eval=None, use_event_switch=False):
"alg": algebraic(t_scaled, y_full, p),
}
)
integrator = casadi.integrator("F", method, problem, options)
self.integrator_specs[model] = method, problem, options
integrator = casadi.integrator("F", method, problem, *time_args, options)
self.integrator_specs[model] = method, problem, options, time_args
if use_grid is False:
self.integrators[model] = {"no grid": integrator}
else:
Expand Down Expand Up @@ -655,13 +658,15 @@ def _run_integrator(
len_alg = len_alg * (num_parameters + 1)

y0_diff = y0[:len_rhs]
y0_alg = y0[len_rhs:]
y0_alg_exact = y0[len_rhs:]
if self.perturb_algebraic_initial_conditions and len_alg > 0:
# Add a tiny perturbation to the algebraic initial conditions
# For some reason this helps with convergence
# The actual value of the initial conditions for the algebraic variables
# doesn't matter
y0_alg = y0_alg * (1 + 1e-6 * casadi.DM(np.random.rand(len_alg)))
y0_alg = y0_alg_exact * (1 + 1e-6 * casadi.DM(np.random.rand(len_alg)))
else:
y0_alg = y0_alg_exact
pybamm.logger.spam("Finished preliminary setup for integrator run")

# Solve
Expand All @@ -682,7 +687,13 @@ def _run_integrator(
raise pybamm.SolverError(error.args[0])
pybamm.logger.debug("Finished casadi integrator")
integration_time = timer.time()
y_sol = casadi.vertcat(casadi_sol["xf"], casadi_sol["zf"])
# Manually add initial conditions and concatenate
x_sol = casadi.horzcat(y0_diff, casadi_sol["xf"])
if len_alg > 0:
z_sol = casadi.horzcat(y0_alg_exact, casadi_sol["zf"])
y_sol = casadi.vertcat(x_sol, z_sol)
else:
y_sol = x_sol
sol = pybamm.Solution(
t_eval,
y_sol,
Expand All @@ -696,7 +707,7 @@ def _run_integrator(
else:
# Repeated calls to the integrator
x = y0_diff
z = y0_alg
z = y0_alg_exact
y_diff = x
y_alg = z
for i in range(len(t_eval) - 1):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_solvers/test_base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def algebraic_eval(self, t, y, inputs):
# with casadi
solver = pybamm.BaseSolver(root_method="casadi")
with self.assertRaisesRegex(
pybamm.SolverError, "Could not find acceptable solution: .../casadi"
pybamm.SolverError, "Could not find acceptable solution: Error in Function"
):
solver.calculate_consistent_state(Model())

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_solvers/test_casadi_algebraic_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def algebraic_eval(self, t, y, inputs):

solver = pybamm.CasadiAlgebraicSolver()
with self.assertRaisesRegex(
pybamm.SolverError, "Could not find acceptable solution: .../casadi"
pybamm.SolverError, "Could not find acceptable solution: Error in Function"
):
solver._integrate(model, np.array([0]), {})
solver = pybamm.CasadiAlgebraicSolver(extra_options={"error_on_fail": False})
Expand Down

0 comments on commit cb1d8d9

Please sign in to comment.