Skip to content

Commit

Permalink
solvers pass unit tests #4087
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jul 5, 2024
1 parent 4aca532 commit 9907906
Show file tree
Hide file tree
Showing 9 changed files with 161 additions and 259 deletions.
39 changes: 26 additions & 13 deletions pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,9 @@ def to_python(
+ "; print(type({0}),np.shape({0}))".format(
id_to_python_variable(symbol_id, False)
)
# + "; jax.debug.print(\"{0} = {{x}}\", x={0}.flatten())".format(
# id_to_python_variable(symbol_id, False)
# )
for symbol_id, symbol_line in variable_symbols.items()
]
else:
Expand Down Expand Up @@ -523,15 +526,13 @@ def __call__(self, t=None, y=None, inputs=None):
else:
nstates = y.shape[0] // self._ninputs
nparams = len(inputs) // self._ninputs
print("nstates:", nstates)
print("nparams:", nparams)

results = [
self._evaluate(
self._constants,
t,
y[i * nstates : (i + 1) * nstates],
input[i * nparams : (i + 1) * nparams],
inputs[i * nparams : (i + 1) * nparams],
)
for i in range(self._ninputs)
]
Expand Down Expand Up @@ -665,15 +666,19 @@ def __init__(
exec(compiled_function)

# use vmap to vectorize the function over the inputs if ninputs > 1
in_axes = ([None] * len(self._arg_list)) + [None, None, 0]
in_axes = ([None] * len(self._arg_list)) + [None, 0, 0]
out_axes = 0
ninputs = len(inputs)
if ninputs > 1:
if is_event:

def mapped_evaluate_jax_event(*args):
# change inputs to a 2d array for vmap (inputs is the last arg)
args[-1] = args[-1].reshape(ninputs, -1)
# change inputs and y to a 2d array for vmap (inputs is the last arg)
args = (
*args[:-2],
args[-2].reshape(ninputs, -1),
args[-1].reshape(ninputs, -1),
)

# exectute the mapped function
results = jax.vmap(
Expand All @@ -685,13 +690,17 @@ def mapped_evaluate_jax_event(*args):
alpha = jax.numpy.log(ninputs) / margin
return jax.scipy.special.logsumexp(alpha * results) / alpha

self._evaluate_jax = mapped_evaluate_jax_event
self._mapped_evaluate_jax = mapped_evaluate_jax_event

else:

def mapped_evaluate_jax(*args):
# change inputs to a 2d array for vmap (inputs is the last arg)
args[-1] = args[-1].reshape(ninputs, -1)
# change inputs and y to a 2d array for vmap (inputs is the last arg)
args = (
*args[:-2],
args[-2].reshape(ninputs, -1),
args[-1].reshape(ninputs, -1),
)

# exectute the mapped function
results = jax.vmap(
Expand All @@ -701,11 +710,13 @@ def mapped_evaluate_jax(*args):
# reshape to a column vector
return results.reshape(-1, 1)

self._evaluate_jax = mapped_evaluate_jax
self._mapped_evaluate_jax = mapped_evaluate_jax
else:
self._mapped_evaluate_jax = self._evaluate_jax

self._static_argnums = tuple(static_argnums)
self._jit_evaluate = jax.jit(
self._evaluate_jax, # type:ignore[attr-defined]
self._mapped_evaluate_jax, # type:ignore[attr-defined]
static_argnums=self._static_argnums,
)

Expand All @@ -715,7 +726,7 @@ def get_jacobian(self):
return self._get_jacfwd(1 + n)

def _get_jacfwd(self, argnum):
jacobian_evaluate = jax.jacfwd(self._evaluate_jax, argnums=argnum)
jacobian_evaluate = jax.jacfwd(self._mapped_evaluate_jax, argnums=argnum)

self._jac_evaluate = jax.jit(
jacobian_evaluate, static_argnums=self._static_argnums
Expand All @@ -737,7 +748,9 @@ def debug(self, t=None, y=None, inputs=None):
y = y.reshape(-1, 1)

# execute code
jaxpr = jax.make_jaxpr(self._evaluate_jax)(*self._constants, t, y, inputs).jaxpr
jaxpr = jax.make_jaxpr(self._mapped_evaluate_jax)(
*self._constants, t, y, inputs
).jaxpr
print("invars:", jaxpr.invars)
print("outvars:", jaxpr.outvars)
print("constvars:", jaxpr.constvars)
Expand Down
78 changes: 1 addition & 77 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,13 +603,7 @@ def _set_initial_conditions(self, model, time, inputs_list, update_rhs):
)
y_zero = np.zeros((y0_total_size, 1))

if model.convert_to_format == "casadi":
# stack inputs
inputs = casadi.vertcat(
*[x for inpts in inputs_list for x in inpts.values()]
)
else:
inputs = inputs_list
inputs = self._inputs_to_stacked_vect(inputs_list, model.convert_to_format)

if self.algebraic_solver is True:
# Don't update model.y0
Expand Down Expand Up @@ -1407,7 +1401,6 @@ def _inputs_to_stacked_vect(inputs_list: list[dict], convert_to_format: str):
for inputs in inputs_list
for x in inputs.values()
]
print(inputs_list, arrays_to_stack)
inputs = np.vstack(arrays_to_stack)
return inputs

Expand Down Expand Up @@ -1495,75 +1488,6 @@ def map_func_over_inputs_casadi(name, f, vars_for_processing, ninputs):
return casadi.Function(name, inputs_stacked, [stack])


def map_func_over_inputs_jax(name, f, vars_for_processing, ninputs):
"""
This takes a casadi function f and returns a new casadi function that maps f over
the provided number of inputs. Some functions (e.g. jacobian action) require an additional
vector input v, which is why add_v is provided.
Parameters
----------
name: str
name of the new function. This must end in the string "_action" for jacobian action functions,
"_jac" for jacobian functions, or "_jacp" for jacp functions.
f: casadi.Function
function to map
vars_for_processing: dict
dictionary of variables for processing
ninputs: int
number of inputs to map over
"""
if f is None:
return None

is_event = "event" in name
add_v = name.endswith("_action")
matrix_output = name.endswith("_jac") or name.endswith("_jacp")

nstates = vars_for_processing["y_and_S"].shape[0]
nparams = vars_for_processing["p_casadi_stacked"].shape[0]

parallelisation = "thread"
y_and_S_inputs_stacked = casadi.MX.sym("y_and_S_stacked", nstates * ninputs)
p_casadi_inputs_stacked = casadi.MX.sym("p_stacked", nparams * ninputs)
v_inputs_stacked = casadi.MX.sym("v_stacked", nstates * ninputs)

y_and_S_2d = y_and_S_inputs_stacked.reshape((nstates, ninputs))
p_casadi_2d = p_casadi_inputs_stacked.reshape((nparams, ninputs))
v_2d = v_inputs_stacked.reshape((nstates, ninputs))

t_casadi = vars_for_processing["t_casadi"]

if add_v:
inputs_2d = [t_casadi, y_and_S_2d, p_casadi_2d, v_2d]
inputs_stacked = [
t_casadi,
y_and_S_inputs_stacked,
p_casadi_inputs_stacked,
v_inputs_stacked,
]
else:
inputs_2d = [t_casadi, y_and_S_2d, p_casadi_2d]
inputs_stacked = [t_casadi, y_and_S_inputs_stacked, p_casadi_inputs_stacked]

mapped_f = f.map(ninputs, parallelisation)(*inputs_2d)
if matrix_output:
# for matrix output we need to stack the outputs in a block diagonal matrix
splits = [i * nstates for i in range(ninputs + 1)]
split = casadi.horzsplit(mapped_f, splits)
stack = casadi.diagcat(*split)
elif is_event:
# Events need to return a scalar, so we combine the vector output
# of the mapped function into a scalar output by calculating a smooth max of the vector output.
stack = casadi.logsumexp(casadi.transpose(mapped_f), 1e-4)
else:
# for vector outputs we need to stack them vertically in a single column vector
splits = [i for i in range(ninputs + 1)]
split = casadi.horzsplit(mapped_f, splits)
stack = casadi.vertcat(*split)
return casadi.Function(name, inputs_stacked, [stack])


def process(
symbol,
name,
Expand Down
8 changes: 6 additions & 2 deletions pybamm/solvers/casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ def _integrate(self, model, t_eval, inputs_list=None):
solution = self._solve_for_event(solution, inputs_list)
solution.check_ys_are_not_too_large()

return solution.split(model.len_rhs, model.len_alg, inputs_list)
return solution.split(
model.len_rhs, model.len_alg, inputs_list, is_casadi_solver=True
)
elif self.mode in ["safe", "safe without grid"]:
y0 = model.y0
# Step-and-check
Expand Down Expand Up @@ -314,7 +316,9 @@ def _integrate(self, model, t_eval, inputs_list=None):
if bool(model.calculate_sensitivities):
solution.sensitivities = True
solution.check_ys_are_not_too_large()
return solution.split(model.len_rhs, model.len_alg, inputs_list)
return solution.split(
model.len_rhs, model.len_alg, inputs_list, is_casadi_solver=True
)

def _solve_for_event(self, coarse_solution, inputs_list):
"""
Expand Down
8 changes: 1 addition & 7 deletions pybamm/solvers/jax_bdf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,17 +876,11 @@ def initialise(g0, y0, t0):
def arg_to_identity(arg):
return onp.identity(arg.shape[0] if arg.ndim > 0 else 1, dtype=arg.dtype)

def arg_dicts_to_values(args):
"""
Note:JAX puts in empty arrays into args for some reason, we remove them here
"""
return sum((tuple(b.values()) for b in args if isinstance(b, dict)), ())

aug_mass = (
mass,
mass,
onp.array(1.0),
*arg_dicts_to_values(tree_map(arg_to_identity, args)),
*tree_map(arg_to_identity, args),
)

def scan_fun(carry, i):
Expand Down
Loading

0 comments on commit 9907906

Please sign in to comment.