Skip to content

Commit

Permalink
pybamm-team#4183 remove autograd
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer authored and js1tr3 committed Aug 12, 2024
1 parent ac13eee commit f98f933
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 116 deletions.
1 change: 0 additions & 1 deletion asv.conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@
"wget": [],
"cmake": [],
"anytree": [],
"autograd": [],
"scikit-fem": [],
"imageio": [],
"pybtex": [],
Expand Down
36 changes: 4 additions & 32 deletions pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Callable

import pybamm
from pybamm.util import import_optional_dependency


class Function(pybamm.Symbol):
Expand All @@ -23,9 +22,6 @@ class Function(pybamm.Symbol):
func(child0.evaluate(t, y, u), child1.evaluate(t, y, u), etc).
children : :class:`pybamm.Symbol`
The children nodes to apply the function to
derivative : str, optional
Which derivative to use when differentiating ("autograd" or "derivative").
Default is "autograd".
differentiated_function : method, optional
The function which was differentiated to obtain this one. Default is None.
"""
Expand All @@ -35,7 +31,6 @@ def __init__(
function: Callable,
*children: pybamm.Symbol,
name: str | None = None,
derivative: str | None = "autograd",
differentiated_function: Callable | None = None,
):
# Turn numbers into scalars
Expand All @@ -54,7 +49,6 @@ def __init__(
domains = self.get_children_domains(children)

self.function = function
self.derivative = derivative
self.differentiated_function = differentiated_function

super().__init__(name, children=children, domains=domains)
Expand Down Expand Up @@ -96,30 +90,10 @@ def _function_diff(self, children: Sequence[pybamm.Symbol], idx: float):
Derivative with respect to child number 'idx'.
See :meth:`pybamm.Symbol._diff()`.
"""
autograd = import_optional_dependency("autograd")
# Store differentiated function, needed in case we want to convert to CasADi
if self.derivative == "autograd":
return Function(
autograd.elementwise_grad(self.function, idx),
*children,
differentiated_function=self.function,
)
elif self.derivative == "derivative":
if len(children) > 1:
raise ValueError(
"""
differentiation using '.derivative()' not implemented for functions
with more than one child
"""
)
else:
# keep using "derivative" as derivative
return pybamm.Function(
self.function.derivative(), # type: ignore[attr-defined]
*children,
derivative="derivative",
differentiated_function=self.function,
)
raise NotImplementedError(
"Derivative of base Function class is not implemented. "
"Please implement in child class."
)

def _function_jac(self, children_jacs):
"""Calculate the Jacobian of a function."""
Expand Down Expand Up @@ -187,7 +161,6 @@ def create_copy(
self.function,
*children,
name=self.name,
derivative=self.derivative,
differentiated_function=self.differentiated_function,
)
else:
Expand All @@ -214,7 +187,6 @@ def _function_new_copy(self, children: list) -> Function:
self.function,
*children,
name=self.name,
derivative=self.derivative,
differentiated_function=self.differentiated_function,
)
)
Expand Down
13 changes: 10 additions & 3 deletions pybamm/expression_tree/interpolant.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,7 @@ def __init__(
self.x = x
self.y = y
self.entries_string = entries_string
super().__init__(
interpolating_function, *children, name=name, derivative="derivative"
)
super().__init__(interpolating_function, *children, name=name)

# Store information as attributes
self.interpolator = interpolator
Expand Down Expand Up @@ -309,6 +307,15 @@ def _function_evaluate(self, evaluated_children):
else: # pragma: no cover
raise ValueError(f"Invalid dimension: {self.dimension}")

def _function_diff(self, children: Sequence[pybamm.Symbol], idx: float):
"""
Derivative with respect to child number 'idx'.
See :meth:`pybamm.Symbol._diff()`.
"""
raise NotImplementedError(
"Cannot differentiate Interpolant symbol with respect to its children."
)

def to_json(self):
"""
Method to serialise an Interpolant object into JSON.
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ jax = [
]
# Contains all optional dependencies, except for jax and dev dependencies
all = [
"autograd>=1.6.2",
"scikit-fem>=8.1.0",
"pybamm[examples,plot,cite,bpx,tqdm]",
]
Expand Down
52 changes: 3 additions & 49 deletions tests/unit/test_expression_tree/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from tests import (
function_test,
multi_var_function_test,
multi_var_function_cube_test,
)


Expand Down Expand Up @@ -52,57 +51,12 @@ def test_function_of_one_variable(self):

def test_diff(self):
a = pybamm.StateVector(slice(0, 1))
b = pybamm.StateVector(slice(1, 2))
y = np.array([5])
func = pybamm.Function(function_test, a)
self.assertEqual(func.diff(a).evaluate(y=y), 2)
self.assertEqual(func.diff(func).evaluate(), 1)
func = pybamm.sin(a)
self.assertEqual(func.evaluate(y=y), np.sin(a.evaluate(y=y)))
self.assertEqual(func.diff(a).evaluate(y=y), np.cos(a.evaluate(y=y)))
func = pybamm.exp(a)
self.assertEqual(func.evaluate(y=y), np.exp(a.evaluate(y=y)))
self.assertEqual(func.diff(a).evaluate(y=y), np.exp(a.evaluate(y=y)))

# multiple variables
func = pybamm.Function(multi_var_function_test, 4 * a, 3 * a)
self.assertEqual(func.diff(a).evaluate(y=y), 7)
func = pybamm.Function(multi_var_function_test, 4 * a, 3 * b)
self.assertEqual(func.diff(a).evaluate(y=np.array([5, 6])), 4)
self.assertEqual(func.diff(b).evaluate(y=np.array([5, 6])), 3)
func = pybamm.Function(multi_var_function_cube_test, 4 * a, 3 * b)
self.assertEqual(func.diff(a).evaluate(y=np.array([5, 6])), 4)
self.assertEqual(
func.diff(b).evaluate(y=np.array([5, 6])), 3 * 3 * (3 * 6) ** 2
)

# exceptions
func = pybamm.Function(
multi_var_function_cube_test, 4 * a, 3 * b, derivative="derivative"
)
with self.assertRaises(ValueError):
with self.assertRaisesRegex(
NotImplementedError, "Derivative of base Function class is not implemented"
):
func.diff(a)

def test_function_of_multiple_variables(self):
a = pybamm.Variable("a")
b = pybamm.Parameter("b")
func = pybamm.Function(multi_var_function_test, a, b)
self.assertEqual(func.name, "function (multi_var_function_test)")
self.assertEqual(str(func), "multi_var_function_test(a, b)")
self.assertEqual(func.children[0].name, a.name)
self.assertEqual(func.children[1].name, b.name)

# test eval and diff
a = pybamm.StateVector(slice(0, 1))
b = pybamm.StateVector(slice(1, 2))
y = np.array([5, 2])
func = pybamm.Function(multi_var_function_test, a, b)

self.assertEqual(func.evaluate(y=y), 7)
self.assertEqual(func.diff(a).evaluate(y=y), 1)
self.assertEqual(func.diff(b).evaluate(y=y), 1)
self.assertEqual(func.diff(func).evaluate(), 1)

def test_exceptions(self):
a = pybamm.Variable("a", domain="something")
b = pybamm.Variable("b", domain="something else")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,22 +314,6 @@ def test_concatenations(self):
y_eval = np.linspace(0, 1, expr.size)
self.assert_casadi_equal(f(y_eval), casadi.SX(expr.evaluate(y=y_eval)))

def test_convert_differentiated_function(self):
a = pybamm.InputParameter("a")
b = pybamm.InputParameter("b")

def myfunction(x, y):
return x + y**3

f = pybamm.Function(myfunction, a, b).diff(a)
self.assert_casadi_equal(
f.to_casadi(inputs={"a": 1, "b": 2}), casadi.DM(1), evalf=True
)
f = pybamm.Function(myfunction, a, b).diff(b)
self.assert_casadi_equal(
f.to_casadi(inputs={"a": 1, "b": 2}), casadi.DM(12), evalf=True
)

def test_convert_input_parameter(self):
casadi_t = casadi.MX.sym("t")
casadi_y = casadi.MX.sym("y", 10)
Expand Down
7 changes: 0 additions & 7 deletions tests/unit/test_expression_tree/test_operations/test_jac.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import unittest
from scipy.sparse import eye
from tests import get_mesh_for_testing
from tests import multi_var_function_test


class TestJacobian(TestCase):
Expand Down Expand Up @@ -213,12 +212,6 @@ def test_functions(self):
dfunc_dy = func.jac(y).evaluate(y=y0)
np.testing.assert_array_equal(0, dfunc_dy)

# several children
func = pybamm.Function(multi_var_function_test, 2 * y, 3 * y)
jacobian = np.diag(5 * np.ones(4))
dfunc_dy = func.jac(y).evaluate(y=y0)
np.testing.assert_array_equal(jacobian, dfunc_dy.toarray())

def test_index(self):
vec = pybamm.StateVector(slice(0, 5))
ind = pybamm.Index(vec, 3)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from scipy.sparse import eye
from tests import (
get_1p1d_discretisation_for_testing,
multi_var_function_test,
)


Expand Down Expand Up @@ -200,12 +199,6 @@ def test_functions(self):
dfunc_dy = func.jac(y).evaluate(y=y0)
np.testing.assert_array_equal(0, dfunc_dy)

# several children
func = pybamm.Function(multi_var_function_test, 2 * y, 3 * y)
jacobian = np.diag(5 * np.ones(8))
dfunc_dy = func.jac(y).evaluate(y=y0)
np.testing.assert_array_equal(jacobian, dfunc_dy.toarray())

def test_jac_of_domain_concatenation(self):
# create mesh
disc = get_1p1d_discretisation_for_testing()
Expand Down

0 comments on commit f98f933

Please sign in to comment.