From 0a8c6ec04897ce934a3820a0d2864be860e6eeef Mon Sep 17 00:00:00 2001 From: lorenzofavaro <44714920+lorenzofavaro@users.noreply.github.com> Date: Sun, 3 Mar 2024 18:35:02 +0100 Subject: [PATCH] Fix cached-instance-method (B019) --- pybamm/expression_tree/symbol.py | 12 +++++++----- pybamm/solvers/idaklu_jax.py | 8 +++++--- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index c308c4ead3..618cc9fb19 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -6,7 +6,7 @@ import numpy as np from scipy.sparse import csr_matrix, issparse -from functools import lru_cache, cached_property +from functools import cached_property from typing import TYPE_CHECKING, Sequence, cast import pybamm @@ -900,7 +900,6 @@ def evaluates_to_number(self): def evaluates_to_constant_number(self): return self.evaluates_to_number() and self.is_constant() - @lru_cache def evaluates_on_edges(self, dimension: str) -> bool: """ Returns True if a symbol evaluates on an edge, i.e. symbol contains a gradient @@ -919,9 +918,12 @@ def evaluates_on_edges(self, dimension: str) -> bool: Whether the symbol evaluates on edges (in the finite volume discretisation sense) """ - eval_on_edges = self._evaluates_on_edges(dimension) - self._saved_evaluates_on_edges[dimension] = eval_on_edges - return eval_on_edges + if dimension not in self._saved_evaluates_on_edges: + self._saved_evaluates_on_edges[dimension] = self._evaluates_on_edges( + dimension + ) + + return self._saved_evaluates_on_edges[dimension] def _evaluates_on_edges(self, dimension): # Default behaviour: return False diff --git a/pybamm/solvers/idaklu_jax.py b/pybamm/solvers/idaklu_jax.py index 2dab3aee76..37ab20ad75 100644 --- a/pybamm/solvers/idaklu_jax.py +++ b/pybamm/solvers/idaklu_jax.py @@ -344,10 +344,11 @@ class _hashabledict(dict): def __hash__(self): return hash(tuple(sorted(self.items()))) + @staticmethod @lru_cache(maxsize=1) - def _cached_solve(self, model, t_hashable, *args, **kwargs): + def _cached_solve(solver, model, t_hashable, *args, **kwargs): """Cache the last solve for reuse""" - return self.solver.solve(model, t_hashable, *args, **kwargs) + return solver.solve(model, t_hashable, *args, **kwargs) def _jaxify_solve(self, t, invar, *inputs_values): """Solve the model using the IDAKLU solver @@ -370,7 +371,8 @@ def _jaxify_solve(self, t, invar, *inputs_values): logger.debug(f" invar: {invar}") logger.debug(f" inputs: {dict(d)}") logger.debug(f" calculate_sensitivities: {invar is not None}") - sim = self._cached_solve( + sim = IDAKLUJax._cached_solve( + self.solver, self.jax_model, tuple(self.jax_t_eval), inputs=self._hashabledict(d),