diff --git a/pybamm/expression_tree/operations/evaluate_python.py b/pybamm/expression_tree/operations/evaluate_python.py index ec8f59fb4b..5efbda5aab 100644 --- a/pybamm/expression_tree/operations/evaluate_python.py +++ b/pybamm/expression_tree/operations/evaluate_python.py @@ -624,8 +624,8 @@ def get_jacobian(self): def get_sensitivities(self): n = len(self._arg_list) - # forward mode autodiff wrt inputs, which is argument 3 after arg_list - jacobian_evaluate = jax.jacfwd(self._evaluate_jax, argnums=3 + n) + # forward mode autodiff wrt inputs, which is argument 2 after arg_list + jacobian_evaluate = jax.jacfwd(self._evaluate_jax, argnums=2 + n) self._sens_evaluate = jax.jit( jacobian_evaluate, static_argnums=self._static_argnums diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index c7981719e8..e5efb8f251 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -320,11 +320,9 @@ def report(string): ) ) jacp = func.get_sensitivities() - jacp = jacp.evaluate if use_jacobian: report(f"Calculating jacobian for {name} using jax") jac = func.get_jacobian() - jac = jac.evaluate else: jac = None @@ -342,17 +340,17 @@ def report(string): p: symbol.diff(pybamm.InputParameter(p)) for p in model.calculate_sensitivities } - if model.convert_to_format == "python": - report(f"Converting sensitivities for {name} to python") - jacp_dict = { - p: pybamm.EvaluatorPython(jacp) - for p, jacp in jacp_dict.items() - } + + report(f"Converting sensitivities for {name} to python") + jacp_dict = { + p: pybamm.EvaluatorPython(jacp) + for p, jacp in jacp_dict.items() + } # jacp should be a function that returns a dict of sensitivities def jacp(*args, **kwargs): return { - k: v.evaluate(*args, **kwargs) for k, v in jacp_dict.items() + k: v(*args, **kwargs) for k, v in jacp_dict.items() } else: @@ -452,8 +450,8 @@ def jacp(*args, **kwargs): jacp_dict = {} for pname in model.calculate_sensitivities: p_diff = casadi.jacobian(casadi_expression, p_casadi[pname]) - jacp_dict[pname] = casadi.casadi_expressiontion( - name, [t_casadi, y_casadi, p_casadi_stacked], [p_diff] + jacp_dict[pname] = casadi.Function( + name, [t_casadi, y_and_S, p_casadi_stacked], [p_diff] ) # jacp should be a casadi_expressiontion that returns diff --git a/tests/unit/test_solvers/test_base_solver.py b/tests/unit/test_solvers/test_base_solver.py index 9d5bcd9bce..5222f01801 100644 --- a/tests/unit/test_solvers/test_base_solver.py +++ b/tests/unit/test_solvers/test_base_solver.py @@ -375,7 +375,8 @@ def exact_diff_b(y, a, b): else: use_inputs = inputs - sens = model.sensitivities_eval(t, y, use_inputs) + sens = model.jacp_rhs_algebraic_eval(t, y, use_inputs) + np.testing.assert_allclose( sens["a"], exact_diff_a(y, inputs["a"], inputs["b"]) )