diff --git a/pybamm/solvers/casadi_solver.py b/pybamm/solvers/casadi_solver.py index c1415007fb..1029a8560e 100644 --- a/pybamm/solvers/casadi_solver.py +++ b/pybamm/solvers/casadi_solver.py @@ -173,7 +173,6 @@ def get_integrator(self, model, t_eval, inputs): # Only set up problem once if model not in self.problems: y0 = model.y0 - ydot0 = model.ydot0 rhs = model.casadi_rhs algebraic = model.casadi_algebraic u_stacked = casadi.vertcat(*[x for x in inputs.values()]) @@ -191,22 +190,21 @@ def get_integrator(self, model, t_eval, inputs): u = casadi.MX.sym("u", u_stacked.shape[0]) y_diff = casadi.MX.sym("y_diff", rhs(t_eval[0], y0, u).shape[0]) problem = {"t": t, "x": y_diff, "p": u} - if algebraic(t_eval[0], y0, ydot0, u).is_empty(): + if algebraic(t_eval[0], y0, u).is_empty(): method = "cvodes" problem.update({"ode": rhs(t, y_diff, u)}) else: options["calc_ic"] = True method = "idas" - y_alg = casadi.MX.sym("y_alg", algebraic(t_eval[0], y0, ydot0, u).shape[0]) + y_alg = casadi.MX.sym("y_alg", algebraic(t_eval[0], y0, u).shape[0]) y_full = casadi.vertcat(y_diff, y_alg) problem.update( { "z": y_alg, "ode": rhs(t, y_full, u), - "alg": algebraic(t, y_full, ydot0, u), + "alg": algebraic(t, y_full, u), } ) - print('using method' ,method) self.problems[model] = problem self.options[model] = options self.methods[model] = method diff --git a/tests/unit/test_solvers/test_casadi_solver.py b/tests/unit/test_solvers/test_casadi_solver.py index 27094449c4..39781ed088 100644 --- a/tests/unit/test_solvers/test_casadi_solver.py +++ b/tests/unit/test_solvers/test_casadi_solver.py @@ -316,17 +316,7 @@ def test_model_solver_with_dvdt(self): model.rhs = {var1: -2 * var1 * pybamm.t} model.algebraic = {var2: var2 - pybamm.d_dt(var1)} model.initial_conditions = {var1: 1, var2: 0} - print('before semi explicit') - for key,value in model.rhs.items(): - print('{}: {}'.format(key, value)) - for key,value in model.algebraic.items(): - print('{}: {}'.format(key, value)) pybamm.make_semi_explicit(model) - print('after semi explicit') - for key,value in model.rhs.items(): - print('{}: {}'.format(key, value)) - for key,value in model.algebraic.items(): - print('{}: {}'.format(key, value)) disc = get_discretisation_for_testing() disc.process_model(model) @@ -335,13 +325,9 @@ def test_model_solver_with_dvdt(self): t_eval = np.linspace(0, 1, 100) solution = solver.solve(model, t_eval) np.testing.assert_array_equal(solution.t, t_eval) - import matplotlib.pyplot as plt - plt.plot(solution.y[0]) - plt.plot(np.exp(-solution.t**2)) - plt.show() - np.testing.assert_allclose(solution.y[0], np.exp(-solution.t ** 2), rtol=1e-06) + np.testing.assert_allclose(solution.y[0], np.exp(-solution.t**2), rtol=1e-06) np.testing.assert_allclose(solution.y[-1], - -2 * solution.t * np.exp(-solution.t**2)) + -2 * solution.t * np.exp(-solution.t**2), rtol=1e-06) if __name__ == "__main__":