Skip to content

Commit

Permalink
Only run Jax tests when Jax installed
Browse files Browse the repository at this point in the history
  • Loading branch information
jsbrittain committed Jul 7, 2023
1 parent efc61be commit 5abff1d
Showing 1 changed file with 25 additions and 24 deletions.
49 changes: 25 additions & 24 deletions tests/unit/test_solvers/test_scipy_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,32 +339,33 @@ def test_model_solver_multiple_inputs_initial_conditions_error(self):
solver.solve(model, t_eval, inputs=inputs_list, nproc=2)

def test_model_solver_multiple_inputs_jax_format(self):
# Create model
model = pybamm.BaseModel()
model.convert_to_format = "jax"
domain = ["negative electrode", "separator", "positive electrode"]
var = pybamm.Variable("var", domain=domain)
model.rhs = {var: -pybamm.InputParameter("rate") * var}
model.initial_conditions = {var: 1}
# create discretisation
mesh = get_mesh_for_testing()
spatial_methods = {"macroscale": pybamm.FiniteVolume()}
disc = pybamm.Discretisation(mesh, spatial_methods)
disc.process_model(model)
if pybamm.have_jax():
# Create model
model = pybamm.BaseModel()
model.convert_to_format = "jax"
domain = ["negative electrode", "separator", "positive electrode"]
var = pybamm.Variable("var", domain=domain)
model.rhs = {var: -pybamm.InputParameter("rate") * var}
model.initial_conditions = {var: 1}
# create discretisation
mesh = get_mesh_for_testing()
spatial_methods = {"macroscale": pybamm.FiniteVolume()}
disc = pybamm.Discretisation(mesh, spatial_methods)
disc.process_model(model)

solver = pybamm.JaxSolver(rtol=1e-8, atol=1e-8, method="RK45")
t_eval = np.linspace(0, 10, 100)
ninputs = 8
inputs_list = [{"rate": 0.01 * (i + 1)} for i in range(ninputs)]
solver = pybamm.JaxSolver(rtol=1e-8, atol=1e-8, method="RK45")
t_eval = np.linspace(0, 10, 100)
ninputs = 8
inputs_list = [{"rate": 0.01 * (i + 1)} for i in range(ninputs)]

solutions = solver.solve(model, t_eval, inputs=inputs_list, nproc=2)
for i in range(ninputs):
with self.subTest(i=i):
solution = solutions[i]
np.testing.assert_array_equal(solution.t, t_eval)
np.testing.assert_allclose(
solution.y[0], np.exp(-0.01 * (i + 1) * solution.t)
)
solutions = solver.solve(model, t_eval, inputs=inputs_list, nproc=2)
for i in range(ninputs):
with self.subTest(i=i):
solution = solutions[i]
np.testing.assert_array_equal(solution.t, t_eval)
np.testing.assert_allclose(
solution.y[0], np.exp(-0.01 * (i + 1) * solution.t)
)

def test_model_solver_with_event_with_casadi(self):
# Create model
Expand Down

0 comments on commit 5abff1d

Please sign in to comment.