Skip to content

Commit

Permalink
Refactor variable generation through the standard process() function
Browse files Browse the repository at this point in the history
  • Loading branch information
jsbrittain committed Sep 13, 2023
1 parent 2187b89 commit 2b9507f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 34 deletions.
68 changes: 38 additions & 30 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
self.ode_solver = False
self.algebraic_solver = False
self._on_extrapolation = "warn"
self.var_casadi_fcns = {}
self.computed_var_fcns = {}

@property
def root_method(self):
Expand Down Expand Up @@ -136,7 +136,7 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
)

# Process initial conditions
initial_conditions, _, jacp_ic, _ = process(
initial_conditions, _, jacp_ic, _, _ = process(
model.concatenated_initial_conditions,
"initial_conditions",
vars_for_processing,
Expand Down Expand Up @@ -179,11 +179,11 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):

# Process rhs, algebraic, residual and event expressions
# and wrap in callables
rhs, jac_rhs, jacp_rhs, jac_rhs_action = process(
rhs, jac_rhs, jacp_rhs, jac_rhs_action, _ = process(
model.concatenated_rhs, "RHS", vars_for_processing
)

algebraic, jac_algebraic, jacp_algebraic, jac_algebraic_action = process(
algebraic, jac_algebraic, jacp_algebraic, jac_algebraic_action, _ = process(
model.concatenated_algebraic, "algebraic", vars_for_processing
)

Expand All @@ -202,6 +202,7 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
jac_rhs_algebraic,
jacp_rhs_algebraic,
jac_rhs_algebraic_action,
_,
) = process(rhs_algebraic, "rhs_algebraic", vars_for_processing)

(
Expand Down Expand Up @@ -241,9 +242,7 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
# can use DAE solver to solve model with algebraic equations only
if len(model.rhs) > 0:
t_casadi = vars_for_processing["t_casadi"]
y_casadi = vars_for_processing["y_casadi"]
y_and_S = vars_for_processing["y_and_S"]
p_casadi = vars_for_processing["p_casadi"]
p_casadi_stacked = vars_for_processing["p_casadi_stacked"]
mass_matrix_inv = casadi.MX(model.mass_matrix_inv.entries)
explicit_rhs = mass_matrix_inv @ rhs(
Expand All @@ -260,8 +259,9 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):

# if output_variables specified then convert functions to casadi
# expressions for evaluation within the respective solver
self.var_casadi_fcns = {}
self.dvar_dy_casadi_fcns = {}
self.computed_var_fcns = {}
self.computed_dvar_dy_fcns = {}
self.computed_dvar_dp_fcns = {}
self.dvar_dp_casadi_fcns = {}
for key in self.output_variables:
# ExplicitTimeIntegral's are not computed as part of the solver and
Expand All @@ -270,28 +270,20 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
model.variables_and_events[key], pybamm.ExplicitTimeIntegral
):
continue
# Generate Casadi function to calculate variable
fcn_name = BaseSolver._wrangle_name(key)
var_casadi = model.variables_and_events[key].to_casadi(
t_casadi, y_casadi, inputs=p_casadi
)
self.var_casadi_fcns[key] = casadi.Function(
fcn_name, [t_casadi, y_casadi, p_casadi_stacked], [var_casadi]
# Generate Casadi function to calculate variable and derivates
# to enable sensitivites to be computed within the solver
(
self.computed_var_fcns[key],
self.computed_dvar_dy_fcns[key],
_,
_,
self.computed_dvar_dp_fcns[key],
) = process(
model.variables_and_events[key],
BaseSolver._wrangle_name(key),
vars_for_processing,
use_jacobian=True,
)
# Generate derivative functions for sensitivities
if (len(inputs) > 0) and (model.calculate_sensitivities):
dvar_dy = casadi.jacobian(var_casadi, y_casadi)
dvar_dp = casadi.jacobian(var_casadi, p_casadi_stacked)
self.dvar_dy_casadi_fcns[key] = casadi.Function(
f"d{fcn_name}_dy",
[t_casadi, y_casadi, p_casadi_stacked],
[dvar_dy],
)
self.dvar_dp_casadi_fcns[key] = casadi.Function(
f"d{fcn_name}_dp",
[t_casadi, y_casadi, p_casadi_stacked],
[dvar_dp],
)

pybamm.logger.info("Finish solver set-up")

Expand Down Expand Up @@ -1466,6 +1458,11 @@ def process(symbol, name, vars_for_processing, use_jacobian=None):
:class:`casadi.Function`
evaluator for product of the Jacobian with a vector $v$,
i.e. $\frac{\partial f}{\partial y} * v$
jacps: :class:`pybamm.EvaluatorPython` or
:class:`pybamm.EvaluatorJax` or
:class:`casadi.Function`
evaluator for derivative of $f(y, t, p)$ wrt $p$ (all stacked)
"""

def report(string):
Expand All @@ -1482,6 +1479,7 @@ def report(string):
report(f"Converting {name} to jax")
func = pybamm.EvaluatorJax(symbol)
jacp = None
jacps = None
if model.calculate_sensitivities:
report(
(
Expand All @@ -1490,6 +1488,7 @@ def report(string):
)
)
jacp = func.get_sensitivities()
jacp = jacps
if use_jacobian:
report(f"Calculating jacobian for {name} using jax")
jac = func.get_jacobian()
Expand All @@ -1503,6 +1502,7 @@ def report(string):
jacobian = vars_for_processing["jacobian"]
# Process with pybamm functions, converting
# to python evaluator
jacps = None
if model.calculate_sensitivities:
report(
(
Expand Down Expand Up @@ -1650,12 +1650,20 @@ def jacp(*args, **kwargs):
[t_casadi, y_and_S, p_casadi_stacked, v],
[jac_action_casadi],
)
# Compute derivate wrt p-stacked (can be passed to solver to
# compute sensitivities online)
jacps = casadi.Function(
f"d{name}_dp",
[t_casadi, y_casadi, p_casadi_stacked],
[casadi.jacobian(casadi_expression, p_casadi_stacked)],
)
else:
jac = None
jac_action = None
jacps = None

func = casadi.Function(
name, [t_casadi, y_and_S, p_casadi_stacked], [casadi_expression]
)

return func, jac, jacp, jac_action
return func, jac, jacp, jac_action, jacps
8 changes: 4 additions & 4 deletions pybamm/solvers/idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,15 +276,15 @@ def resfn(t, y, inputs, ydot):
):
continue
self.var_idaklu_fcns.append(
idaklu.generate_function(self.var_casadi_fcns[key].serialize())
idaklu.generate_function(self.computed_var_fcns[key].serialize())
)
# Convert derivative functions for sensitivities
if (len(inputs) > 0) and (model.calculate_sensitivities):
self.dvar_dy_idaklu_fcns.append(
idaklu.generate_function(self.dvar_dy_casadi_fcns[key].serialize())
idaklu.generate_function(self.computed_dvar_dy_fcns[key].serialize())
)
self.dvar_dp_idaklu_fcns.append(
idaklu.generate_function(self.dvar_dp_casadi_fcns[key].serialize())
idaklu.generate_function(self.computed_dvar_dp_fcns[key].serialize())
)

else:
Expand Down Expand Up @@ -458,7 +458,7 @@ def sensfn(resvalS, t, y, inputs, yp, yS, ypS):
"sensitivity_names": sensitivity_names,
"number_of_sensitivity_parameters": number_of_sensitivity_parameters,
"output_variables": self.output_variables,
"var_casadi_fcns": self.var_casadi_fcns,
"var_casadi_fcns": self.computed_var_fcns,
"var_idaklu_fcns": self.var_idaklu_fcns,
"dvar_dy_idaklu_fcns": self.dvar_dy_idaklu_fcns,
"dvar_dp_idaklu_fcns": self.dvar_dp_idaklu_fcns,
Expand Down

0 comments on commit 2b9507f

Please sign in to comment.