Skip to content

Commit

Permalink
#1863 #1898 got scipy solver mostly working
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jan 14, 2022
1 parent 1c4b04d commit fda76c9
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 16 deletions.
2 changes: 1 addition & 1 deletion pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def __call__(self, t=None, y=None, inputs=None):

result = self._evaluate(self._constants, t, y, inputs)

return result.flatten()
return result

def __getstate__(self):
# Control the state of instances of EvaluatorPython
Expand Down
35 changes: 25 additions & 10 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,6 @@ def report(string):
else:
jac = None


elif model.convert_to_format != "casadi":
# Process with pybamm functions, converting
# to python evaluator
Expand Down Expand Up @@ -367,7 +366,8 @@ def jacp(*args, **kwargs):
else:
# Process with CasADi
report(f"Converting {name} to CasADi")
casadi_expression = symbol.to_casadi(t_casadi, y_casadi, inputs=p_casadi)
casadi_expression = symbol.to_casadi(
t_casadi, y_casadi, inputs=p_casadi)
# Add sensitivity vectors to the rhs and algebraic equations
jacp = None
if calculate_sensitivities_explicit:
Expand Down Expand Up @@ -462,7 +462,6 @@ def jacp(*args, **kwargs):
else:
jac = None


func = casadi.Function(
name, [t_casadi, y_and_S, p_casadi_stacked], [casadi_expression]
)
Expand All @@ -476,6 +475,18 @@ def jacp(*args, **kwargs):
use_jacobian=False,
)[0]

# evaluate initial condition
y0_total_size = (
model.len_rhs + model.len_rhs_sens + model.len_alg + model.len_alg_sens
)
y_zero = np.zeros((y0_total_size, 1))
if model.convert_to_format == "casadi":
# stack inputs
inputs_casadi = casadi.vertcat(*[x for x in inputs.values()])
model.y0 = initial_conditions(0, y_zero, inputs_casadi)
else:
model.y0 = initial_conditions(0, y_zero, inputs)

if ics_only:
pybamm.logger.info("Finish solver set-up")
return
Expand Down Expand Up @@ -550,8 +561,6 @@ def jacp(*args, **kwargs):
)
)



# Process rhs, algebraic, residual and event expressions
# and wrap in callables
rhs, jac_rhs, jacp_rhs = process(model.concatenated_rhs, "RHS")
Expand All @@ -560,7 +569,6 @@ def jacp(*args, **kwargs):
model.concatenated_algebraic, "algebraic"
)


# combine rhs and algebraic functions
if len(model.rhs) == 0:
rhs_algebraic = model.concatenated_rhs
Expand Down Expand Up @@ -623,15 +631,22 @@ def jacp(*args, **kwargs):
# Add the solver attributes
model.rhs_eval = rhs
model.algebraic_eval = algebraic
model.rhs_algebraic_eval = rhs_algebraic

model.terminate_events_eval = terminate_events
model.discontinuity_events_eval = discontinuity_events
model.interpolant_extrapolation_events_eval = interpolant_extrapolation_events
model.rhs_algebraic_eval = rhs_algebraic

model.jac_rhs_eval = jac_rhs
model.jacp_rhs_eval = jacp_rhs

model.jacp_algebraic_eval = jacp_algebraic
model.jacp_algebraic_eval = jacp_algebraic

model.jac_rhs_algebraic_eval = jac_rhs_algebraic
model.jacp_rhs_algebraic_eval = jacp_rhs_algebraic
model.initial_conditions_eval = initial_conditions


# Save CasADi functions for the CasADi solver
# Save CasADi functions for solvers that use CasADi
# Note: when we pass to casadi the ode part of the problem must be in
Expand All @@ -653,8 +668,6 @@ def jacp(*args, **kwargs):
model.casadi_sensitivities_rhs = jacp_rhs
model.casadi_sensitivities_algebraic = jacp_algebraic



pybamm.logger.info("Finish solver set-up")

def _set_initial_conditions(self, model, inputs_dict, update_rhs):
Expand Down Expand Up @@ -1448,6 +1461,7 @@ def _set_up_ext_and_inputs(
ext_and_inputs = {**external_variables, **inputs}
return ext_and_inputs


class SolverCallable:
"""A class that will be called by the solver when integrating"""

Expand Down Expand Up @@ -1484,6 +1498,7 @@ def function(self, t, y, inputs):
else:
return self._function(t, y, inputs=inputs, known_evals={})[0]


class InitialConditions(SolverCallable):
"""Returns initial conditions given inputs"""

Expand Down
20 changes: 15 additions & 5 deletions pybamm/solvers/scipy_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,19 +78,29 @@ def _integrate(self, model, t_eval, inputs_dict=None):
# Initial conditions
y0 = model.y0
if isinstance(y0, casadi.DM):
y0 = y0.full().flatten()
y0 = y0.full()
y0 = y0.flatten()

# check for user-supplied Jacobian
implicit_methods = ["Radau", "BDF", "LSODA"]
if np.any([self.method in implicit_methods]):
if model.jac_rhs_algebraic_eval:
if model.jac_rhs_eval:
def jacobian(t, y):
return model.jac_rhs_eval(t, y, inputs)
extra_options.update(
{"jac": lambda t, y: model.jac_rhs_algebraic_eval(t, y, inputs)}
{"jac": jacobian}
)

# rhs equation
def rhs(t, y):
return model.rhs_eval(t, y, inputs).reshape(-1)

if model.convert_to_format == 'casadi':
def rhs(t, y):
return model.rhs_eval(t, y, inputs).full().reshape(-1)

# make events terminal so that the solver stops when they are reached
if model.terminate_events_eval:

def event_wrapper(event):
def event_fn(t, y):
return event(t, y, inputs)
Expand All @@ -103,7 +113,7 @@ def event_fn(t, y):

timer = pybamm.Timer()
sol = it.solve_ivp(
lambda t, y: model.rhs_eval(t, y, inputs),
rhs,
(t_eval[0], t_eval[-1]),
y0,
t_eval=t_eval,
Expand Down

0 comments on commit fda76c9

Please sign in to comment.