Skip to content

Commit

Permalink
#858 improve coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Mar 15, 2020
1 parent 3cdb6dc commit 00990f9
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 19 deletions.
5 changes: 1 addition & 4 deletions pybamm/expression_tree/independent_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@ def _evaluate_for_shape(self):

def _jac(self, variable):
""" See :meth:`pybamm.Symbol._jac()`. """
if variable.id == self.id:
return pybamm.Scalar(1)
else:
return pybamm.Scalar(0)
return pybamm.Scalar(0)


class Time(IndependentVariable):
Expand Down
16 changes: 1 addition & 15 deletions pybamm/expression_tree/state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,20 +105,6 @@ def set_id(self):
+ tuple(self.domain)
)

def _base_evaluate(self, t=None, y=None, y_dot=None, u=None):
""" See :meth:`pybamm.Symbol._base_evaluate()`. """
if y is None:
raise TypeError("StateVector cannot evaluate input 'y=None'")
if y.shape[0] < len(self.evaluation_array):
raise ValueError(
"y is too short, so value with slice is smaller than expected"
)
else:
out = (y[: len(self._evaluation_array)])[self._evaluation_array]
if isinstance(out, np.ndarray) and out.ndim == 1:
out = out[:, np.newaxis]
return out

def _jac_diff_vector(self, variable):
"""
Differentiate a slice of a StateVector of size m with respect to another slice
Expand All @@ -139,7 +125,7 @@ def _jac_diff_vector(self, variable):
variable_size = variable.last_point - variable.first_point

# Return zeros of correct size since no entries match
return csr_matrix(slices_size, variable_size)
return pybamm.Matrix(csr_matrix((slices_size, variable_size)))

def _jac_same_vector(self, variable):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,11 @@ def test_errors(self):
ValueError, "Must provide a 'y' for converting state vectors"
):
y.to_casadi()
y_dot = pybamm.StateVectorDot(slice(0, 10))
with self.assertRaisesRegex(
ValueError, "Must provide a 'y_dot' for converting state vectors"
):
y_dot.to_casadi()
var = pybamm.Variable("var")
with self.assertRaisesRegex(TypeError, "Cannot convert symbol of type"):
var.to_casadi()
Expand Down
41 changes: 41 additions & 0 deletions tests/unit/test_expression_tree/test_operations/test_jac.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,47 @@ def test_nonlinear(self):
with self.assertRaises(pybamm.UndefinedOperationError):
func.jac(y)

def test_linear_ydot(self):
y = pybamm.StateVector(slice(0, 4))
y_dot = pybamm.StateVectorDot(slice(0, 4))
u = pybamm.StateVector(slice(0, 2))
v = pybamm.StateVector(slice(2, 4))
u_dot = pybamm.StateVectorDot(slice(0, 2))
v_dot = pybamm.StateVectorDot(slice(2, 4))

y0 = np.ones(4)
y_dot0 = np.ones(4)

func = u_dot
jacobian = np.array([[1, 0, 0, 0], [0, 1, 0, 0]])
dfunc_dy = func.jac(y_dot).evaluate(y=y0, y_dot=y_dot0)
np.testing.assert_array_equal(jacobian, dfunc_dy.toarray())

func = -v_dot
jacobian = np.array([[0, 0, -1, 0], [0, 0, 0, -1]])
dfunc_dy = func.jac(y_dot).evaluate(y=y0, y_dot=y_dot0)
np.testing.assert_array_equal(jacobian, dfunc_dy.toarray())

func = u_dot
jacobian = np.array([[0, 0, 0, 0], [0, 0, 0, 0]])
dfunc_dy = func.jac(y).evaluate(y=y0, y_dot=y_dot0)
np.testing.assert_array_equal(jacobian, dfunc_dy.toarray())

func = -v_dot
jacobian = np.array([[0, 0, 0, 0], [0, 0, 0, 0]])
dfunc_dy = func.jac(y).evaluate(y=y0, y_dot=y_dot0)
np.testing.assert_array_equal(jacobian, dfunc_dy.toarray())

func = u
jacobian = np.array([[0, 0, 0, 0], [0, 0, 0, 0]])
dfunc_dy = func.jac(y_dot).evaluate(y=y0, y_dot=y_dot0)
np.testing.assert_array_equal(jacobian, dfunc_dy.toarray())

func = -v
jacobian = np.array([[0, 0, 0, 0], [0, 0, 0, 0]])
dfunc_dy = func.jac(y_dot).evaluate(y=y0, y_dot=y_dot0)
np.testing.assert_array_equal(jacobian, dfunc_dy.toarray())

def test_functions(self):
y = pybamm.StateVector(slice(0, 4))
u = pybamm.StateVector(slice(0, 2))
Expand Down
25 changes: 25 additions & 0 deletions tests/unit/test_expression_tree/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ def test_variable_init(self):
self.assertEqual(a.domain[0], "test")
self.assertRaises(TypeError, pybamm.Variable("a", domain="test"))

def test_variable_diff(self):
a = pybamm.Variable("a")
b = pybamm.Variable("b")
self.assertIsInstance(a.diff(a), pybamm.Scalar)
self.assertEqual(a.diff(a).evaluate(), 1)
self.assertIsInstance(a.diff(b), pybamm.Scalar)
self.assertEqual(a.diff(b).evaluate(), 0)

def test_variable_id(self):
a1 = pybamm.Variable("a", domain=["negative electrode"])
a2 = pybamm.Variable("a", domain=["negative electrode"])
Expand All @@ -25,6 +33,7 @@ def test_variable_id(self):
self.assertNotEqual(a1.id, a3.id)
self.assertNotEqual(a1.id, a4.id)


class TestVariableDot(unittest.TestCase):
def test_variable_init(self):
a = pybamm.VariableDot("a'")
Expand All @@ -43,6 +52,14 @@ def test_variable_id(self):
self.assertNotEqual(a1.id, a3.id)
self.assertNotEqual(a1.id, a4.id)

def test_variable_diff(self):
a = pybamm.VariableDot("a")
b = pybamm.Variable("b")
self.assertIsInstance(a.diff(a), pybamm.Scalar)
self.assertEqual(a.diff(a).evaluate(), 1)
self.assertIsInstance(a.diff(b), pybamm.Scalar)
self.assertEqual(a.diff(b).evaluate(), 0)


class TestExternalVariable(unittest.TestCase):
def test_external_variable_scalar(self):
Expand All @@ -68,6 +85,14 @@ def test_external_variable_vector(self):
with self.assertRaisesRegex(ValueError, "External variable"):
a.evaluate(u={"a": np.ones((5, 1))})

def test_external_variable_diff(self):
a = pybamm.ExternalVariable("a", 10)
b = pybamm.Variable("b")
self.assertIsInstance(a.diff(a), pybamm.Scalar)
self.assertEqual(a.diff(a).evaluate(), 1)
self.assertIsInstance(a.diff(b), pybamm.Scalar)
self.assertEqual(a.diff(b).evaluate(), 0)


if __name__ == "__main__":
print("Add -v for more debug output")
Expand Down

0 comments on commit 00990f9

Please sign in to comment.