Skip to content

Commit

Permalink
add interpolant differentiation back in
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer authored and js1tr3 committed Aug 12, 2024
1 parent 84d25c9 commit 07dbc10
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
29 changes: 26 additions & 3 deletions pybamm/expression_tree/interpolant.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
interpolator: str | None = "linear",
extrapolate: bool = True,
entries_string: str | None = None,
_num_derivatives: int = 0,
):
# Check interpolator is valid
if interpolator not in ["linear", "cubic", "pchip"]:
Expand Down Expand Up @@ -187,6 +188,12 @@ def __init__(
self.x = x
self.y = y
self.entries_string = entries_string

# Differentiate the interpolating function if necessary
self._num_derivatives = _num_derivatives
for _ in range(_num_derivatives):
interpolating_function = interpolating_function.derivative()

super().__init__(interpolating_function, *children, name=name)

# Store information as attributes
Expand All @@ -209,6 +216,7 @@ def _from_json(cls, snippet: dict):
name=snippet["name"],
interpolator=snippet["interpolator"],
extrapolate=snippet["extrapolate"],
_num_derivatives=snippet["_num_derivatives"],
)

@property
Expand Down Expand Up @@ -237,6 +245,7 @@ def set_id(self):
self.entries_string,
*tuple([child.id for child in self.children]),
*tuple(self.domain),
self._num_derivatives,
)
)

Expand All @@ -252,6 +261,7 @@ def create_copy(self, new_children=None, perform_simplifications=True):
interpolator=self.interpolator,
extrapolate=self.extrapolate,
entries_string=self.entries_string,
_num_derivatives=self._num_derivatives,
)

def _function_evaluate(self, evaluated_children):
Expand Down Expand Up @@ -312,9 +322,21 @@ def _function_diff(self, children: Sequence[pybamm.Symbol], idx: float):
Derivative with respect to child number 'idx'.
See :meth:`pybamm.Symbol._diff()`.
"""
raise NotImplementedError(
"Cannot differentiate Interpolant symbol with respect to its children."
)
if len(children) > 1:
raise ValueError(
"differentiation not implemented for functions with more than one child"
)
else:
# keep using "derivative" as derivative
return Interpolant(
self.x,
self.y,
children,
name=self.name,
interpolator=self.interpolator,
extrapolate=self.extrapolate,
_num_derivatives=self._num_derivatives + 1,
)

def to_json(self):
"""
Expand All @@ -328,6 +350,7 @@ def to_json(self):
"y": self.y.tolist(),
"interpolator": self.interpolator,
"extrapolate": self.extrapolate,
"_num_derivatives": self._num_derivatives,
}

return json_dict
1 change: 1 addition & 0 deletions tests/unit/test_expression_tree/test_interpolant.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def test_to_from_json(self):
],
"interpolator": "linear",
"extrapolate": True,
"_num_derivatives": 0,
}

# check correct writing to json
Expand Down

0 comments on commit 07dbc10

Please sign in to comment.