diff --git a/tests/unit/test_solvers/test_scipy_solver.py b/tests/unit/test_solvers/test_scipy_solver.py index e1fe8fc15d..fad6651d55 100644 --- a/tests/unit/test_solvers/test_scipy_solver.py +++ b/tests/unit/test_solvers/test_scipy_solver.py @@ -339,32 +339,33 @@ def test_model_solver_multiple_inputs_initial_conditions_error(self): solver.solve(model, t_eval, inputs=inputs_list, nproc=2) def test_model_solver_multiple_inputs_jax_format(self): - # Create model - model = pybamm.BaseModel() - model.convert_to_format = "jax" - domain = ["negative electrode", "separator", "positive electrode"] - var = pybamm.Variable("var", domain=domain) - model.rhs = {var: -pybamm.InputParameter("rate") * var} - model.initial_conditions = {var: 1} - # create discretisation - mesh = get_mesh_for_testing() - spatial_methods = {"macroscale": pybamm.FiniteVolume()} - disc = pybamm.Discretisation(mesh, spatial_methods) - disc.process_model(model) + if pybamm.have_jax(): + # Create model + model = pybamm.BaseModel() + model.convert_to_format = "jax" + domain = ["negative electrode", "separator", "positive electrode"] + var = pybamm.Variable("var", domain=domain) + model.rhs = {var: -pybamm.InputParameter("rate") * var} + model.initial_conditions = {var: 1} + # create discretisation + mesh = get_mesh_for_testing() + spatial_methods = {"macroscale": pybamm.FiniteVolume()} + disc = pybamm.Discretisation(mesh, spatial_methods) + disc.process_model(model) - solver = pybamm.JaxSolver(rtol=1e-8, atol=1e-8, method="RK45") - t_eval = np.linspace(0, 10, 100) - ninputs = 8 - inputs_list = [{"rate": 0.01 * (i + 1)} for i in range(ninputs)] + solver = pybamm.JaxSolver(rtol=1e-8, atol=1e-8, method="RK45") + t_eval = np.linspace(0, 10, 100) + ninputs = 8 + inputs_list = [{"rate": 0.01 * (i + 1)} for i in range(ninputs)] - solutions = solver.solve(model, t_eval, inputs=inputs_list, nproc=2) - for i in range(ninputs): - with self.subTest(i=i): - solution = solutions[i] - np.testing.assert_array_equal(solution.t, t_eval) - np.testing.assert_allclose( - solution.y[0], np.exp(-0.01 * (i + 1) * solution.t) - ) + solutions = solver.solve(model, t_eval, inputs=inputs_list, nproc=2) + for i in range(ninputs): + with self.subTest(i=i): + solution = solutions[i] + np.testing.assert_array_equal(solution.t, t_eval) + np.testing.assert_allclose( + solution.y[0], np.exp(-0.01 * (i + 1) * solution.t) + ) def test_model_solver_with_event_with_casadi(self): # Create model