diff --git a/pybamm/solvers/jax_bdf_solver.py b/pybamm/solvers/jax_bdf_solver.py index 59c5459d12..d9024117b9 100644 --- a/pybamm/solvers/jax_bdf_solver.py +++ b/pybamm/solvers/jax_bdf_solver.py @@ -166,7 +166,7 @@ def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol): state['rtol'] = rtol state['M'] = mass EPS = jnp.finfo(y0.dtype).eps - state['newton_tol'] = jnp.max((10 * EPS / rtol, jnp.min((0.03, rtol ** 0.5)))) + state['newton_tol'] = jnp.maximum(10 * EPS / rtol, jnp.minimum(0.03, rtol ** 0.5)) scale_y0 = atol + rtol * jnp.abs(y0) y0, not_converged = _select_initial_conditions( @@ -325,7 +325,7 @@ def _select_initial_step(atol, rtol, fun, t0, y0, f0, h0): d2 = jnp.sqrt(jnp.mean(((f1 - f0) / scale)**2)) order = 1 h1 = h0 * d2 ** (-1 / (order + 1)) - return jnp.min((100 * h0, h1)) + return jnp.minimum(100 * h0, h1) def _predict(state, D): @@ -559,7 +559,7 @@ def _prepare_next_step_order_change(state, d, y, n_iter): max_index = jnp.argmax(factors) order += max_index - 1 - factor = jnp.min((MAX_FACTOR, safety * factors[max_index])) + factor = jnp.minimum(MAX_FACTOR, safety * factors[max_index]) new_state = _update_step_size_and_lu(state._replace(D=D, order=order), factor) return new_state @@ -599,9 +599,9 @@ def while_body(while_state): state ) - #if not_converged * updated_jacobian: + # if not_converged * updated_jacobian: # print('not converged, update step size by 0.3') - #if not_converged * (updated_jacobian == False): + # if not_converged * (updated_jacobian == False): # print('not converged, update jacobian') # if not converged and jacobian not updated, then update the jacobian and try @@ -626,11 +626,11 @@ def while_body(while_state): error_norm = rms_norm(error / scale_y) # calculate optimal step size factor as per eq 2.46 of [2] - factor = jnp.max((MIN_FACTOR, - safety * - error_norm ** (-1 / (state.order + 1)))) + factor = jnp.maximum(MIN_FACTOR, + safety * + error_norm ** (-1 / (state.order + 1))) - #if converged * (error_norm > 1): + # if converged * (error_norm > 1): # print('converged, but error is too large',error_norm, factor, d, scale_y) (state, step_accepted) = tree_multimap(