From fda76c99afe27d44a690a36f672de78c21154acf Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Fri, 14 Jan 2022 17:48:36 +0000 Subject: [PATCH] #1863 #1898 got scipy solver mostly working --- .../operations/evaluate_python.py | 2 +- pybamm/solvers/base_solver.py | 35 +++++++++++++------ pybamm/solvers/scipy_solver.py | 20 ++++++++--- 3 files changed, 41 insertions(+), 16 deletions(-) diff --git a/pybamm/expression_tree/operations/evaluate_python.py b/pybamm/expression_tree/operations/evaluate_python.py index 78e0b4f594..89ae85892c 100644 --- a/pybamm/expression_tree/operations/evaluate_python.py +++ b/pybamm/expression_tree/operations/evaluate_python.py @@ -501,7 +501,7 @@ def __call__(self, t=None, y=None, inputs=None): result = self._evaluate(self._constants, t, y, inputs) - return result.flatten() + return result def __getstate__(self): # Control the state of instances of EvaluatorPython diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index 3b94b95c4e..3b7c47bfdd 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -321,7 +321,6 @@ def report(string): else: jac = None - elif model.convert_to_format != "casadi": # Process with pybamm functions, converting # to python evaluator @@ -367,7 +366,8 @@ def jacp(*args, **kwargs): else: # Process with CasADi report(f"Converting {name} to CasADi") - casadi_expression = symbol.to_casadi(t_casadi, y_casadi, inputs=p_casadi) + casadi_expression = symbol.to_casadi( + t_casadi, y_casadi, inputs=p_casadi) # Add sensitivity vectors to the rhs and algebraic equations jacp = None if calculate_sensitivities_explicit: @@ -462,7 +462,6 @@ def jacp(*args, **kwargs): else: jac = None - func = casadi.Function( name, [t_casadi, y_and_S, p_casadi_stacked], [casadi_expression] ) @@ -476,6 +475,18 @@ def jacp(*args, **kwargs): use_jacobian=False, )[0] + # evaluate initial condition + y0_total_size = ( + model.len_rhs + model.len_rhs_sens + model.len_alg + model.len_alg_sens + ) + y_zero = np.zeros((y0_total_size, 1)) + if model.convert_to_format == "casadi": + # stack inputs + inputs_casadi = casadi.vertcat(*[x for x in inputs.values()]) + model.y0 = initial_conditions(0, y_zero, inputs_casadi) + else: + model.y0 = initial_conditions(0, y_zero, inputs) + if ics_only: pybamm.logger.info("Finish solver set-up") return @@ -550,8 +561,6 @@ def jacp(*args, **kwargs): ) ) - - # Process rhs, algebraic, residual and event expressions # and wrap in callables rhs, jac_rhs, jacp_rhs = process(model.concatenated_rhs, "RHS") @@ -560,7 +569,6 @@ def jacp(*args, **kwargs): model.concatenated_algebraic, "algebraic" ) - # combine rhs and algebraic functions if len(model.rhs) == 0: rhs_algebraic = model.concatenated_rhs @@ -623,15 +631,22 @@ def jacp(*args, **kwargs): # Add the solver attributes model.rhs_eval = rhs model.algebraic_eval = algebraic + model.rhs_algebraic_eval = rhs_algebraic + model.terminate_events_eval = terminate_events model.discontinuity_events_eval = discontinuity_events model.interpolant_extrapolation_events_eval = interpolant_extrapolation_events - model.rhs_algebraic_eval = rhs_algebraic + + model.jac_rhs_eval = jac_rhs + model.jacp_rhs_eval = jacp_rhs + + model.jacp_algebraic_eval = jacp_algebraic + model.jacp_algebraic_eval = jacp_algebraic + model.jac_rhs_algebraic_eval = jac_rhs_algebraic model.jacp_rhs_algebraic_eval = jacp_rhs_algebraic model.initial_conditions_eval = initial_conditions - # Save CasADi functions for the CasADi solver # Save CasADi functions for solvers that use CasADi # Note: when we pass to casadi the ode part of the problem must be in @@ -653,8 +668,6 @@ def jacp(*args, **kwargs): model.casadi_sensitivities_rhs = jacp_rhs model.casadi_sensitivities_algebraic = jacp_algebraic - - pybamm.logger.info("Finish solver set-up") def _set_initial_conditions(self, model, inputs_dict, update_rhs): @@ -1448,6 +1461,7 @@ def _set_up_ext_and_inputs( ext_and_inputs = {**external_variables, **inputs} return ext_and_inputs + class SolverCallable: """A class that will be called by the solver when integrating""" @@ -1484,6 +1498,7 @@ def function(self, t, y, inputs): else: return self._function(t, y, inputs=inputs, known_evals={})[0] + class InitialConditions(SolverCallable): """Returns initial conditions given inputs""" diff --git a/pybamm/solvers/scipy_solver.py b/pybamm/solvers/scipy_solver.py index 450fc94a3b..df62f724b1 100644 --- a/pybamm/solvers/scipy_solver.py +++ b/pybamm/solvers/scipy_solver.py @@ -78,19 +78,29 @@ def _integrate(self, model, t_eval, inputs_dict=None): # Initial conditions y0 = model.y0 if isinstance(y0, casadi.DM): - y0 = y0.full().flatten() + y0 = y0.full() + y0 = y0.flatten() # check for user-supplied Jacobian implicit_methods = ["Radau", "BDF", "LSODA"] if np.any([self.method in implicit_methods]): - if model.jac_rhs_algebraic_eval: + if model.jac_rhs_eval: + def jacobian(t, y): + return model.jac_rhs_eval(t, y, inputs) extra_options.update( - {"jac": lambda t, y: model.jac_rhs_algebraic_eval(t, y, inputs)} + {"jac": jacobian} ) + # rhs equation + def rhs(t, y): + return model.rhs_eval(t, y, inputs).reshape(-1) + + if model.convert_to_format == 'casadi': + def rhs(t, y): + return model.rhs_eval(t, y, inputs).full().reshape(-1) + # make events terminal so that the solver stops when they are reached if model.terminate_events_eval: - def event_wrapper(event): def event_fn(t, y): return event(t, y, inputs) @@ -103,7 +113,7 @@ def event_fn(t, y): timer = pybamm.Timer() sol = it.solve_ivp( - lambda t, y: model.rhs_eval(t, y, inputs), + rhs, (t_eval[0], t_eval[-1]), y0, t_eval=t_eval,