diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index 3d42b4e4cf..6f43922112 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -67,7 +67,7 @@ def __init__( self.ode_solver = False self.algebraic_solver = False self._on_extrapolation = "warn" - self.var_casadi_fcns = {} + self.computed_var_fcns = {} @property def root_method(self): @@ -136,7 +136,7 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False): ) # Process initial conditions - initial_conditions, _, jacp_ic, _ = process( + initial_conditions, _, jacp_ic, _, _ = process( model.concatenated_initial_conditions, "initial_conditions", vars_for_processing, @@ -179,11 +179,11 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False): # Process rhs, algebraic, residual and event expressions # and wrap in callables - rhs, jac_rhs, jacp_rhs, jac_rhs_action = process( + rhs, jac_rhs, jacp_rhs, jac_rhs_action, _ = process( model.concatenated_rhs, "RHS", vars_for_processing ) - algebraic, jac_algebraic, jacp_algebraic, jac_algebraic_action = process( + algebraic, jac_algebraic, jacp_algebraic, jac_algebraic_action, _ = process( model.concatenated_algebraic, "algebraic", vars_for_processing ) @@ -202,6 +202,7 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False): jac_rhs_algebraic, jacp_rhs_algebraic, jac_rhs_algebraic_action, + _, ) = process(rhs_algebraic, "rhs_algebraic", vars_for_processing) ( @@ -241,9 +242,7 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False): # can use DAE solver to solve model with algebraic equations only if len(model.rhs) > 0: t_casadi = vars_for_processing["t_casadi"] - y_casadi = vars_for_processing["y_casadi"] y_and_S = vars_for_processing["y_and_S"] - p_casadi = vars_for_processing["p_casadi"] p_casadi_stacked = vars_for_processing["p_casadi_stacked"] mass_matrix_inv = casadi.MX(model.mass_matrix_inv.entries) explicit_rhs = mass_matrix_inv @ rhs( @@ -260,8 +259,9 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False): # if output_variables specified then convert functions to casadi # expressions for evaluation within the respective solver - self.var_casadi_fcns = {} - self.dvar_dy_casadi_fcns = {} + self.computed_var_fcns = {} + self.computed_dvar_dy_fcns = {} + self.computed_dvar_dp_fcns = {} self.dvar_dp_casadi_fcns = {} for key in self.output_variables: # ExplicitTimeIntegral's are not computed as part of the solver and @@ -270,28 +270,20 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False): model.variables_and_events[key], pybamm.ExplicitTimeIntegral ): continue - # Generate Casadi function to calculate variable - fcn_name = BaseSolver._wrangle_name(key) - var_casadi = model.variables_and_events[key].to_casadi( - t_casadi, y_casadi, inputs=p_casadi - ) - self.var_casadi_fcns[key] = casadi.Function( - fcn_name, [t_casadi, y_casadi, p_casadi_stacked], [var_casadi] + # Generate Casadi function to calculate variable and derivates + # to enable sensitivites to be computed within the solver + ( + self.computed_var_fcns[key], + self.computed_dvar_dy_fcns[key], + _, + _, + self.computed_dvar_dp_fcns[key], + ) = process( + model.variables_and_events[key], + BaseSolver._wrangle_name(key), + vars_for_processing, + use_jacobian=True, ) - # Generate derivative functions for sensitivities - if (len(inputs) > 0) and (model.calculate_sensitivities): - dvar_dy = casadi.jacobian(var_casadi, y_casadi) - dvar_dp = casadi.jacobian(var_casadi, p_casadi_stacked) - self.dvar_dy_casadi_fcns[key] = casadi.Function( - f"d{fcn_name}_dy", - [t_casadi, y_casadi, p_casadi_stacked], - [dvar_dy], - ) - self.dvar_dp_casadi_fcns[key] = casadi.Function( - f"d{fcn_name}_dp", - [t_casadi, y_casadi, p_casadi_stacked], - [dvar_dp], - ) pybamm.logger.info("Finish solver set-up") @@ -1466,6 +1458,11 @@ def process(symbol, name, vars_for_processing, use_jacobian=None): :class:`casadi.Function` evaluator for product of the Jacobian with a vector $v$, i.e. $\frac{\partial f}{\partial y} * v$ + + jacps: :class:`pybamm.EvaluatorPython` or + :class:`pybamm.EvaluatorJax` or + :class:`casadi.Function` + evaluator for derivative of $f(y, t, p)$ wrt $p$ (all stacked) """ def report(string): @@ -1482,6 +1479,7 @@ def report(string): report(f"Converting {name} to jax") func = pybamm.EvaluatorJax(symbol) jacp = None + jacps = None if model.calculate_sensitivities: report( ( @@ -1490,6 +1488,7 @@ def report(string): ) ) jacp = func.get_sensitivities() + jacp = jacps if use_jacobian: report(f"Calculating jacobian for {name} using jax") jac = func.get_jacobian() @@ -1503,6 +1502,7 @@ def report(string): jacobian = vars_for_processing["jacobian"] # Process with pybamm functions, converting # to python evaluator + jacps = None if model.calculate_sensitivities: report( ( @@ -1650,12 +1650,20 @@ def jacp(*args, **kwargs): [t_casadi, y_and_S, p_casadi_stacked, v], [jac_action_casadi], ) + # Compute derivate wrt p-stacked (can be passed to solver to + # compute sensitivities online) + jacps = casadi.Function( + f"d{name}_dp", + [t_casadi, y_casadi, p_casadi_stacked], + [casadi.jacobian(casadi_expression, p_casadi_stacked)], + ) else: jac = None jac_action = None + jacps = None func = casadi.Function( name, [t_casadi, y_and_S, p_casadi_stacked], [casadi_expression] ) - return func, jac, jacp, jac_action + return func, jac, jacp, jac_action, jacps diff --git a/pybamm/solvers/idaklu_solver.py b/pybamm/solvers/idaklu_solver.py index b9e474d2af..d9819f1608 100644 --- a/pybamm/solvers/idaklu_solver.py +++ b/pybamm/solvers/idaklu_solver.py @@ -276,15 +276,15 @@ def resfn(t, y, inputs, ydot): ): continue self.var_idaklu_fcns.append( - idaklu.generate_function(self.var_casadi_fcns[key].serialize()) + idaklu.generate_function(self.computed_var_fcns[key].serialize()) ) # Convert derivative functions for sensitivities if (len(inputs) > 0) and (model.calculate_sensitivities): self.dvar_dy_idaklu_fcns.append( - idaklu.generate_function(self.dvar_dy_casadi_fcns[key].serialize()) + idaklu.generate_function(self.computed_dvar_dy_fcns[key].serialize()) ) self.dvar_dp_idaklu_fcns.append( - idaklu.generate_function(self.dvar_dp_casadi_fcns[key].serialize()) + idaklu.generate_function(self.computed_dvar_dp_fcns[key].serialize()) ) else: @@ -458,7 +458,7 @@ def sensfn(resvalS, t, y, inputs, yp, yS, ypS): "sensitivity_names": sensitivity_names, "number_of_sensitivity_parameters": number_of_sensitivity_parameters, "output_variables": self.output_variables, - "var_casadi_fcns": self.var_casadi_fcns, + "var_casadi_fcns": self.computed_var_fcns, "var_idaklu_fcns": self.var_idaklu_fcns, "dvar_dy_idaklu_fcns": self.dvar_dy_idaklu_fcns, "dvar_dp_idaklu_fcns": self.dvar_dp_idaklu_fcns,