Skip to content

Commit

Permalink
#1863 #1898 removing solvercallables
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jan 13, 2022
1 parent 7ce6d06 commit 7c7139d
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 191 deletions.
54 changes: 21 additions & 33 deletions pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def __init__(self, symbol):
# add function def to first line
python_str = (
"def evaluate(constants, t=None, y=None, "
"y_dot=None, inputs=None, known_evals=None):\n" + python_str
"inputs=None):\n" + python_str
)

# calculate the final variable that will output the result of calling `evaluate`
Expand All @@ -491,21 +491,17 @@ def __init__(self, symbol):
compiled_function = compile(python_str, result_var, "exec")
exec(compiled_function)

def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
def __call__(self, t=None, y=None, inputs=None):
"""
Acts as a drop-in replacement for :func:`pybamm.Symbol.evaluate`
evaluate function
"""
# generated code assumes y is a column vector
if y is not None and y.ndim == 1:
y = y.reshape(-1, 1)

result = self._evaluate(self._constants, t, y, y_dot, inputs, known_evals)
result = self._evaluate(self._constants, t, y, inputs)

# don't need known_evals, but need to reproduce Symbol.evaluate signature
if known_evals is not None:
return result, known_evals
else:
return result
return result.flatten()

def __getstate__(self):
# Control the state of instances of EvaluatorPython
Expand Down Expand Up @@ -581,7 +577,7 @@ def __init__(self, symbol):
python_str = python_str.replace("\n", "\n ")

# add function def to first line
args = "t=None, y=None, y_dot=None, inputs=None, known_evals=None"
args = "t=None, y=None, inputs=None"
if self._arg_list:
args = ",".join(self._arg_list) + ", " + args
python_str = "def evaluate_jax({}):\n".format(args) + python_str
Expand Down Expand Up @@ -637,14 +633,14 @@ def get_sensitivities(self):

return EvaluatorJaxSensitivities(self._sens_evaluate, self._constants)

def debug(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
def debug(self, t=None, y=None, y_dot=None, inputs=None):
# generated code assumes y is a column vector
if y is not None and y.ndim == 1:
y = y.reshape(-1, 1)

# execute code
jaxpr = jax.make_jaxpr(self._evaluate_jax)(
*self._constants, t, y, y_dot, inputs, known_evals
*self._constants, t, y, y_dot, input
).jaxpr
print("invars:", jaxpr.invars)
print("outvars:", jaxpr.outvars)
Expand All @@ -654,52 +650,47 @@ def debug(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
print()
print("jaxpr:", jaxpr)

def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
def __call__(self, t=None, y=None, inputs=None):
"""
Acts as a drop-in replacement for :func:`pybamm.Symbol.evaluate`
evaluate function
"""
# generated code assumes y is a column vector
if y is not None and y.ndim == 1:
y = y.reshape(-1, 1)

result = self._jit_evaluate(*self._constants, t, y, y_dot, inputs, known_evals)
result = self._jit_evaluate(*self._constants, t, y, y_dot, inputs)

return result.flatten()

# don't need known_evals, but need to reproduce Symbol.evaluate signature
if known_evals is not None:
return result, known_evals
else:
return result


class EvaluatorJaxJacobian:
def __init__(self, jac_evaluate, constants):
self._jac_evaluate = jac_evaluate
self._constants = constants

def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):

def __call__(self, t=None, y=None, inputs=None):
"""
Acts as a drop-in replacement for :func:`pybamm.Symbol.evaluate`
evaluate function
"""
# generated code assumes y is a column vector
if y is not None and y.ndim == 1:
y = y.reshape(-1, 1)

# execute code
result = self._jac_evaluate(*self._constants, t, y, y_dot, inputs, known_evals)
result = self._jac_evaluate(*self._constants, t, y, inputs)
result = result.reshape(result.shape[0], -1)

if known_evals is not None:
return result, known_evals
else:
return result
return result


class EvaluatorJaxSensitivities:
def __init__(self, jac_evaluate, constants):
self._jac_evaluate = jac_evaluate
self._constants = constants

def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
def __call__(self, t=None, y=None, inputs=None):
"""
Acts as a drop-in replacement for :func:`pybamm.Symbol.evaluate`
"""
Expand All @@ -708,9 +699,6 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
y = y.reshape(-1, 1)

# execute code
result = self._jac_evaluate(*self._constants, t, y, y_dot, inputs, known_evals)
result = self._jac_evaluate(*self._constants, t, y, inputs)

if known_evals is not None:
return result, known_evals
else:
return result
return result
Loading

0 comments on commit 7c7139d

Please sign in to comment.