From 64a4a3b0a5c51a1d5148a5712ea35e2caa48c66c Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Fri, 21 Jun 2024 17:20:08 +0000 Subject: [PATCH] refactor python and jax evaluators to take a stacked inputs vector #4087 --- .../operations/evaluate_python.py | 137 +++++++++++++----- 1 file changed, 100 insertions(+), 37 deletions(-) diff --git a/pybamm/expression_tree/operations/evaluate_python.py b/pybamm/expression_tree/operations/evaluate_python.py index c5668391e1..457977823e 100644 --- a/pybamm/expression_tree/operations/evaluate_python.py +++ b/pybamm/expression_tree/operations/evaluate_python.py @@ -139,6 +139,7 @@ def find_symbols( symbol: pybamm.Symbol, constant_symbols: OrderedDict, variable_symbols: OrderedDict, + input_slices: dict[str, int | slice], output_jax=False, ): """ @@ -166,6 +167,9 @@ def find_symbols( variable_symbol: collections.OrderedDict The output dictionary of variable (with y or t) symbol ids to lines of code + input_slices: dict of (str, int or slice) + A dict mapping the name of an input to a slice of the input vector + output_jax: bool If True, only numpy and jax operations will be used in the generated code, raises NotImplNotImplementedError if any SparseStack or Mat-Mat multiply @@ -187,7 +191,9 @@ def find_symbols( # process children recursively for child in symbol.children: - find_symbols(child, constant_symbols, variable_symbols, output_jax) + find_symbols( + child, constant_symbols, variable_symbols, input_slices, output_jax + ) # calculate the variable names that will hold the result of calculating the # children variables @@ -358,7 +364,7 @@ def find_symbols( symbol_str = "t" elif isinstance(symbol, pybamm.InputParameter): - symbol_str = f'inputs["{symbol.name}"]' + symbol_str = f'inputs["{input_slices[symbol.name]}"]' else: raise NotImplementedError( @@ -369,7 +375,7 @@ def find_symbols( def to_python( - symbol: pybamm.Symbol, debug=False, output_jax=False + symbol: pybamm.Symbol, inputs: list[dict], debug=False, output_jax=False ) -> tuple[OrderedDict, str]: """ This function converts an expression tree into a dict of constant input values, and @@ -380,6 +386,9 @@ def to_python( symbol : :class:`pybamm.Symbol` The symbol to convert to python code + inputs: list of dict + The inputs to the expression tree + debug : bool If set to True, the function also emits debug code @@ -398,7 +407,18 @@ def to_python( """ constant_values: OrderedDict = OrderedDict() variable_symbols: OrderedDict = OrderedDict() - find_symbols(symbol, constant_values, variable_symbols, output_jax) + input_slices = {} + i = 0 + for input_dict in inputs: + for key, value in input_dict.items(): + if isinstance(value, np.ndarray): + inc = value.shape[0] + input_slices[key] = slice(i, i + inc) + else: + inc = 1 + input_slices[key] = i + i += inc + find_symbols(symbol, constant_values, variable_symbols, input_slices, output_jax) line_format = "{} = {}" @@ -430,6 +450,8 @@ class EvaluatorPython: symbol : :class:`pybamm.Symbol` The symbol to convert to python code + inputs: list of dict + The inputs to the expression tree is_event: bool Indicates this symbol is an event expression is_matrix: bool @@ -438,9 +460,13 @@ class EvaluatorPython: """ def __init__( - self, symbol: pybamm.Symbol, is_event: bool = False, is_matrix: bool = False + self, + symbol: pybamm.Symbol, + inputs: list[dict], + is_event: bool = False, + is_matrix: bool = False, ): - constants, python_str = pybamm.to_python(symbol, debug=False) + constants, python_str = pybamm.to_python(symbol, inputs, debug=False) # extract constants in generated function for i, symbol_id in enumerate(constants.keys()): @@ -484,6 +510,8 @@ def __init__( compiled_function = compile(python_str, result_var, "exec") exec(compiled_function) + self._ninputs = len(inputs) + def __call__(self, t=None, y=None, inputs=None): """ evaluate function @@ -492,44 +520,35 @@ def __call__(self, t=None, y=None, inputs=None): if y is not None and y.ndim == 1: y = y.reshape(-1, 1) - if isinstance(inputs, list): - if len(inputs) == 1: - # nothing to do for a single input - result = self._evaluate(self._constants, t, y, inputs[0]) - elif self._is_event: + if self._ninputs == 1: + # nothing to do for a single input + result = self._evaluate(self._constants, t, y, inputs) + else: + nstates = y.shape[0] // self._ninputs + nparams = len(inputs) // self._ninputs + + results = [ + self._evaluate( + self._constants, + t, + y[i * nstates : (i + 1) * nstates], + input[i * nparams : (i + 1) * nparams], + ) + for i in range(self._ninputs) + ] + + if self._is_event: # if an event do a soft max on the results to combine events from multiple # inputs - results = np.array( - [self._evaluate(self._constants, t, y, input) for input in inputs] - ) margin = 1e-4 alpha = np.log(len(inputs)) / margin result = scipy.special.logsumexp(alpha * results) / alpha elif self._is_matrix: # if a matrix output, concatenate the results in a block diagonal matrix - results = [ - self._evaluate(self._constants, t, y, input) for input in inputs - ] result = scipy.sparse.block_diag(*results, format="csr") else: # otherwise concatenate the results in a column vector - result = self._evaluate(self._constants, t, y, inputs[0]) - if len(inputs) > 1: - if isinstance(result, numbers.Number): - result = np.array([result]) - ny = result.shape[0] - ni = len(inputs) - results = np.zeros((ni * ny, 1)) - results[:ny] = result - i = ny - for input in inputs[1:]: - results[i : i + ny] += self._evaluate( - self._constants, t, y[i : i + ny], input - ) - i += ny - result = results - else: - result = self._evaluate(self._constants, t, y, inputs) + result = np.vstack(results) return result @@ -566,17 +585,22 @@ class EvaluatorJax: symbol : :class:`pybamm.Symbol` The symbol to convert to python code - + inputs: list of dict + The inputs to the model + is_event: bool + Indicates this symbol is an event expression """ - def __init__(self, symbol: pybamm.Symbol): + def __init__(self, symbol: pybamm.Symbol, inputs: list[dict], is_event: bool): if not pybamm.have_jax(): # pragma: no cover raise ModuleNotFoundError( "Jax or jaxlib is not installed, please see https://docs.pybamm.org/en/latest/source/user_guide/installation/gnu-linux-mac.html#optional-jaxsolver" ) - constants, python_str = pybamm.to_python(symbol, debug=False, output_jax=True) + constants, python_str = pybamm.to_python( + symbol, inputs, debug=False, output_jax=True + ) # replace numpy function calls to jax numpy calls python_str = python_str.replace("np.", "jax.numpy.") @@ -634,6 +658,45 @@ def __init__(self, symbol: pybamm.Symbol): compiled_function = compile(python_str, result_var, "exec") exec(compiled_function) + # use vmap to vectorize the function over the inputs if ninputs > 1 + in_axes = ([None] * len(self._arg_list)) + [None, None, 0] + out_axes = 0 + ninputs = len(inputs) + if ninputs > 1: + if is_event: + + def mapped_evaluate_jax_event(*args): + # change inputs to a 2d array for vmap (inputs is the last arg) + args[-1] = args[-1].reshape(ninputs, -1) + + # exectute the mapped function + results = jax.vmap( + self._evaluate_jax, in_axes=in_axes, out_axes=out_axes + )(*args) + + # if an event do a soft max on the results to combine events from multiple inputs + margin = 1e-4 + alpha = jax.numpy.log(ninputs) / margin + return jax.scipy.special.logsumexp(alpha * results) / alpha + + self._evaluate_jax = mapped_evaluate_jax_event + + else: + + def mapped_evaluate_jax(*args): + # change inputs to a 2d array for vmap (inputs is the last arg) + args[-1] = args[-1].reshape(ninputs, -1) + + # exectute the mapped function + results = jax.vmap( + self._evaluate_jax, in_axes=in_axes, out_axes=out_axes + )(*args) + + # reshape to a column vector + return results.reshape(-1, 1) + + self._evaluate_jax = mapped_evaluate_jax + self._static_argnums = tuple(static_argnums) self._jit_evaluate = jax.jit( self._evaluate_jax, # type:ignore[attr-defined]