diff --git a/pybamm/solvers/idaklu_solver.py b/pybamm/solvers/idaklu_solver.py index e9c77cf005..d09a537155 100644 --- a/pybamm/solvers/idaklu_solver.py +++ b/pybamm/solvers/idaklu_solver.py @@ -240,7 +240,8 @@ def resfn(t, y, inputs, ydot): ) else: - jac_y0_t0 = model.jac_rhs_algebraic_eval(t_eval[0], y0, inputs_dict) + t0 = 0 if t_eval is None else t_eval[0] + jac_y0_t0 = model.jac_rhs_algebraic_eval(t0, y0, inputs_dict) if sparse.issparse(jac_y0_t0): def jacfn(t, y, inputs, cj): j = ( diff --git a/tests/unit/test_solvers/test_base_solver.py b/tests/unit/test_solvers/test_base_solver.py index a96adab593..3ae5770c63 100644 --- a/tests/unit/test_solvers/test_base_solver.py +++ b/tests/unit/test_solvers/test_base_solver.py @@ -377,11 +377,18 @@ def exact_diff_b(y, a, b): sens = model.jacp_rhs_algebraic_eval(t, y, use_inputs) + if convert_to_format == "casadi": + sens_a = sens[0] + sens_b = sens[1] + else: + sens_a = sens["a"] + sens_b = sens["b"] + np.testing.assert_allclose( - sens["a"], exact_diff_a(y, inputs["a"], inputs["b"]) + sens_a, exact_diff_a(y, inputs["a"], inputs["b"]) ) np.testing.assert_allclose( - sens["b"], exact_diff_b(y, inputs["a"], inputs["b"]) + sens_b, exact_diff_b(y, inputs["a"], inputs["b"]) )