Skip to content

Commit

Permalink
#963 fix unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Apr 19, 2020
1 parent 18895f1 commit c4160bd
Show file tree
Hide file tree
Showing 12 changed files with 61 additions and 48 deletions.
19 changes: 13 additions & 6 deletions pybamm/expression_tree/operations/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ class Simplification(object):
def __init__(self, simplified_symbols=None):
self._simplified_symbols = simplified_symbols or {}

def simplify(self, symbol):
def simplify(self, symbol, clear_domains=True):
"""
This function recurses down the tree, applying any simplifications defined in
classes derived from pybamm.Symbol. E.g. any expression multiplied by a
Expand All @@ -577,7 +577,9 @@ def simplify(self, symbol):
Parameters
----------
symbol : :class:`pybamm.Symbol`
The symbol to simplify
The symbol to simplify
clear_domains : bool
Whether to remove a symbol's domain when simplifying. Default is True.
Returns
-------
Expand All @@ -588,15 +590,16 @@ def simplify(self, symbol):
try:
return self._simplified_symbols[symbol.id]
except KeyError:
simplified_symbol = self._simplify(symbol)
simplified_symbol = self._simplify(symbol, clear_domains)

self._simplified_symbols[symbol.id] = simplified_symbol

return simplified_symbol

def _simplify(self, symbol):
def _simplify(self, symbol, clear_domains=True):
""" See :meth:`Simplification.simplify()`. """
symbol.clear_domains()
if clear_domains:
symbol.clear_domains()

if isinstance(symbol, pybamm.BinaryOperator):
left, right = symbol.children
Expand All @@ -607,7 +610,11 @@ def _simplify(self, symbol):
new_symbol = symbol._binary_simplify(new_left, new_right)

elif isinstance(symbol, pybamm.UnaryOperator):
new_child = self.simplify(symbol.child)
# Reassign domain for gradient and divergence
if isinstance(symbol, (pybamm.Gradient, pybamm.Divergence)):
new_child = self.simplify(symbol.child, clear_domains=False)
else:
new_child = self.simplify(symbol.child)
# _unary_simplify defined in derived classes for specific rules
new_symbol = symbol._unary_simplify(new_child)

Expand Down
10 changes: 10 additions & 0 deletions pybamm/expression_tree/unary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,10 @@ def __init__(self, child):
+ "Try broadcasting the object first, e.g.\n\n"
"\tpybamm.grad(pybamm.PrimaryBroadcast(symbol, 'domain'))"
)
if child.evaluates_on_edges() is True:
raise TypeError(
"Cannot take gradient of '{}' since it evaluates on edges".format(child)
)
super().__init__("grad", child)

def evaluates_on_edges(self):
Expand All @@ -338,6 +342,12 @@ def __init__(self, child):
+ "Try broadcasting the object first, e.g.\n\n"
"\tpybamm.div(pybamm.PrimaryBroadcast(symbol, 'domain'))"
)
if child.evaluates_on_edges() is False:
raise TypeError(
"Cannot take divergence of '{}' since it does not ".format(child)
+ "evaluate on edges. Usually, a gradient should be taken before the "
"divergence."
)
super().__init__("div", child)

def evaluates_on_edges(self):
Expand Down
11 changes: 2 additions & 9 deletions tests/unit/test_discretisations/test_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,12 +475,11 @@ def test_discretise_spatial_operator(self):
disc.set_variable_slices(variables)

# Simple expressions
for eqn in [pybamm.grad(var), pybamm.div(var)]:
for eqn in [pybamm.grad(var), pybamm.div(pybamm.grad(var))]:
eqn_disc = disc.process_symbol(eqn)

self.assertIsInstance(eqn_disc, pybamm.MatrixMultiplication)
self.assertIsInstance(eqn_disc.children[0], pybamm.Matrix)
self.assertIsInstance(eqn_disc.children[1], pybamm.StateVector)

combined_submesh = mesh.combine_submeshes(*whole_cell)
y = combined_submesh[0].nodes ** 2
Expand All @@ -491,14 +490,13 @@ def test_discretise_spatial_operator(self):
)

# More complex expressions
for eqn in [var * pybamm.grad(var), var * pybamm.div(var)]:
for eqn in [var * pybamm.grad(var), var * pybamm.div(pybamm.grad(var))]:
eqn_disc = disc.process_symbol(eqn)

self.assertIsInstance(eqn_disc, pybamm.Multiplication)
self.assertIsInstance(eqn_disc.children[0], pybamm.StateVector)
self.assertIsInstance(eqn_disc.children[1], pybamm.MatrixMultiplication)
self.assertIsInstance(eqn_disc.children[1].children[0], pybamm.Matrix)
self.assertIsInstance(eqn_disc.children[1].children[1], pybamm.StateVector)

y = combined_submesh[0].nodes ** 2
var_disc = disc.process_symbol(var)
Expand Down Expand Up @@ -602,10 +600,6 @@ def test_process_model_ode(self):
combined_submesh = mesh.combine_submeshes(*whole_cell)
disc.process_model(model)

# We cannot re-discretise after discretising a first time
with self.assertRaisesRegex(pybamm.ModelError, "Cannot re-discretise a model"):
disc.process_model(model)

y0 = model.concatenated_initial_conditions.evaluate()
np.testing.assert_array_equal(
y0, 3 * np.ones_like(combined_submesh[0].nodes[:, np.newaxis])
Expand Down Expand Up @@ -766,7 +760,6 @@ def test_process_model_dae(self):
mesh = disc.mesh

disc.process_model(model)

combined_submesh = mesh.combine_submeshes(*whole_cell)

y0 = model.concatenated_initial_conditions.evaluate()
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_expression_tree/test_operations/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def test_symbol_new_copy(self):
abs(a),
pybamm.Function(np.sin, a),
pybamm.FunctionParameter("function", {"a": a}),
pybamm.grad(a),
pybamm.div(a),
pybamm.grad(v_n),
pybamm.div(pybamm.grad(v_n)),
pybamm.Integral(a, pybamm.t),
pybamm.BoundaryValue(v_n, "right"),
pybamm.BoundaryGradient(v_n, "right"),
Expand Down
16 changes: 10 additions & 6 deletions tests/unit/test_expression_tree/test_operations/test_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class TestSimplify(unittest.TestCase):
def test_symbol_simplify(self):
a = pybamm.Scalar(0)
a = pybamm.Scalar(0, domain="domain")
b = pybamm.Scalar(1)
c = pybamm.Parameter("c")
d = pybamm.Scalar(-1)
Expand Down Expand Up @@ -58,13 +58,17 @@ def myfunction(x, y):
# Gradient
self.assertIsInstance((pybamm.grad(a)).simplify(), pybamm.Scalar)
self.assertEqual((pybamm.grad(a)).simplify().evaluate(), 0)
v = pybamm.Variable("v")
self.assertIsInstance((pybamm.grad(v)).simplify(), pybamm.Gradient)
v = pybamm.Variable("v", domain="domain")
grad_v = pybamm.grad(v)
self.assertIsInstance(grad_v.simplify(), pybamm.Gradient)

# Divergence
self.assertIsInstance((pybamm.div(a)).simplify(), pybamm.Scalar)
self.assertEqual((pybamm.div(a)).simplify().evaluate(), 0)
self.assertIsInstance((pybamm.div(v)).simplify(), pybamm.Divergence)
div_b = pybamm.div(pybamm.PrimaryBroadcastToEdges(b, "domain"))
self.assertIsInstance(div_b.simplify(), pybamm.Scalar)
self.assertEqual(div_b.simplify().evaluate(), 0)
self.assertIsInstance(
(pybamm.div(pybamm.grad(v))).simplify(), pybamm.Divergence
)

# Integral
self.assertIsInstance(
Expand Down
11 changes: 3 additions & 8 deletions tests/unit/test_expression_tree/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,6 @@ def test_symbol_repr(self):
+ r", \*, children\=\['c', 'd'\], domain=\['test'\]"
+ r", auxiliary_domains\=\{'sec': \"\['other test'\]\"\}\)",
)
self.assertRegex(
pybamm.grad(a).__repr__(),
r"Gradient\("
+ hex_regex
+ r", grad, children\=\['a'\], domain=\[\], auxiliary_domains\=\{\}\)",
)
self.assertRegex(
pybamm.grad(c).__repr__(),
r"Gradient\("
Expand Down Expand Up @@ -332,9 +326,10 @@ def test_symbol_visualise(self):
rhs.visualise("StefanMaxwell_test")

def test_has_spatial_derivatives(self):
var = pybamm.Variable("var")
var = pybamm.Variable("var", domain="test")
grad_eqn = pybamm.grad(var)
div_eqn = pybamm.div(var)
var2 = pybamm.PrimaryBroadcastToEdges(pybamm.Variable("var2"), "test")
div_eqn = pybamm.div(var2)
grad_div_eqn = pybamm.div(grad_eqn)
algebraic_eqn = 2 * var + 3
self.assertTrue(grad_eqn.has_symbol_of_classes(pybamm.Gradient))
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_expression_tree/test_unary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def test_diff(self):
spatial_a.diff(a)

def test_printing(self):
a = pybamm.Symbol("a")
a = pybamm.Symbol("a", domain="test")
self.assertEqual(str(-a), "-a")
grad = pybamm.Gradient(a)
self.assertEqual(grad.name, "grad")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ def test_public_functions(self):
a = pybamm.Scalar(0)
variables = {
"Electrolyte tortuosity": a,
"Electrolyte concentration": a,
"Electrolyte concentration": pybamm.FullBroadcast(
a,
["negative electrode", "separator", "positive electrode"],
"current collector",
),
"Negative electrode interfacial current density": pybamm.FullBroadcast(
a, "negative electrode", "current collector"
),
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/test_parameters/test_parameter_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,11 @@ def test_process_symbol(self):
self.assertEqual(processed_integ.integration_variable[0].id, x.id)

# process unary operation
grad = pybamm.Gradient(a)
v = pybamm.Variable("v", domain="test")
grad = pybamm.Gradient(v)
processed_grad = parameter_values.process_symbol(grad)
self.assertIsInstance(processed_grad, pybamm.Gradient)
self.assertIsInstance(processed_grad.children[0], pybamm.Scalar)
self.assertEqual(processed_grad.children[0].value, 1)
self.assertIsInstance(processed_grad.children[0], pybamm.Variable)

# process delta function
aa = pybamm.Parameter("a")
Expand Down Expand Up @@ -435,8 +435,8 @@ def test_process_model(self):
b = pybamm.Parameter("b")
c = pybamm.Parameter("c")
d = pybamm.Parameter("d")
var1 = pybamm.Variable("var1")
var2 = pybamm.Variable("var2")
var1 = pybamm.Variable("var1", domain="test")
var2 = pybamm.Variable("var2", domain="test")
model.rhs = {var1: a * pybamm.grad(var1)}
model.algebraic = {var2: c * var2}
model.initial_conditions = {var1: b, var2: d}
Expand Down
12 changes: 6 additions & 6 deletions tests/unit/test_solvers/test_casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ def test_model_solver(self):

# create discretisation
disc = pybamm.Discretisation()
disc.process_model(model)
model_disc = disc.process_model(model, inplace=False)
# Solve
solver = pybamm.CasadiSolver(mode="fast", rtol=1e-8, atol=1e-8)
t_eval = np.linspace(0, 1, 100)
solution = solver.solve(model, t_eval)
solution = solver.solve(model_disc, t_eval)
np.testing.assert_array_equal(solution.t, t_eval)
np.testing.assert_array_almost_equal(
solution.y[0], np.exp(0.1 * solution.t), decimal=5
Expand Down Expand Up @@ -79,20 +79,20 @@ def test_model_solver_failure(self):

# create discretisation
disc = pybamm.Discretisation()
disc.process_model(model)
model_disc = disc.process_model(model, inplace=False)

solver = pybamm.CasadiSolver(regularity_check=False)

# Solve with failure at t=2
t_eval = np.linspace(0, 20, 100)
with self.assertRaises(pybamm.SolverError):
solver.solve(model, t_eval)
solver.solve(model_disc, t_eval)
# Solve with failure at t=0
model.initial_conditions = {var: 0}
disc.process_model(model)
model_disc = disc.process_model(model, inplace=False)
t_eval = np.linspace(0, 20, 100)
with self.assertRaises(pybamm.SolverError):
solver.solve(model, t_eval)
solver.solve(model_disc, t_eval)

def test_model_solver_events(self):
# Create model
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_solvers/test_scikits_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,12 +524,12 @@ def test_model_solver_dae_events_casadi(self):
pybamm.Event("var2 = 2.5", pybamm.min(var2 - 2.5)),
]
disc = get_discretisation_for_testing()
disc.process_model(model)
model_disc = disc.process_model(model, inplace=False)

# Solve
solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8)
t_eval = np.linspace(0, 5, 100)
solution = solver.solve(model, t_eval)
solution = solver.solve(model_disc, t_eval)
np.testing.assert_array_less(solution.y[0], 1.5)
np.testing.assert_array_less(solution.y[-1], 2.5)
np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t))
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_solvers/test_scipy_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,11 @@ def test_model_solver_with_event_with_casadi(self):
mesh = get_mesh_for_testing()
spatial_methods = {"macroscale": pybamm.FiniteVolume()}
disc = pybamm.Discretisation(mesh, spatial_methods)
disc.process_model(model)
model_disc = disc.process_model(model, inplace=False)
# Solve
solver = pybamm.ScipySolver(rtol=1e-8, atol=1e-8, method="RK45")
t_eval = np.linspace(0, 10, 100)
solution = solver.solve(model, t_eval)
solution = solver.solve(model_disc, t_eval)
self.assertLess(len(solution.t), len(t_eval))
np.testing.assert_array_equal(solution.t, t_eval[: len(solution.t)])
np.testing.assert_allclose(solution.y[0], np.exp(-0.1 * solution.t))
Expand Down

0 comments on commit c4160bd

Please sign in to comment.