Skip to content

Commit

Permalink
Refactor variable generation through the standard process() function
Browse files Browse the repository at this point in the history
  • Loading branch information
jsbrittain committed Sep 13, 2023
1 parent 2187b89 commit 7f9aa24
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 32 deletions.
56 changes: 28 additions & 28 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
self.ode_solver = False
self.algebraic_solver = False
self._on_extrapolation = "warn"
self.var_casadi_fcns = {}
self.computed_var_fcns = {}

@property
def root_method(self):
Expand Down Expand Up @@ -241,9 +241,7 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
# can use DAE solver to solve model with algebraic equations only
if len(model.rhs) > 0:
t_casadi = vars_for_processing["t_casadi"]
y_casadi = vars_for_processing["y_casadi"]
y_and_S = vars_for_processing["y_and_S"]
p_casadi = vars_for_processing["p_casadi"]
p_casadi_stacked = vars_for_processing["p_casadi_stacked"]
mass_matrix_inv = casadi.MX(model.mass_matrix_inv.entries)
explicit_rhs = mass_matrix_inv @ rhs(
Expand All @@ -260,38 +258,30 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):

# if output_variables specified then convert functions to casadi
# expressions for evaluation within the respective solver
self.var_casadi_fcns = {}
self.dvar_dy_casadi_fcns = {}
self.dvar_dp_casadi_fcns = {}
self.computed_var_fcns = {}
self.computed_dvar_dy_fcns = {}
self.computed_dvar_dp_fcns = {}
for key in self.output_variables:
# ExplicitTimeIntegral's are not computed as part of the solver and
# do not need to be converted
if isinstance(
model.variables_and_events[key], pybamm.ExplicitTimeIntegral
):
continue
# Generate Casadi function to calculate variable
fcn_name = BaseSolver._wrangle_name(key)
var_casadi = model.variables_and_events[key].to_casadi(
t_casadi, y_casadi, inputs=p_casadi
)
self.var_casadi_fcns[key] = casadi.Function(
fcn_name, [t_casadi, y_casadi, p_casadi_stacked], [var_casadi]
# Generate Casadi function to calculate variable and derivates
# to enable sensitivites to be computed within the solver
(
self.computed_var_fcns[key],
self.computed_dvar_dy_fcns[key],
self.computed_dvar_dp_fcns[key],
_,
) = process(
model.variables_and_events[key],
BaseSolver._wrangle_name(key),
vars_for_processing,
use_jacobian=True,
return_jacp_stacked=True,
)
# Generate derivative functions for sensitivities
if (len(inputs) > 0) and (model.calculate_sensitivities):
dvar_dy = casadi.jacobian(var_casadi, y_casadi)
dvar_dp = casadi.jacobian(var_casadi, p_casadi_stacked)
self.dvar_dy_casadi_fcns[key] = casadi.Function(
f"d{fcn_name}_dy",
[t_casadi, y_casadi, p_casadi_stacked],
[dvar_dy],
)
self.dvar_dp_casadi_fcns[key] = casadi.Function(
f"d{fcn_name}_dp",
[t_casadi, y_casadi, p_casadi_stacked],
[dvar_dp],
)

pybamm.logger.info("Finish solver set-up")

Expand Down Expand Up @@ -1430,7 +1420,7 @@ def _set_up_model_inputs(self, model, inputs):
return ordered_inputs


def process(symbol, name, vars_for_processing, use_jacobian=None):
def process(symbol, name, vars_for_processing, use_jacobian=None, return_jacp_stacked=None):
"""
Parameters
----------
Expand All @@ -1440,6 +1430,8 @@ def process(symbol, name, vars_for_processing, use_jacobian=None):
function evaluators created will have this base name
use_jacobian: bool, optional
whether to return Jacobian functions
return_jacp_stacked: bool, optional
returns Jacobian function wrt stacked parameters instead of jacp
Returns
-------
Expand Down Expand Up @@ -1650,6 +1642,14 @@ def jacp(*args, **kwargs):
[t_casadi, y_and_S, p_casadi_stacked, v],
[jac_action_casadi],
)
# Compute derivate wrt p-stacked (can be passed to solver to
# compute sensitivities online)
if return_jacp_stacked:
jacp = casadi.Function(
f"d{name}_dp",
[t_casadi, y_casadi, p_casadi_stacked],
[casadi.jacobian(casadi_expression, p_casadi_stacked)],
)
else:
jac = None
jac_action = None
Expand Down
8 changes: 4 additions & 4 deletions pybamm/solvers/idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,15 +276,15 @@ def resfn(t, y, inputs, ydot):
):
continue
self.var_idaklu_fcns.append(
idaklu.generate_function(self.var_casadi_fcns[key].serialize())
idaklu.generate_function(self.computed_var_fcns[key].serialize())
)
# Convert derivative functions for sensitivities
if (len(inputs) > 0) and (model.calculate_sensitivities):
self.dvar_dy_idaklu_fcns.append(
idaklu.generate_function(self.dvar_dy_casadi_fcns[key].serialize())
idaklu.generate_function(self.computed_dvar_dy_fcns[key].serialize())
)
self.dvar_dp_idaklu_fcns.append(
idaklu.generate_function(self.dvar_dp_casadi_fcns[key].serialize())
idaklu.generate_function(self.computed_dvar_dp_fcns[key].serialize())
)

else:
Expand Down Expand Up @@ -458,7 +458,7 @@ def sensfn(resvalS, t, y, inputs, yp, yS, ypS):
"sensitivity_names": sensitivity_names,
"number_of_sensitivity_parameters": number_of_sensitivity_parameters,
"output_variables": self.output_variables,
"var_casadi_fcns": self.var_casadi_fcns,
"var_casadi_fcns": self.computed_var_fcns,
"var_idaklu_fcns": self.var_idaklu_fcns,
"dvar_dy_idaklu_fcns": self.dvar_dy_idaklu_fcns,
"dvar_dp_idaklu_fcns": self.dvar_dp_idaklu_fcns,
Expand Down

0 comments on commit 7f9aa24

Please sign in to comment.