diff --git a/pybamm/solvers/algebraic_solver.py b/pybamm/solvers/algebraic_solver.py index 10671e241e..c567bec1ba 100644 --- a/pybamm/solvers/algebraic_solver.py +++ b/pybamm/solvers/algebraic_solver.py @@ -80,7 +80,14 @@ def _integrate(self, model, t_eval, inputs_dict=None): len_rhs = model.rhs_eval(t_eval[0], y0, inputs).shape[0] y0_diff, y0_alg = np.split(y0, [len_rhs]) - algebraic = model.algebraic_eval + if model.convert_to_format == 'casadi': + def algebraic(t, y): + result = model.algebraic_eval(t, y, inputs) + return result.full().flatten() + else: + def algebraic(t, y): + result = model.algebraic_eval(t, y, inputs) + return result.flatten() y_alg = np.empty((len(y0_alg), len(t_eval))) @@ -91,7 +98,7 @@ def _integrate(self, model, t_eval, inputs_dict=None): def root_fun(y_alg): "Evaluates algebraic using y" y = np.concatenate([y0_diff, y_alg]) - out = algebraic(t, y, inputs) + out = algebraic(t, y) pybamm.logger.debug( "Evaluating algebraic equations at t={}, L2-norm is {}".format( t * model.timescale_eval, np.linalg.norm(out) diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index 3b7c47bfdd..1e6fe0b8a0 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -108,7 +108,11 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False): t_eval : numeric type, optional The times (in seconds) at which to compute the solution """ - pybamm.logger.info("Start solver set-up") + + if ics_only: + pybamm.logger.info("Start solver set-up, initial_conditions only") + else: + pybamm.logger.info("Start solver set-up") # Check model.algebraic for ode solvers if self.ode_solver is True and len(model.algebraic) > 0: @@ -474,6 +478,7 @@ def jacp(*args, **kwargs): "initial_conditions", use_jacobian=False, )[0] + model.initial_conditions_eval = initial_conditions # evaluate initial condition y0_total_size = ( @@ -640,12 +645,11 @@ def jacp(*args, **kwargs): model.jac_rhs_eval = jac_rhs model.jacp_rhs_eval = jacp_rhs - model.jacp_algebraic_eval = jacp_algebraic + model.jac_algebraic_eval = jac_algebraic model.jacp_algebraic_eval = jacp_algebraic model.jac_rhs_algebraic_eval = jac_rhs_algebraic model.jacp_rhs_algebraic_eval = jacp_rhs_algebraic - model.initial_conditions_eval = initial_conditions # Save CasADi functions for the CasADi solver # Save CasADi functions for solvers that use CasADi @@ -723,7 +727,7 @@ def _set_initial_conditions(self, model, inputs_dict, update_rhs): model.y0 = casadi.vertcat( y0_from_inputs[:len_rhs], y0_from_model[len_rhs:] ) - y0 = self.calculate_consistent_state(model, 0, inputs) + y0 = self.calculate_consistent_state(model, 0, inputs_dict) # Make y0 a function of inputs if doing symbolic with casadi model.y0 = y0 @@ -738,7 +742,7 @@ def calculate_consistent_state(self, model, time=0, inputs=None): The model for which to calculate initial conditions. time : float The time at which to calculate the states - inputs_dict : dict, optional + inputs: dict, optional Any input parameters to pass to the model when solving Returns diff --git a/tests/unit/test_solvers/test_algebraic_solver.py b/tests/unit/test_solvers/test_algebraic_solver.py index 15e8a7bbd4..74c8fe2eb0 100644 --- a/tests/unit/test_solvers/test_algebraic_solver.py +++ b/tests/unit/test_solvers/test_algebraic_solver.py @@ -44,9 +44,12 @@ class Model(pybamm.BaseModel): timescale_eval = 1 length_scales = {} jac_algebraic_eval = None - convert_to_format = "python" len_rhs_and_alg = 1 + def __init__(self): + super().__init__() + self.convert_to_format = "python" + def algebraic_eval(self, t, y, inputs): return y + 2 @@ -68,9 +71,12 @@ class Model(pybamm.BaseModel): timescale_eval = 1 length_scales = {} jac_algebraic_eval = None - convert_to_format = "casadi" len_rhs_and_alg = 1 + def __init__(self): + super().__init__() + self.convert_to_format = "python" + def algebraic_eval(self, t, y, inputs): # algebraic equation has no real root return y ** 2 + 1 @@ -99,9 +105,12 @@ class Model(pybamm.BaseModel): rhs = {} timescale_eval = 1 length_scales = {} - convert_to_format = "python" len_rhs_and_alg = 2 + def __init__(self): + super().__init__() + self.convert_to_format = "python" + def algebraic_eval(self, t, y, inputs): return A @ y - b