diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index 07ad28d165..6bc05e1296 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -1440,55 +1440,3 @@ def _set_up_ext_and_inputs( ext_and_inputs = {**external_variables, **ordered_inputs} return ext_and_inputs - - -class SolverCallable: - """A class that will be called by the solver when integrating""" - - def __init__(self, function, name, model): - self._function = function - if isinstance(function, casadi.Function): - self.form = "casadi" - else: - self.form = "python" - self.name = name - self.model = model - self.timescale = self.model.timescale_eval - - def __call__(self, t, y, inputs): - pybamm.logger.debug( - "Evaluating {} for {} at t={}".format( - self.name, self.model.name, t * self.timescale - ) - ) - if self.name in ["RHS", "algebraic", "residuals", "event"]: - - return self.function(t, y, inputs).flatten() - else: - return self.function(t, y, inputs) - - def function(self, t, y, inputs): - if self.form == "casadi": - states_eval = self._function(t, y, inputs) - if self.name in ["RHS", "algebraic", "residuals", "event"]: - return states_eval.full() - else: - # keep jacobians sparse - return states_eval - else: - return self._function(t, y, inputs=inputs, known_evals={})[0] - - -class InitialConditions(SolverCallable): - """Returns initial conditions given inputs""" - - def __init__(self, function, model): - super().__init__(function, "initial conditions", model) - - def __call__(self, inputs): - if self.form == "casadi": - if isinstance(inputs, dict): - inputs = casadi.vertcat(*[x for x in inputs.values()]) - return self._function(0, self.y_dummy, inputs) - else: - return self._function(0, self.y_dummy, inputs=inputs).flatten() diff --git a/tests/unit/test_solvers/test_scikits_solvers.py b/tests/unit/test_solvers/test_scikits_solvers.py index 4700d2acbe..5d22b32411 100644 --- a/tests/unit/test_solvers/test_scikits_solvers.py +++ b/tests/unit/test_solvers/test_scikits_solvers.py @@ -782,13 +782,13 @@ def test_model_step_events(self): def test_model_step_nonsmooth_events(self): # Create model model = pybamm.BaseModel() - model.timescale = pybamm.Scalar(1) + model.timescale_eval = pybamm.Scalar(1) var1 = pybamm.Variable("var1") var2 = pybamm.Variable("var2") a = 0.6 discontinuities = (np.arange(3) + 1) * a - model.rhs = {var1: pybamm.Modulo(pybamm.t * model.timescale, a)} + model.rhs = {var1: pybamm.Modulo(pybamm.t * model.timescale_eval, a)} model.algebraic = {var2: 2 * var1 - var2} model.initial_conditions = {var1: 0, var2: 0} model.events = [