Skip to content

Commit

Permalink
Merge pull request #4196 from pybamm-team/issue-4183-remove-autograd
Browse files Browse the repository at this point in the history
#4183 remove autograd
  • Loading branch information
valentinsulzer authored Jun 21, 2024
2 parents 8ba4791 + d38117b commit 205ca81
Show file tree
Hide file tree
Showing 10 changed files with 56 additions and 116 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

## Breaking changes

- Functions that are created using `pybamm.Function(function_object, children)` can no longer be differentiated symbolically (e.g. to compute the Jacobian). This should affect no users, since function derivatives for all "standard" functions are explicitly implemented ([#4196](https://github.com/pybamm-team/PyBaMM/pull/4196))
- Removed data files under `pybamm/input` and released them in a separate repository upstream at [pybamm-data](https://github.com/pybamm-team/pybamm-data/releases/tag/v1.0.0). Note that data files under `pybamm/input/parameters` have not been removed. ([#4098](https://github.com/pybamm-team/PyBaMM/pull/4098))
- Removed `check_model` argument from `Simulation.solve`. To change the `check_model` option, use `Simulation(..., discretisation_kwargs={"check_model": False})`. ([#4020](https://github.com/pybamm-team/PyBaMM/pull/4020))
- Removed multiple Docker images. Here on, a single Docker image tagged `pybamm/pybamm:latest` will be provided with both solvers (`IDAKLU` and `JAX`) pre-installed. ([#3992](https://github.com/pybamm-team/PyBaMM/pull/3992))
Expand Down
1 change: 0 additions & 1 deletion asv.conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@
"wget": [],
"cmake": [],
"anytree": [],
"autograd": [],
"scikit-fem": [],
"imageio": [],
"pybtex": [],
Expand Down
36 changes: 4 additions & 32 deletions pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from typing_extensions import TypeVar

import pybamm
from pybamm.util import import_optional_dependency


class Function(pybamm.Symbol):
Expand All @@ -26,9 +25,6 @@ class Function(pybamm.Symbol):
func(child0.evaluate(t, y, u), child1.evaluate(t, y, u), etc).
children : :class:`pybamm.Symbol`
The children nodes to apply the function to
derivative : str, optional
Which derivative to use when differentiating ("autograd" or "derivative").
Default is "autograd".
differentiated_function : method, optional
The function which was differentiated to obtain this one. Default is None.
"""
Expand All @@ -38,7 +34,6 @@ def __init__(
function: Callable,
*children: pybamm.Symbol,
name: str | None = None,
derivative: str | None = "autograd",
differentiated_function: Callable | None = None,
):
# Turn numbers into scalars
Expand All @@ -57,7 +52,6 @@ def __init__(
domains = self.get_children_domains(children)

self.function = function
self.derivative = derivative
self.differentiated_function = differentiated_function

super().__init__(name, children=children, domains=domains)
Expand Down Expand Up @@ -99,30 +93,10 @@ def _function_diff(self, children: Sequence[pybamm.Symbol], idx: float):
Derivative with respect to child number 'idx'.
See :meth:`pybamm.Symbol._diff()`.
"""
autograd = import_optional_dependency("autograd")
# Store differentiated function, needed in case we want to convert to CasADi
if self.derivative == "autograd":
return Function(
autograd.elementwise_grad(self.function, idx),
*children,
differentiated_function=self.function,
)
elif self.derivative == "derivative":
if len(children) > 1:
raise ValueError(
"""
differentiation using '.derivative()' not implemented for functions
with more than one child
"""
)
else:
# keep using "derivative" as derivative
return pybamm.Function(
self.function.derivative(), # type: ignore[attr-defined]
*children,
derivative="derivative",
differentiated_function=self.function,
)
raise NotImplementedError(
"Derivative of base Function class is not implemented. "
"Please implement in child class."
)

def _function_jac(self, children_jacs):
"""Calculate the Jacobian of a function."""
Expand Down Expand Up @@ -190,7 +164,6 @@ def create_copy(
self.function,
*children,
name=self.name,
derivative=self.derivative,
differentiated_function=self.differentiated_function,
)
else:
Expand All @@ -217,7 +190,6 @@ def _function_new_copy(self, children: list) -> Function:
self.function,
*children,
name=self.name,
derivative=self.derivative,
differentiated_function=self.differentiated_function,
)
)
Expand Down
36 changes: 33 additions & 3 deletions pybamm/expression_tree/interpolant.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
interpolator: str | None = "linear",
extrapolate: bool = True,
entries_string: str | None = None,
_num_derivatives: int = 0,
):
# Check interpolator is valid
if interpolator not in ["linear", "cubic", "pchip"]:
Expand Down Expand Up @@ -189,9 +190,13 @@ def __init__(
self.x = x
self.y = y
self.entries_string = entries_string
super().__init__(
interpolating_function, *children, name=name, derivative="derivative"
)

# Differentiate the interpolating function if necessary
self._num_derivatives = _num_derivatives
for _ in range(_num_derivatives):
interpolating_function = interpolating_function.derivative()

super().__init__(interpolating_function, *children, name=name)

# Store information as attributes
self.interpolator = interpolator
Expand All @@ -213,6 +218,7 @@ def _from_json(cls, snippet: dict):
name=snippet["name"],
interpolator=snippet["interpolator"],
extrapolate=snippet["extrapolate"],
_num_derivatives=snippet["_num_derivatives"],
)

@property
Expand Down Expand Up @@ -241,6 +247,7 @@ def set_id(self):
self.entries_string,
*tuple([child.id for child in self.children]),
*tuple(self.domain),
self._num_derivatives,
)
)

Expand All @@ -256,6 +263,7 @@ def create_copy(self, new_children=None, perform_simplifications=True):
interpolator=self.interpolator,
extrapolate=self.extrapolate,
entries_string=self.entries_string,
_num_derivatives=self._num_derivatives,
)

def _function_evaluate(self, evaluated_children):
Expand Down Expand Up @@ -311,6 +319,27 @@ def _function_evaluate(self, evaluated_children):
else: # pragma: no cover
raise ValueError(f"Invalid dimension: {self.dimension}")

def _function_diff(self, children: Sequence[pybamm.Symbol], idx: float):
"""
Derivative with respect to child number 'idx'.
See :meth:`pybamm.Symbol._diff()`.
"""
if len(children) > 1:
raise NotImplementedError(
"differentiation not implemented for functions with more than one child"
)
else:
# keep using "derivative" as derivative
return Interpolant(
self.x,
self.y,
children,
name=self.name,
interpolator=self.interpolator,
extrapolate=self.extrapolate,
_num_derivatives=self._num_derivatives + 1,
)

def to_json(self):
"""
Method to serialise an Interpolant object into JSON.
Expand All @@ -323,6 +352,7 @@ def to_json(self):
"y": self.y.tolist(),
"interpolator": self.interpolator,
"extrapolate": self.extrapolate,
"_num_derivatives": self._num_derivatives,
}

return json_dict
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ jax = [
]
# Contains all optional dependencies, except for jax and dev dependencies
all = [
"autograd>=1.6.2",
"scikit-fem>=8.1.0",
"pybamm[examples,plot,cite,bpx,tqdm]",
]
Expand Down
52 changes: 3 additions & 49 deletions tests/unit/test_expression_tree/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from tests import (
function_test,
multi_var_function_test,
multi_var_function_cube_test,
)


Expand Down Expand Up @@ -52,57 +51,12 @@ def test_function_of_one_variable(self):

def test_diff(self):
a = pybamm.StateVector(slice(0, 1))
b = pybamm.StateVector(slice(1, 2))
y = np.array([5])
func = pybamm.Function(function_test, a)
self.assertEqual(func.diff(a).evaluate(y=y), 2)
self.assertEqual(func.diff(func).evaluate(), 1)
func = pybamm.sin(a)
self.assertEqual(func.evaluate(y=y), np.sin(a.evaluate(y=y)))
self.assertEqual(func.diff(a).evaluate(y=y), np.cos(a.evaluate(y=y)))
func = pybamm.exp(a)
self.assertEqual(func.evaluate(y=y), np.exp(a.evaluate(y=y)))
self.assertEqual(func.diff(a).evaluate(y=y), np.exp(a.evaluate(y=y)))

# multiple variables
func = pybamm.Function(multi_var_function_test, 4 * a, 3 * a)
self.assertEqual(func.diff(a).evaluate(y=y), 7)
func = pybamm.Function(multi_var_function_test, 4 * a, 3 * b)
self.assertEqual(func.diff(a).evaluate(y=np.array([5, 6])), 4)
self.assertEqual(func.diff(b).evaluate(y=np.array([5, 6])), 3)
func = pybamm.Function(multi_var_function_cube_test, 4 * a, 3 * b)
self.assertEqual(func.diff(a).evaluate(y=np.array([5, 6])), 4)
self.assertEqual(
func.diff(b).evaluate(y=np.array([5, 6])), 3 * 3 * (3 * 6) ** 2
)

# exceptions
func = pybamm.Function(
multi_var_function_cube_test, 4 * a, 3 * b, derivative="derivative"
)
with self.assertRaises(ValueError):
with self.assertRaisesRegex(
NotImplementedError, "Derivative of base Function class is not implemented"
):
func.diff(a)

def test_function_of_multiple_variables(self):
a = pybamm.Variable("a")
b = pybamm.Parameter("b")
func = pybamm.Function(multi_var_function_test, a, b)
self.assertEqual(func.name, "function (multi_var_function_test)")
self.assertEqual(str(func), "multi_var_function_test(a, b)")
self.assertEqual(func.children[0].name, a.name)
self.assertEqual(func.children[1].name, b.name)

# test eval and diff
a = pybamm.StateVector(slice(0, 1))
b = pybamm.StateVector(slice(1, 2))
y = np.array([5, 2])
func = pybamm.Function(multi_var_function_test, a, b)

self.assertEqual(func.evaluate(y=y), 7)
self.assertEqual(func.diff(a).evaluate(y=y), 1)
self.assertEqual(func.diff(b).evaluate(y=y), 1)
self.assertEqual(func.diff(func).evaluate(), 1)

def test_exceptions(self):
a = pybamm.Variable("a", domain="something")
b = pybamm.Variable("b", domain="something else")
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/test_expression_tree/test_interpolant.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,20 @@ def test_diff(self):
decimal=3,
)

# test 2D interpolation diff fails
x = (np.arange(-5.01, 5.01, 0.05), np.arange(-5.01, 5.01, 0.01))
xx, yy = np.meshgrid(x[0], x[1], indexing="ij")
z = np.sin(xx**2 + yy**2)
var1 = pybamm.StateVector(slice(0, 1))
var2 = pybamm.StateVector(slice(1, 2))
# linear
interp = pybamm.Interpolant(x, z, (var1, var2), interpolator="linear")
with self.assertRaisesRegex(
NotImplementedError,
"differentiation not implemented for functions with more than one child",
):
interp.diff(var1)

def test_processing(self):
x = np.linspace(0, 1, 200)
y = pybamm.StateVector(slice(0, 2))
Expand Down Expand Up @@ -369,6 +383,7 @@ def test_to_from_json(self):
],
"interpolator": "linear",
"extrapolate": True,
"_num_derivatives": 0,
}

# check correct writing to json
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,22 +314,6 @@ def test_concatenations(self):
y_eval = np.linspace(0, 1, expr.size)
self.assert_casadi_equal(f(y_eval), casadi.SX(expr.evaluate(y=y_eval)))

def test_convert_differentiated_function(self):
a = pybamm.InputParameter("a")
b = pybamm.InputParameter("b")

def myfunction(x, y):
return x + y**3

f = pybamm.Function(myfunction, a, b).diff(a)
self.assert_casadi_equal(
f.to_casadi(inputs={"a": 1, "b": 2}), casadi.DM(1), evalf=True
)
f = pybamm.Function(myfunction, a, b).diff(b)
self.assert_casadi_equal(
f.to_casadi(inputs={"a": 1, "b": 2}), casadi.DM(12), evalf=True
)

def test_convert_input_parameter(self):
casadi_t = casadi.MX.sym("t")
casadi_y = casadi.MX.sym("y", 10)
Expand Down
7 changes: 0 additions & 7 deletions tests/unit/test_expression_tree/test_operations/test_jac.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import unittest
from scipy.sparse import eye
from tests import get_mesh_for_testing
from tests import multi_var_function_test


class TestJacobian(TestCase):
Expand Down Expand Up @@ -213,12 +212,6 @@ def test_functions(self):
dfunc_dy = func.jac(y).evaluate(y=y0)
np.testing.assert_array_equal(0, dfunc_dy)

# several children
func = pybamm.Function(multi_var_function_test, 2 * y, 3 * y)
jacobian = np.diag(5 * np.ones(4))
dfunc_dy = func.jac(y).evaluate(y=y0)
np.testing.assert_array_equal(jacobian, dfunc_dy.toarray())

def test_index(self):
vec = pybamm.StateVector(slice(0, 5))
ind = pybamm.Index(vec, 3)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from scipy.sparse import eye
from tests import (
get_1p1d_discretisation_for_testing,
multi_var_function_test,
)


Expand Down Expand Up @@ -200,12 +199,6 @@ def test_functions(self):
dfunc_dy = func.jac(y).evaluate(y=y0)
np.testing.assert_array_equal(0, dfunc_dy)

# several children
func = pybamm.Function(multi_var_function_test, 2 * y, 3 * y)
jacobian = np.diag(5 * np.ones(8))
dfunc_dy = func.jac(y).evaluate(y=y0)
np.testing.assert_array_equal(jacobian, dfunc_dy.toarray())

def test_jac_of_domain_concatenation(self):
# create mesh
disc = get_1p1d_discretisation_for_testing()
Expand Down

0 comments on commit 205ca81

Please sign in to comment.