diff --git a/CHANGELOG.md b/CHANGELOG.md index 7070ebaf52..f70aae0272 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ ## Optimizations +- Performance refactor of JAX BDF Solver with default Jax method set to `"BDF"`. ([#4456](https://github.com/pybamm-team/PyBaMM/pull/4456)) - Improved performance of initialization and reinitialization of ODEs in the (`IDAKLUSolver`). ([#4453](https://github.com/pybamm-team/PyBaMM/pull/4453)) - Removed the `start_step_offset` setting and disabled minimum `dt` warnings for drive cycles with the (`IDAKLUSolver`). ([#4416](https://github.com/pybamm-team/PyBaMM/pull/4416)) diff --git a/examples/scripts/multiprocess_jax_solver.py b/examples/scripts/multiprocess_jax_solver.py new file mode 100644 index 0000000000..8192256ed1 --- /dev/null +++ b/examples/scripts/multiprocess_jax_solver.py @@ -0,0 +1,57 @@ +import pybamm +import time +import numpy as np + + +# This script provides an example for massively vectorised +# model solves using the JAX BDF solver. First, +# we set up the model and process parameters +model = pybamm.lithium_ion.SPM() +model.convert_to_format = "jax" +model.events = [] # remove events (not supported in jax) +geometry = model.default_geometry +param = pybamm.ParameterValues("Chen2020") +param.update({"Current function [A]": "[input]"}) +param.process_geometry(geometry) +param.process_model(model) + +# Discretise and setup solver +mesh = pybamm.Mesh(geometry, model.default_submesh_types, model.default_var_pts) +disc = pybamm.Discretisation(mesh, model.default_spatial_methods) +disc.process_model(model) +t_eval = np.linspace(0, 3600, 100) +solver = pybamm.JaxSolver(atol=1e-6, rtol=1e-6, method="BDF") + +# Set number of vectorised solves +values = np.linspace(0.01, 1.0, 1000) +inputs = [{"Current function [A]": value} for value in values] + +# Run solve for all inputs, with a just-in-time compilation +# occurring on the first solve. All sequential solves will +# use the compiled code, with a large performance improvement. +start_time = time.time() +sol = solver.solve(model, t_eval, inputs=inputs) +print(f"Time taken: {time.time() - start_time}") # 1.3s + +# Rerun the vectorised solve, showing performance improvement +start_time = time.time() +compiled_sol = solver.solve(model, t_eval, inputs=inputs) +print(f"Compiled time taken: {time.time() - start_time}") # 0.42s + +# Plot one of the solves +plot = pybamm.QuickPlot( + sol[5], + [ + "Negative particle concentration [mol.m-3]", + "Electrolyte concentration [mol.m-3]", + "Positive particle concentration [mol.m-3]", + "Current [A]", + "Negative electrode potential [V]", + "Electrolyte potential [V]", + "Positive electrode potential [V]", + "Voltage [V]", + ], + time_unit="seconds", + spatial_unit="um", +) +plot.dynamic_plot() diff --git a/noxfile.py b/noxfile.py index 6567ed167c..14bafcca47 100644 --- a/noxfile.py +++ b/noxfile.py @@ -222,7 +222,7 @@ def run_scripts(session): # https://bitbucket.org/pybtex-devs/pybtex/issues/169/replace-pkg_resources-with # is fixed session.install("setuptools", silent=False) - session.install("-e", ".[all,dev]", silent=False) + session.install("-e", ".[all,dev,jax]", silent=False) session.run("python", "-m", "pytest", "-m", "scripts") diff --git a/src/pybamm/solvers/base_solver.py b/src/pybamm/solvers/base_solver.py index 6334169cf0..a7fc79fe3a 100644 --- a/src/pybamm/solvers/base_solver.py +++ b/src/pybamm/solvers/base_solver.py @@ -90,6 +90,10 @@ def root_method(self): def supports_parallel_solve(self): return False + @property + def requires_explicit_sensitivities(self): + return True + @root_method.setter def root_method(self, method): if method == "casadi": @@ -141,7 +145,7 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False): # see if we need to form the explicit sensitivity equations calculate_sensitivities_explicit = ( - model.calculate_sensitivities and not isinstance(self, pybamm.IDAKLUSolver) + model.calculate_sensitivities and self.requires_explicit_sensitivities ) self._set_up_model_sensitivities_inplace( @@ -494,11 +498,7 @@ def _set_up_model_sensitivities_inplace( # if we have a mass matrix, we need to extend it def extend_mass_matrix(M): M_extend = [M.entries] * (num_parameters + 1) - M_extend_pybamm = pybamm.Matrix(block_diag(M_extend, format="csr")) - return M_extend_pybamm - - model.mass_matrix = extend_mass_matrix(model.mass_matrix) - model.mass_matrix = extend_mass_matrix(model.mass_matrix) + return pybamm.Matrix(block_diag(M_extend, format="csr")) model.mass_matrix = extend_mass_matrix(model.mass_matrix) diff --git a/src/pybamm/solvers/idaklu_solver.py b/src/pybamm/solvers/idaklu_solver.py index 32048d89c0..2b2d852697 100644 --- a/src/pybamm/solvers/idaklu_solver.py +++ b/src/pybamm/solvers/idaklu_solver.py @@ -736,6 +736,10 @@ def _demote_64_to_32(self, x: pybamm.EvaluatorJax): def supports_parallel_solve(self): return True + @property + def requires_explicit_sensitivities(self): + return False + def _integrate(self, model, t_eval, inputs_list=None, t_interp=None): """ Solve a DAE model defined by residuals with initial conditions y0. diff --git a/src/pybamm/solvers/jax_bdf_solver.py b/src/pybamm/solvers/jax_bdf_solver.py index 6f0c62b9a8..3db82ca0da 100644 --- a/src/pybamm/solvers/jax_bdf_solver.py +++ b/src/pybamm/solvers/jax_bdf_solver.py @@ -111,68 +111,63 @@ def fun_bind_inputs(y, t): jac_bind_inputs = jax.jacfwd(fun_bind_inputs, argnums=0) - t0 = t_eval[0] - h0 = t_eval[1] - t0 - + t0, h0 = t_eval[0], t_eval[1] - t_eval[0] stepper = _bdf_init( fun_bind_inputs, jac_bind_inputs, mass, t0, y0, h0, rtol, atol ) - i = 0 y_out = jnp.empty((len(t_eval), len(y0)), dtype=y0.dtype) - init_state = [stepper, t_eval, i, y_out] - def cond_fun(state): - _, t_eval, i, _ = state + _, _, i, _ = state return i < len(t_eval) def body_fun(state): stepper, t_eval, i, y_out = state stepper = _bdf_step(stepper, fun_bind_inputs, jac_bind_inputs) - index = jnp.searchsorted(t_eval, stepper.t) - index = index.astype( - "int" + t_eval.dtype.name[-2:] - ) # Coerce index to correct type + index = jnp.searchsorted(t_eval, stepper.t).astype(jnp.int32) - def for_body(j, y_out): - t = t_eval[j] - y_out = y_out.at[jnp.index_exp[j, :]].set(_bdf_interpolate(stepper, t)) - return y_out + def interpolate_and_update(j, y_out): + y = _bdf_interpolate(stepper, t_eval[j]) + return y_out.at[j].set(y) - y_out = jax.lax.fori_loop(i, index, for_body, y_out) - return [stepper, t_eval, index, y_out] + y_out = jax.lax.fori_loop(i, index, interpolate_and_update, y_out) + return stepper, t_eval, index, y_out + + init_state = (stepper, t_eval, 0, y_out) + _, _, _, y_out = jax.lax.while_loop(cond_fun, body_fun, init_state) - stepper, t_eval, i, y_out = jax.lax.while_loop(cond_fun, body_fun, init_state) return y_out - BDFInternalStates = [ - "t", - "atol", - "rtol", - "M", - "newton_tol", - "order", - "h", - "n_equal_steps", - "D", - "y0", - "scale_y0", - "kappa", - "gamma", - "alpha", - "c", - "error_const", - "J", - "LU", - "U", - "psi", - "n_function_evals", - "n_jacobian_evals", - "n_lu_decompositions", - "n_steps", - "consistent_y0_failed", - ] - BDFState = collections.namedtuple("BDFState", BDFInternalStates) + BDFState = collections.namedtuple( + "BDFState", + [ + "t", + "atol", + "rtol", + "M", + "newton_tol", + "order", + "h", + "n_equal_steps", + "D", + "y0", + "scale_y0", + "kappa", + "gamma", + "alpha", + "c", + "error_const", + "J", + "LU", + "U", + "psi", + "n_function_evals", + "n_jacobian_evals", + "n_lu_decompositions", + "n_steps", + "consistent_y0_failed", + ], + ) jax.tree_util.register_pytree_node( BDFState, lambda xs: (tuple(xs), None), lambda _, xs: BDFState(*xs) @@ -211,62 +206,70 @@ def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol): absolute tolerance for the solver """ - state = {} - state["t"] = t0 - state["atol"] = atol - state["rtol"] = rtol - state["M"] = mass EPS = jnp.finfo(y0.dtype).eps - state["newton_tol"] = jnp.maximum(10 * EPS / rtol, jnp.minimum(0.03, rtol**0.5)) + # Scaling and tolerance initialisation scale_y0 = atol + rtol * jnp.abs(y0) + newton_tol = jnp.maximum(10 * EPS / rtol, jnp.minimum(0.03, rtol**0.5)) + y0, not_converged = _select_initial_conditions( - fun, mass, t0, y0, state["newton_tol"], scale_y0 + fun, mass, t0, y0, newton_tol, scale_y0 ) - state["consistent_y0_failed"] = not_converged + # Compute initial function and step size f0 = fun(y0, t0) - order = 1 - state["order"] = order - state["h"] = _select_initial_step(atol, rtol, fun, t0, y0, f0, h0) - state["n_equal_steps"] = 0 + h = _select_initial_step(atol, rtol, fun, t0, y0, f0, h0) + + # Initialise the difference matrix, D D = jnp.empty((MAX_ORDER + 1, len(y0)), dtype=y0.dtype) D = D.at[jnp.index_exp[0, :]].set(y0) - D = D.at[jnp.index_exp[1, :]].set(f0 * state["h"]) - state["D"] = D - state["y0"] = y0 - state["scale_y0"] = scale_y0 + D = D.at[jnp.index_exp[1, :]].set(f0 * h) # kappa values for difference orders, taken from Table 1 of [1] kappa = jnp.array([0, -0.1850, -1 / 9, -0.0823, -0.0415, 0]) gamma = jnp.hstack((0, jnp.cumsum(1 / jnp.arange(1, MAX_ORDER + 1)))) alpha = 1.0 / ((1 - kappa) * gamma) - c = state["h"] * alpha[order] + c = h * alpha[1] error_const = kappa * gamma + 1 / jnp.arange(1, MAX_ORDER + 2) - state["kappa"] = kappa - state["gamma"] = gamma - state["alpha"] = alpha - state["c"] = c - state["error_const"] = error_const - + # Jacobian and LU decomp J = jac(y0, t0) - state["J"] = J - - state["LU"] = jax.scipy.linalg.lu_factor(state["M"] - c * J) - - state["U"] = _compute_R(order, 1) - state["psi"] = None - - state["n_function_evals"] = 2 - state["n_jacobian_evals"] = 1 - state["n_lu_decompositions"] = 1 - state["n_steps"] = 0 + LU = jax.scipy.linalg.lu_factor(mass - c * J) + U = _compute_R(1, 1) # Order 1 + + # Create initial BDFState + state = BDFState( + t=t0, + atol=atol, + rtol=rtol, + M=mass, + newton_tol=newton_tol, + consistent_y0_failed=not_converged, + order=1, + h=h, + n_equal_steps=0, + D=D, + y0=y0, + scale_y0=scale_y0, + kappa=kappa, + gamma=gamma, + alpha=alpha, + c=c, + error_const=error_const, + J=J, + LU=LU, + U=U, + psi=None, + n_function_evals=2, + n_jacobian_evals=1, + n_lu_decompositions=1, + n_steps=0, + ) - tuple_state = BDFState(*[state[k] for k in BDFInternalStates]) - y0, scale_y0 = _predict(tuple_state, D) - psi = _update_psi(tuple_state, D) - return tuple_state._replace(y0=y0, scale_y0=scale_y0, psi=psi) + # Predict initial y0, scale_yo, update state + y0, scale_y0 = _predict(state, D) + psi = _update_psi(state, D) + return state._replace(y0=y0, scale_y0=scale_y0, psi=psi) def _compute_R(order, factor): """ @@ -374,10 +377,8 @@ def _predict(state, D): """ predict forward to new step (eq 2 in [1]) """ - n = len(state.y0) - order = state.order - orders = jnp.repeat(jnp.arange(MAX_ORDER + 1).reshape(-1, 1), n, axis=1) - subD = jnp.where(orders <= order, D, 0) + orders = jnp.arange(MAX_ORDER + 1)[:, None] + subD = jnp.where(orders <= state.order, D, 0) y0 = jnp.sum(subD, axis=0) scale_y0 = state.atol + state.rtol * jnp.abs(state.y0) return y0, scale_y0 @@ -397,7 +398,7 @@ def _update_psi(state, D): def _update_difference_for_next_step(state, d): """ - update of difference equations can be done efficiently + Update of difference equations can be done efficiently by reusing d and D. From first equation on page 4 of [1]: @@ -409,34 +410,21 @@ def _update_difference_for_next_step(state, d): Combining these gives the following algorithm """ order = state.order - D = state.D - D = D.at[jnp.index_exp[order + 2]].set(d - D[order + 1]) - D = D.at[jnp.index_exp[order + 1]].set(d) - i = order - while_state = [i, D] - - def while_cond(while_state): - i, _ = while_state - return i >= 0 + D = state.D.at[order + 2].set(d - state.D[order + 1]) + D = D.at[order + 1].set(d) - def while_body(while_state): - i, D = while_state - D = D.at[jnp.index_exp[i]].add(D[i + 1]) - i -= 1 - return [i, D] + def update_D(i, D): + return D.at[order - i].add(D[order - i + 1]) - i, D = jax.lax.while_loop(while_cond, while_body, while_state) - - return D + return jax.lax.fori_loop(0, order + 1, update_D, D) def _update_step_size_and_lu(state, factor): + """ + Update step size and recompute LU decomposition. + """ state = _update_step_size(state, factor) - - # redo lu (c has changed) LU = jax.scipy.linalg.lu_factor(state.M - state.c * state.J) - n_lu_decompositions = state.n_lu_decompositions + 1 - - return state._replace(LU=LU, n_lu_decompositions=n_lu_decompositions) + return state._replace(LU=LU, n_lu_decompositions=state.n_lu_decompositions + 1) def _update_step_size(state, factor): """ @@ -449,7 +437,6 @@ def _update_step_size(state, factor): """ order = state.order h = state.h * factor - n_equal_steps = 0 c = h * state.alpha[order] # update D using equations in section 3.2 of [1] @@ -461,19 +448,14 @@ def _update_step_size(state, factor): RU = jnp.where( jnp.logical_and(I <= order, J <= order), RU, jnp.identity(MAX_ORDER + 1) ) - D = state.D - D = jnp.dot(RU.T, D) - # D = jax.ops.index_update(D, jax.ops.index[:order + 1], - # jnp.dot(RU.T, D[:order + 1])) + D = jnp.dot(RU.T, state.D) - # update psi (D has changed) + # update psi, y0 (D has changed) psi = _update_psi(state, D) - - # update y0 (D has changed) y0, scale_y0 = _predict(state, D) return state._replace( - n_equal_steps=n_equal_steps, + n_equal_steps=0, h=h, c=c, D=D, @@ -484,27 +466,23 @@ def _update_step_size(state, factor): def _update_jacobian(state, jac): """ - we update the jacobian using J(t_{n+1}, y^0_{n+1}) + Update the jacobian using J(t_{n+1}, y^0_{n+1}) following the scipy bdf implementation rather than J(t_n, y_n) as per [1] """ J = jac(state.y0, state.t + state.h) - n_jacobian_evals = state.n_jacobian_evals + 1 LU = jax.scipy.linalg.lu_factor(state.M - state.c * J) - n_lu_decompositions = state.n_lu_decompositions + 1 return state._replace( J=J, - n_jacobian_evals=n_jacobian_evals, + n_jacobian_evals=state.n_jacobian_evals + 1, LU=LU, - n_lu_decompositions=n_lu_decompositions, + n_lu_decompositions=state.n_lu_decompositions + 1, ) def _newton_iteration(state, fun): - tol = state.newton_tol - c = state.c - psi = state.psi + """ + Perform Newton iteration to solve the system. + """ y0 = state.y0 - LU = state.LU - M = state.M scale_y0 = state.scale_y0 t = state.t + state.h d = jnp.zeros(y0.shape, dtype=y0.dtype) @@ -522,17 +500,20 @@ def while_cond(while_state): def while_body(while_state): k, converged, dy_norm_old, d, y, n_function_evals = while_state + f_eval = fun(y, t) n_function_evals += 1 - b = c * f_eval - M @ (psi + d) - dy = jax.scipy.linalg.lu_solve(LU, b) + b = state.c * f_eval - state.M @ (state.psi + d) + dy = jax.scipy.linalg.lu_solve(state.LU, b) dy_norm = jnp.sqrt(jnp.mean((dy / scale_y0) ** 2)) rate = dy_norm / dy_norm_old # if iteration is not going to converge in NEWTON_MAXITER # (assuming the current rate), then abort pred = rate >= 1 - pred += rate ** (NEWTON_MAXITER - k) / (1 - rate) * dy_norm > tol + pred += ( + rate ** (NEWTON_MAXITER - k) / (1 - rate) * dy_norm > state.newton_tol + ) pred *= dy_norm_old >= 0 k += pred * (NEWTON_MAXITER - k - 1) @@ -541,7 +522,7 @@ def while_body(while_state): # if converged then break out of iteration early pred = dy_norm_old >= 0.0 - pred *= rate / (1 - rate) * dy_norm < tol + pred *= rate / (1 - rate) * dy_norm < state.newton_tol converged = (dy_norm == 0.0) + pred dy_norm_old = dy_norm @@ -564,7 +545,6 @@ def _prepare_next_step(state, d): def _prepare_next_step_order_change(state, d, y, n_iter): order = state.order - D = _update_difference_for_next_step(state, d) # Note: we are recalculating these from the while loop above, could re-use? @@ -586,7 +566,6 @@ def _prepare_next_step_order_change(state, d, y, n_iter): rms_norm(state.error_const[order + 1] * D[order + 2] / scale_y), jnp.inf, ) - error_norms = jnp.array([error_m_norm, error_norm, error_p_norm]) factors = error_norms ** (-1 / (jnp.arange(3) + order)) @@ -595,111 +574,89 @@ def _prepare_next_step_order_change(state, d, y, n_iter): max_index = jnp.argmax(factors) order += max_index - 1 + # New step size factor factor = jnp.minimum(MAX_FACTOR, safety * factors[max_index]) - - new_state = _update_step_size_and_lu(state._replace(D=D, order=order), factor) - return new_state + new_state = state._replace(D=D, order=order) + return _update_step_size_and_lu(new_state, factor) def _bdf_step(state, fun, jac): - # print('bdf_step', state.t, state.h) - # we will try and use the old jacobian unless convergence of newton iteration - # fails - updated_jacobian = False - # initialise step size and try to make the step, - # iterate, reducing step size until error is in bounds - step_accepted = False - y = jnp.empty_like(state.y0) - d = jnp.empty_like(state.y0) - n_iter = -1 - - # loop until step is accepted - while_state = [state, step_accepted, updated_jacobian, y, d, n_iter] + """ + Perform a BDF step. - def while_cond(while_state): - _, step_accepted, _, _, _, _ = while_state - return step_accepted == False # noqa: E712 + We will try and use the old jacobian unless + convergence of newton iteration fails. + """ - def while_body(while_state): - state, step_accepted, updated_jacobian, y, d, n_iter = while_state + def step_iteration(while_state): + state, updated_jacobian = while_state - # solve BDF equation using y0 as starting point + # Solve BDF equation using Newton iteration converged, n_iter, y, d, state = _newton_iteration(state, fun) - not_converged = converged == False # noqa: E712 - - # newton iteration did not converge, but jacobian has already been - # evaluated so reduce step size by 0.3 (as per [1]) and try again - state = tree_map( - partial(jnp.where, not_converged * updated_jacobian), - _update_step_size_and_lu(state, 0.3), - state, - ) - # if not_converged * updated_jacobian: - # print('not converged, update step size by 0.3') - # if not_converged * (updated_jacobian == False): - # print('not converged, update jacobian') - - # if not converged and jacobian not updated, then update the jacobian and - # try again - (state, updated_jacobian) = tree_map( - partial( - jnp.where, - not_converged * (updated_jacobian == False), # noqa: E712 + # Update Jacobian or reduce step size if not converged + # Evaluated so reduce step size by 0.3 (as per [1]) and try again + state, updated_jacobian = jax.lax.cond( + ~converged, + lambda s, uj: jax.lax.cond( + uj, + lambda s: (_update_step_size_and_lu(s, 0.3), True), + lambda s: (_update_jacobian(s, jac), True), + s, ), - (_update_jacobian(state, jac), True), - (state, False + updated_jacobian), + lambda s, uj: (s, uj), + state, + updated_jacobian, ) safety = 0.9 * (2 * NEWTON_MAXITER + 1) / (2 * NEWTON_MAXITER + n_iter) scale_y = state.atol + state.rtol * jnp.abs(y) + # Calculate error and updated step size factor # combine eq 3, 4 and 6 from [1] to obtain error # Note that error = C_k * h^{k+1} y^{k+1} # and d = D^{k+1} y_{n+1} \approx h^{k+1} y^{k+1} error = state.error_const[state.order] * d - error_norm = rms_norm(error / scale_y) - # calculate optimal step size factor as per eq 2.46 of [2] - factor = jnp.maximum( - MIN_FACTOR, safety * error_norm ** (-1 / (state.order + 1)) + # Calculate optimal step size factor as per eq 2.46 of [2] + factor = jnp.clip( + safety * error_norm ** (-1 / (state.order + 1)), MIN_FACTOR, None ) - # if converged * (error_norm > 1): - # print( - # "converged, but error is too large", - # error_norm, - # factor, - # d, - # scale_y, - # ) - - (state, step_accepted) = tree_map( - partial(jnp.where, converged * (error_norm > 1)), - (_update_step_size_and_lu(state, factor), False), - (state, converged), + # Update step size if error is too large + state = jax.lax.cond( + converged & (error_norm > 1), + lambda s: _update_step_size_and_lu(s, factor), + lambda s: s, + state, ) - return [state, step_accepted, updated_jacobian, y, d, n_iter] - - state, step_accepted, updated_jacobian, y, d, n_iter = jax.lax.while_loop( - while_cond, while_body, while_state + step_accepted = converged & (error_norm <= 1) + return (state, updated_jacobian), (step_accepted, y, d, n_iter) + + # Iterate until step is accepted + (state, _), (_, y, d, n_iter) = jax.lax.while_loop( + lambda carry_and_aux: ~carry_and_aux[1][0], + lambda carry_and_aux: step_iteration(carry_and_aux[0]), + ( + (state, False), + (False, jnp.empty_like(state.y0), jnp.empty_like(state.y0), -1), + ), ) - # take the accepted step + # Update state for the accepted step n_steps = state.n_steps + 1 t = state.t + state.h - - # a change in order is only done after running at order k for k + 1 steps - # (see page 83 of [2]) n_equal_steps = state.n_equal_steps + 1 - state = state._replace(n_equal_steps=n_equal_steps, t=t, n_steps=n_steps) - state = tree_map( - partial(jnp.where, n_equal_steps < state.order + 1), - _prepare_next_step(state, d), - _prepare_next_step_order_change(state, d, y, n_iter), + # Prepare for the next step, potentially changing order + # (see page 83 of [2]) + state = jax.lax.cond( + n_equal_steps < state.order + 1, + lambda s: _prepare_next_step(s, d), + lambda s: _prepare_next_step_order_change(s, d, y, n_iter), + state, ) return state @@ -710,8 +667,6 @@ def _bdf_interpolate(state, t_eval): definition of the interpolating polynomial can be found on page 7 of [1] """ - order = state.order - t = state.t h = state.h D = state.D j = 0 @@ -721,11 +676,11 @@ def _bdf_interpolate(state, t_eval): def while_cond(while_state): j, _, _ = while_state - return j < order + return j < state.order def while_body(while_state): j, time_factor, order_summation = while_state - time_factor *= (t_eval - (t - h * j)) / (h * (1 + j)) + time_factor *= (t_eval - (state.t - h * j)) / (h * (1 + j)) order_summation += D[j + 1] * time_factor j += 1 return [j, time_factor, order_summation] @@ -972,13 +927,13 @@ def jax_bdf_integrate(func, y0, t_eval, *args, rtol=1e-6, atol=1e-6, mass=None): """ Backward Difference formula (BDF) implicit multistep integrator. The basic algorithm is derived in :footcite:t:`byrne1975polyalgorithm`. This particular implementation - follows that implemented in the Matlab routine ode15s described in - :footcite:t:`shampine1997matlab` and the SciPy implementation - :footcite:t:`Virtanen2020` which features the NDF formulas for improved stability, - with associated differences in the error constants, and calculates the jacobian at - J(t_{n+1}, y^0_{n+1}). This implementation was based on that implemented in the - SciPy library :footcite:t:`Virtanen2020`, which also mainly follows - :footcite:t:`shampine1997matlab` but uses the more standard jacobian update. + follows the Matlab routine ode15s described in :footcite:t:`shampine1997matlab` + and the SciPy implementation :footcite:t:`Virtanen2020` which features + the NDF formulas for improved stability, with associated differences in the + error constants, and calculates the jacobian at J(t_{n+1}, y^0_{n+1}). This + implementation was based on that implemented in the SciPy library + :footcite:t:`Virtanen2020`, which also mainly follows :footcite:t:`shampine1997matlab` + but uses the more standard jacobian update. Parameters ---------- diff --git a/src/pybamm/solvers/jax_solver.py b/src/pybamm/solvers/jax_solver.py index bfcdef1882..a1f1733ed6 100644 --- a/src/pybamm/solvers/jax_solver.py +++ b/src/pybamm/solvers/jax_solver.py @@ -31,10 +31,10 @@ class JaxSolver(pybamm.BaseSolver): Parameters ---------- method: str, optional (see `jax.experimental.ode.odeint` for details) - * 'RK45' (default) uses jax.experimental.ode.odeint - * 'BDF' uses custom jax_bdf_integrate (see `jax_bdf_integrate.py` for details) + * 'BDF' (default) uses custom jax_bdf_integrate (see `jax_bdf_integrate.py` for details) + * 'RK45' uses jax.experimental.ode.odeint root_method: str, optional - Method to use to calculate consistent initial conditions. By default this uses + Method to use to calculate consistent initial conditions. By default, this uses the newton chord method internal to the jax bdf solver, otherwise choose from the set of default options defined in docs for pybamm.BaseSolver rtol : float, optional @@ -52,7 +52,7 @@ class JaxSolver(pybamm.BaseSolver): def __init__( self, - method="RK45", + method="BDF", root_method=None, rtol=1e-6, atol=1e-6, @@ -189,6 +189,10 @@ def solve_model_bdf(inputs): def supports_parallel_solve(self): return True + @property + def requires_explicit_sensitivities(self): + return False + def _integrate(self, model, t_eval, inputs=None, t_interp=None): """ Solve a model defined by dydt with initial conditions y0. diff --git a/tests/unit/test_solvers/test_base_solver.py b/tests/unit/test_solvers/test_base_solver.py index 6753513e72..1733cafc4c 100644 --- a/tests/unit/test_solvers/test_base_solver.py +++ b/tests/unit/test_solvers/test_base_solver.py @@ -19,6 +19,7 @@ def test_base_solver_init(self): assert solver.rtol == 1e-5 solver.rtol = 1e-7 assert solver.rtol == 1e-7 + assert solver.requires_explicit_sensitivities def test_root_method_init(self): solver = pybamm.BaseSolver(root_method="casadi") diff --git a/tests/unit/test_solvers/test_jax_solver.py b/tests/unit/test_solvers/test_jax_solver.py index f7e5b8d3b6..8f43eda3c7 100644 --- a/tests/unit/test_solvers/test_jax_solver.py +++ b/tests/unit/test_solvers/test_jax_solver.py @@ -237,3 +237,11 @@ def test_get_solve(self): y = solver({"rate": 0.2}) np.testing.assert_allclose(y[0], np.exp(-0.2 * t_eval), rtol=1e-6, atol=1e-6) + + # Reset solver, test passing `calculate_sensitivities` + for method in ["RK45", "BDF"]: + solver = pybamm.JaxSolver(method=method, rtol=1e-8, atol=1e-8) + solution_sens = solver.solve( + model, t_eval, inputs={"rate": 0.1}, calculate_sensitivities=True + ) + assert len(solution_sens.sensitivities) == 0