Skip to content

Commit

Permalink
#1863 fix base solver tests
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Mar 2, 2022
1 parent 2c34e43 commit 2bbcccf
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 14 deletions.
4 changes: 2 additions & 2 deletions pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,8 @@ def get_jacobian(self):
def get_sensitivities(self):
n = len(self._arg_list)

# forward mode autodiff wrt inputs, which is argument 3 after arg_list
jacobian_evaluate = jax.jacfwd(self._evaluate_jax, argnums=3 + n)
# forward mode autodiff wrt inputs, which is argument 2 after arg_list
jacobian_evaluate = jax.jacfwd(self._evaluate_jax, argnums=2 + n)

self._sens_evaluate = jax.jit(
jacobian_evaluate, static_argnums=self._static_argnums
Expand Down
20 changes: 9 additions & 11 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,9 @@ def report(string):
)
)
jacp = func.get_sensitivities()
jacp = jacp.evaluate
if use_jacobian:
report(f"Calculating jacobian for {name} using jax")
jac = func.get_jacobian()
jac = jac.evaluate
else:
jac = None

Expand All @@ -342,17 +340,17 @@ def report(string):
p: symbol.diff(pybamm.InputParameter(p))
for p in model.calculate_sensitivities
}
if model.convert_to_format == "python":
report(f"Converting sensitivities for {name} to python")
jacp_dict = {
p: pybamm.EvaluatorPython(jacp)
for p, jacp in jacp_dict.items()
}

report(f"Converting sensitivities for {name} to python")
jacp_dict = {
p: pybamm.EvaluatorPython(jacp)
for p, jacp in jacp_dict.items()
}

# jacp should be a function that returns a dict of sensitivities
def jacp(*args, **kwargs):
return {
k: v.evaluate(*args, **kwargs) for k, v in jacp_dict.items()
k: v(*args, **kwargs) for k, v in jacp_dict.items()
}

else:
Expand Down Expand Up @@ -452,8 +450,8 @@ def jacp(*args, **kwargs):
jacp_dict = {}
for pname in model.calculate_sensitivities:
p_diff = casadi.jacobian(casadi_expression, p_casadi[pname])
jacp_dict[pname] = casadi.casadi_expressiontion(
name, [t_casadi, y_casadi, p_casadi_stacked], [p_diff]
jacp_dict[pname] = casadi.Function(
name, [t_casadi, y_and_S, p_casadi_stacked], [p_diff]
)

# jacp should be a casadi_expressiontion that returns
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_solvers/test_base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,8 @@ def exact_diff_b(y, a, b):
else:
use_inputs = inputs

sens = model.sensitivities_eval(t, y, use_inputs)
sens = model.jacp_rhs_algebraic_eval(t, y, use_inputs)

np.testing.assert_allclose(
sens["a"], exact_diff_a(y, inputs["a"], inputs["b"])
)
Expand Down

0 comments on commit 2bbcccf

Please sign in to comment.