Skip to content

Commit

Permalink
refactor python and jax evaluators to take a stacked inputs vector #4087
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jul 5, 2024
1 parent 5d2f7c6 commit 64a4a3b
Showing 1 changed file with 100 additions and 37 deletions.
137 changes: 100 additions & 37 deletions pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def find_symbols(
symbol: pybamm.Symbol,
constant_symbols: OrderedDict,
variable_symbols: OrderedDict,
input_slices: dict[str, int | slice],
output_jax=False,
):
"""
Expand Down Expand Up @@ -166,6 +167,9 @@ def find_symbols(
variable_symbol: collections.OrderedDict
The output dictionary of variable (with y or t) symbol ids to lines of code
input_slices: dict of (str, int or slice)
A dict mapping the name of an input to a slice of the input vector
output_jax: bool
If True, only numpy and jax operations will be used in the generated code,
raises NotImplNotImplementedError if any SparseStack or Mat-Mat multiply
Expand All @@ -187,7 +191,9 @@ def find_symbols(

# process children recursively
for child in symbol.children:
find_symbols(child, constant_symbols, variable_symbols, output_jax)
find_symbols(
child, constant_symbols, variable_symbols, input_slices, output_jax
)

# calculate the variable names that will hold the result of calculating the
# children variables
Expand Down Expand Up @@ -358,7 +364,7 @@ def find_symbols(
symbol_str = "t"

elif isinstance(symbol, pybamm.InputParameter):
symbol_str = f'inputs["{symbol.name}"]'
symbol_str = f'inputs["{input_slices[symbol.name]}"]'

else:
raise NotImplementedError(
Expand All @@ -369,7 +375,7 @@ def find_symbols(


def to_python(
symbol: pybamm.Symbol, debug=False, output_jax=False
symbol: pybamm.Symbol, inputs: list[dict], debug=False, output_jax=False
) -> tuple[OrderedDict, str]:
"""
This function converts an expression tree into a dict of constant input values, and
Expand All @@ -380,6 +386,9 @@ def to_python(
symbol : :class:`pybamm.Symbol`
The symbol to convert to python code
inputs: list of dict
The inputs to the expression tree
debug : bool
If set to True, the function also emits debug code
Expand All @@ -398,7 +407,18 @@ def to_python(
"""
constant_values: OrderedDict = OrderedDict()
variable_symbols: OrderedDict = OrderedDict()
find_symbols(symbol, constant_values, variable_symbols, output_jax)
input_slices = {}
i = 0
for input_dict in inputs:
for key, value in input_dict.items():
if isinstance(value, np.ndarray):
inc = value.shape[0]
input_slices[key] = slice(i, i + inc)
else:
inc = 1
input_slices[key] = i
i += inc
find_symbols(symbol, constant_values, variable_symbols, input_slices, output_jax)

line_format = "{} = {}"

Expand Down Expand Up @@ -430,6 +450,8 @@ class EvaluatorPython:
symbol : :class:`pybamm.Symbol`
The symbol to convert to python code
inputs: list of dict
The inputs to the expression tree
is_event: bool
Indicates this symbol is an event expression
is_matrix: bool
Expand All @@ -438,9 +460,13 @@ class EvaluatorPython:
"""

def __init__(
self, symbol: pybamm.Symbol, is_event: bool = False, is_matrix: bool = False
self,
symbol: pybamm.Symbol,
inputs: list[dict],
is_event: bool = False,
is_matrix: bool = False,
):
constants, python_str = pybamm.to_python(symbol, debug=False)
constants, python_str = pybamm.to_python(symbol, inputs, debug=False)

# extract constants in generated function
for i, symbol_id in enumerate(constants.keys()):
Expand Down Expand Up @@ -484,6 +510,8 @@ def __init__(
compiled_function = compile(python_str, result_var, "exec")
exec(compiled_function)

self._ninputs = len(inputs)

def __call__(self, t=None, y=None, inputs=None):
"""
evaluate function
Expand All @@ -492,44 +520,35 @@ def __call__(self, t=None, y=None, inputs=None):
if y is not None and y.ndim == 1:
y = y.reshape(-1, 1)

if isinstance(inputs, list):
if len(inputs) == 1:
# nothing to do for a single input
result = self._evaluate(self._constants, t, y, inputs[0])
elif self._is_event:
if self._ninputs == 1:
# nothing to do for a single input
result = self._evaluate(self._constants, t, y, inputs)
else:
nstates = y.shape[0] // self._ninputs
nparams = len(inputs) // self._ninputs

results = [
self._evaluate(
self._constants,
t,
y[i * nstates : (i + 1) * nstates],
input[i * nparams : (i + 1) * nparams],
)
for i in range(self._ninputs)
]

if self._is_event:
# if an event do a soft max on the results to combine events from multiple
# inputs
results = np.array(
[self._evaluate(self._constants, t, y, input) for input in inputs]
)
margin = 1e-4
alpha = np.log(len(inputs)) / margin
result = scipy.special.logsumexp(alpha * results) / alpha
elif self._is_matrix:
# if a matrix output, concatenate the results in a block diagonal matrix
results = [
self._evaluate(self._constants, t, y, input) for input in inputs
]
result = scipy.sparse.block_diag(*results, format="csr")
else:
# otherwise concatenate the results in a column vector
result = self._evaluate(self._constants, t, y, inputs[0])
if len(inputs) > 1:
if isinstance(result, numbers.Number):
result = np.array([result])
ny = result.shape[0]
ni = len(inputs)
results = np.zeros((ni * ny, 1))
results[:ny] = result
i = ny
for input in inputs[1:]:
results[i : i + ny] += self._evaluate(
self._constants, t, y[i : i + ny], input
)
i += ny
result = results
else:
result = self._evaluate(self._constants, t, y, inputs)
result = np.vstack(results)

return result

Expand Down Expand Up @@ -566,17 +585,22 @@ class EvaluatorJax:
symbol : :class:`pybamm.Symbol`
The symbol to convert to python code
inputs: list of dict
The inputs to the model
is_event: bool
Indicates this symbol is an event expression
"""

def __init__(self, symbol: pybamm.Symbol):
def __init__(self, symbol: pybamm.Symbol, inputs: list[dict], is_event: bool):
if not pybamm.have_jax(): # pragma: no cover
raise ModuleNotFoundError(
"Jax or jaxlib is not installed, please see https://docs.pybamm.org/en/latest/source/user_guide/installation/gnu-linux-mac.html#optional-jaxsolver"
)

constants, python_str = pybamm.to_python(symbol, debug=False, output_jax=True)
constants, python_str = pybamm.to_python(
symbol, inputs, debug=False, output_jax=True
)

# replace numpy function calls to jax numpy calls
python_str = python_str.replace("np.", "jax.numpy.")
Expand Down Expand Up @@ -634,6 +658,45 @@ def __init__(self, symbol: pybamm.Symbol):
compiled_function = compile(python_str, result_var, "exec")
exec(compiled_function)

# use vmap to vectorize the function over the inputs if ninputs > 1
in_axes = ([None] * len(self._arg_list)) + [None, None, 0]
out_axes = 0
ninputs = len(inputs)
if ninputs > 1:
if is_event:

def mapped_evaluate_jax_event(*args):
# change inputs to a 2d array for vmap (inputs is the last arg)
args[-1] = args[-1].reshape(ninputs, -1)

# exectute the mapped function
results = jax.vmap(
self._evaluate_jax, in_axes=in_axes, out_axes=out_axes
)(*args)

# if an event do a soft max on the results to combine events from multiple inputs
margin = 1e-4
alpha = jax.numpy.log(ninputs) / margin
return jax.scipy.special.logsumexp(alpha * results) / alpha

self._evaluate_jax = mapped_evaluate_jax_event

else:

def mapped_evaluate_jax(*args):
# change inputs to a 2d array for vmap (inputs is the last arg)
args[-1] = args[-1].reshape(ninputs, -1)

# exectute the mapped function
results = jax.vmap(
self._evaluate_jax, in_axes=in_axes, out_axes=out_axes
)(*args)

# reshape to a column vector
return results.reshape(-1, 1)

self._evaluate_jax = mapped_evaluate_jax

self._static_argnums = tuple(static_argnums)
self._jit_evaluate = jax.jit(
self._evaluate_jax, # type:ignore[attr-defined]
Expand Down

0 comments on commit 64a4a3b

Please sign in to comment.