From c4160bd07dda0fa84e0843f2d6aca686a04697e6 Mon Sep 17 00:00:00 2001 From: Valentin Sulzer Date: Sun, 19 Apr 2020 16:47:55 -0400 Subject: [PATCH] #963 fix unit tests --- pybamm/expression_tree/operations/simplify.py | 19 +++++++++++++------ pybamm/expression_tree/unary_operators.py | 10 ++++++++++ .../test_discretisation.py | 11 ++--------- .../test_operations/test_copy.py | 4 ++-- .../test_operations/test_simplify.py | 16 ++++++++++------ .../unit/test_expression_tree/test_symbol.py | 11 +++-------- .../test_unary_operators.py | 2 +- .../test_full_conductivity.py | 6 +++++- .../test_parameters/test_parameter_values.py | 10 +++++----- tests/unit/test_solvers/test_casadi_solver.py | 12 ++++++------ .../unit/test_solvers/test_scikits_solvers.py | 4 ++-- tests/unit/test_solvers/test_scipy_solver.py | 4 ++-- 12 files changed, 61 insertions(+), 48 deletions(-) diff --git a/pybamm/expression_tree/operations/simplify.py b/pybamm/expression_tree/operations/simplify.py index df033cea40..60282c893e 100644 --- a/pybamm/expression_tree/operations/simplify.py +++ b/pybamm/expression_tree/operations/simplify.py @@ -567,7 +567,7 @@ class Simplification(object): def __init__(self, simplified_symbols=None): self._simplified_symbols = simplified_symbols or {} - def simplify(self, symbol): + def simplify(self, symbol, clear_domains=True): """ This function recurses down the tree, applying any simplifications defined in classes derived from pybamm.Symbol. E.g. any expression multiplied by a @@ -577,7 +577,9 @@ def simplify(self, symbol): Parameters ---------- symbol : :class:`pybamm.Symbol` - The symbol to simplify + The symbol to simplify + clear_domains : bool + Whether to remove a symbol's domain when simplifying. Default is True. Returns ------- @@ -588,15 +590,16 @@ def simplify(self, symbol): try: return self._simplified_symbols[symbol.id] except KeyError: - simplified_symbol = self._simplify(symbol) + simplified_symbol = self._simplify(symbol, clear_domains) self._simplified_symbols[symbol.id] = simplified_symbol return simplified_symbol - def _simplify(self, symbol): + def _simplify(self, symbol, clear_domains=True): """ See :meth:`Simplification.simplify()`. """ - symbol.clear_domains() + if clear_domains: + symbol.clear_domains() if isinstance(symbol, pybamm.BinaryOperator): left, right = symbol.children @@ -607,7 +610,11 @@ def _simplify(self, symbol): new_symbol = symbol._binary_simplify(new_left, new_right) elif isinstance(symbol, pybamm.UnaryOperator): - new_child = self.simplify(symbol.child) + # Reassign domain for gradient and divergence + if isinstance(symbol, (pybamm.Gradient, pybamm.Divergence)): + new_child = self.simplify(symbol.child, clear_domains=False) + else: + new_child = self.simplify(symbol.child) # _unary_simplify defined in derived classes for specific rules new_symbol = symbol._unary_simplify(new_child) diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index 78e0634b66..c02244726d 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -316,6 +316,10 @@ def __init__(self, child): + "Try broadcasting the object first, e.g.\n\n" "\tpybamm.grad(pybamm.PrimaryBroadcast(symbol, 'domain'))" ) + if child.evaluates_on_edges() is True: + raise TypeError( + "Cannot take gradient of '{}' since it evaluates on edges".format(child) + ) super().__init__("grad", child) def evaluates_on_edges(self): @@ -338,6 +342,12 @@ def __init__(self, child): + "Try broadcasting the object first, e.g.\n\n" "\tpybamm.div(pybamm.PrimaryBroadcast(symbol, 'domain'))" ) + if child.evaluates_on_edges() is False: + raise TypeError( + "Cannot take divergence of '{}' since it does not ".format(child) + + "evaluate on edges. Usually, a gradient should be taken before the " + "divergence." + ) super().__init__("div", child) def evaluates_on_edges(self): diff --git a/tests/unit/test_discretisations/test_discretisation.py b/tests/unit/test_discretisations/test_discretisation.py index 8ffb8323b0..5bc0aa1197 100644 --- a/tests/unit/test_discretisations/test_discretisation.py +++ b/tests/unit/test_discretisations/test_discretisation.py @@ -475,12 +475,11 @@ def test_discretise_spatial_operator(self): disc.set_variable_slices(variables) # Simple expressions - for eqn in [pybamm.grad(var), pybamm.div(var)]: + for eqn in [pybamm.grad(var), pybamm.div(pybamm.grad(var))]: eqn_disc = disc.process_symbol(eqn) self.assertIsInstance(eqn_disc, pybamm.MatrixMultiplication) self.assertIsInstance(eqn_disc.children[0], pybamm.Matrix) - self.assertIsInstance(eqn_disc.children[1], pybamm.StateVector) combined_submesh = mesh.combine_submeshes(*whole_cell) y = combined_submesh[0].nodes ** 2 @@ -491,14 +490,13 @@ def test_discretise_spatial_operator(self): ) # More complex expressions - for eqn in [var * pybamm.grad(var), var * pybamm.div(var)]: + for eqn in [var * pybamm.grad(var), var * pybamm.div(pybamm.grad(var))]: eqn_disc = disc.process_symbol(eqn) self.assertIsInstance(eqn_disc, pybamm.Multiplication) self.assertIsInstance(eqn_disc.children[0], pybamm.StateVector) self.assertIsInstance(eqn_disc.children[1], pybamm.MatrixMultiplication) self.assertIsInstance(eqn_disc.children[1].children[0], pybamm.Matrix) - self.assertIsInstance(eqn_disc.children[1].children[1], pybamm.StateVector) y = combined_submesh[0].nodes ** 2 var_disc = disc.process_symbol(var) @@ -602,10 +600,6 @@ def test_process_model_ode(self): combined_submesh = mesh.combine_submeshes(*whole_cell) disc.process_model(model) - # We cannot re-discretise after discretising a first time - with self.assertRaisesRegex(pybamm.ModelError, "Cannot re-discretise a model"): - disc.process_model(model) - y0 = model.concatenated_initial_conditions.evaluate() np.testing.assert_array_equal( y0, 3 * np.ones_like(combined_submesh[0].nodes[:, np.newaxis]) @@ -766,7 +760,6 @@ def test_process_model_dae(self): mesh = disc.mesh disc.process_model(model) - combined_submesh = mesh.combine_submeshes(*whole_cell) y0 = model.concatenated_initial_conditions.evaluate() diff --git a/tests/unit/test_expression_tree/test_operations/test_copy.py b/tests/unit/test_expression_tree/test_operations/test_copy.py index e3fdb43c7d..a72faf6495 100644 --- a/tests/unit/test_expression_tree/test_operations/test_copy.py +++ b/tests/unit/test_expression_tree/test_operations/test_copy.py @@ -26,8 +26,8 @@ def test_symbol_new_copy(self): abs(a), pybamm.Function(np.sin, a), pybamm.FunctionParameter("function", {"a": a}), - pybamm.grad(a), - pybamm.div(a), + pybamm.grad(v_n), + pybamm.div(pybamm.grad(v_n)), pybamm.Integral(a, pybamm.t), pybamm.BoundaryValue(v_n, "right"), pybamm.BoundaryGradient(v_n, "right"), diff --git a/tests/unit/test_expression_tree/test_operations/test_simplify.py b/tests/unit/test_expression_tree/test_operations/test_simplify.py index d898709a7f..47d663a667 100644 --- a/tests/unit/test_expression_tree/test_operations/test_simplify.py +++ b/tests/unit/test_expression_tree/test_operations/test_simplify.py @@ -10,7 +10,7 @@ class TestSimplify(unittest.TestCase): def test_symbol_simplify(self): - a = pybamm.Scalar(0) + a = pybamm.Scalar(0, domain="domain") b = pybamm.Scalar(1) c = pybamm.Parameter("c") d = pybamm.Scalar(-1) @@ -58,13 +58,17 @@ def myfunction(x, y): # Gradient self.assertIsInstance((pybamm.grad(a)).simplify(), pybamm.Scalar) self.assertEqual((pybamm.grad(a)).simplify().evaluate(), 0) - v = pybamm.Variable("v") - self.assertIsInstance((pybamm.grad(v)).simplify(), pybamm.Gradient) + v = pybamm.Variable("v", domain="domain") + grad_v = pybamm.grad(v) + self.assertIsInstance(grad_v.simplify(), pybamm.Gradient) # Divergence - self.assertIsInstance((pybamm.div(a)).simplify(), pybamm.Scalar) - self.assertEqual((pybamm.div(a)).simplify().evaluate(), 0) - self.assertIsInstance((pybamm.div(v)).simplify(), pybamm.Divergence) + div_b = pybamm.div(pybamm.PrimaryBroadcastToEdges(b, "domain")) + self.assertIsInstance(div_b.simplify(), pybamm.Scalar) + self.assertEqual(div_b.simplify().evaluate(), 0) + self.assertIsInstance( + (pybamm.div(pybamm.grad(v))).simplify(), pybamm.Divergence + ) # Integral self.assertIsInstance( diff --git a/tests/unit/test_expression_tree/test_symbol.py b/tests/unit/test_expression_tree/test_symbol.py index 52638aebbd..3353fc258e 100644 --- a/tests/unit/test_expression_tree/test_symbol.py +++ b/tests/unit/test_expression_tree/test_symbol.py @@ -271,12 +271,6 @@ def test_symbol_repr(self): + r", \*, children\=\['c', 'd'\], domain=\['test'\]" + r", auxiliary_domains\=\{'sec': \"\['other test'\]\"\}\)", ) - self.assertRegex( - pybamm.grad(a).__repr__(), - r"Gradient\(" - + hex_regex - + r", grad, children\=\['a'\], domain=\[\], auxiliary_domains\=\{\}\)", - ) self.assertRegex( pybamm.grad(c).__repr__(), r"Gradient\(" @@ -332,9 +326,10 @@ def test_symbol_visualise(self): rhs.visualise("StefanMaxwell_test") def test_has_spatial_derivatives(self): - var = pybamm.Variable("var") + var = pybamm.Variable("var", domain="test") grad_eqn = pybamm.grad(var) - div_eqn = pybamm.div(var) + var2 = pybamm.PrimaryBroadcastToEdges(pybamm.Variable("var2"), "test") + div_eqn = pybamm.div(var2) grad_div_eqn = pybamm.div(grad_eqn) algebraic_eqn = 2 * var + 3 self.assertTrue(grad_eqn.has_symbol_of_classes(pybamm.Gradient)) diff --git a/tests/unit/test_expression_tree/test_unary_operators.py b/tests/unit/test_expression_tree/test_unary_operators.py index b43bd8e9ab..5610df749b 100644 --- a/tests/unit/test_expression_tree/test_unary_operators.py +++ b/tests/unit/test_expression_tree/test_unary_operators.py @@ -197,7 +197,7 @@ def test_diff(self): spatial_a.diff(a) def test_printing(self): - a = pybamm.Symbol("a") + a = pybamm.Symbol("a", domain="test") self.assertEqual(str(-a), "-a") grad = pybamm.Gradient(a) self.assertEqual(grad.name, "grad") diff --git a/tests/unit/test_models/test_submodels/test_electrolyte_conductivity/test_full_conductivity.py b/tests/unit/test_models/test_submodels/test_electrolyte_conductivity/test_full_conductivity.py index 4f16b15b15..7056c1bb36 100644 --- a/tests/unit/test_models/test_submodels/test_electrolyte_conductivity/test_full_conductivity.py +++ b/tests/unit/test_models/test_submodels/test_electrolyte_conductivity/test_full_conductivity.py @@ -13,7 +13,11 @@ def test_public_functions(self): a = pybamm.Scalar(0) variables = { "Electrolyte tortuosity": a, - "Electrolyte concentration": a, + "Electrolyte concentration": pybamm.FullBroadcast( + a, + ["negative electrode", "separator", "positive electrode"], + "current collector", + ), "Negative electrode interfacial current density": pybamm.FullBroadcast( a, "negative electrode", "current collector" ), diff --git a/tests/unit/test_parameters/test_parameter_values.py b/tests/unit/test_parameters/test_parameter_values.py index 351f7b23e5..27fb55a88c 100644 --- a/tests/unit/test_parameters/test_parameter_values.py +++ b/tests/unit/test_parameters/test_parameter_values.py @@ -127,11 +127,11 @@ def test_process_symbol(self): self.assertEqual(processed_integ.integration_variable[0].id, x.id) # process unary operation - grad = pybamm.Gradient(a) + v = pybamm.Variable("v", domain="test") + grad = pybamm.Gradient(v) processed_grad = parameter_values.process_symbol(grad) self.assertIsInstance(processed_grad, pybamm.Gradient) - self.assertIsInstance(processed_grad.children[0], pybamm.Scalar) - self.assertEqual(processed_grad.children[0].value, 1) + self.assertIsInstance(processed_grad.children[0], pybamm.Variable) # process delta function aa = pybamm.Parameter("a") @@ -435,8 +435,8 @@ def test_process_model(self): b = pybamm.Parameter("b") c = pybamm.Parameter("c") d = pybamm.Parameter("d") - var1 = pybamm.Variable("var1") - var2 = pybamm.Variable("var2") + var1 = pybamm.Variable("var1", domain="test") + var2 = pybamm.Variable("var2", domain="test") model.rhs = {var1: a * pybamm.grad(var1)} model.algebraic = {var2: c * var2} model.initial_conditions = {var1: b, var2: d} diff --git a/tests/unit/test_solvers/test_casadi_solver.py b/tests/unit/test_solvers/test_casadi_solver.py index 1018baa68a..020fce254d 100644 --- a/tests/unit/test_solvers/test_casadi_solver.py +++ b/tests/unit/test_solvers/test_casadi_solver.py @@ -23,11 +23,11 @@ def test_model_solver(self): # create discretisation disc = pybamm.Discretisation() - disc.process_model(model) + model_disc = disc.process_model(model, inplace=False) # Solve solver = pybamm.CasadiSolver(mode="fast", rtol=1e-8, atol=1e-8) t_eval = np.linspace(0, 1, 100) - solution = solver.solve(model, t_eval) + solution = solver.solve(model_disc, t_eval) np.testing.assert_array_equal(solution.t, t_eval) np.testing.assert_array_almost_equal( solution.y[0], np.exp(0.1 * solution.t), decimal=5 @@ -79,20 +79,20 @@ def test_model_solver_failure(self): # create discretisation disc = pybamm.Discretisation() - disc.process_model(model) + model_disc = disc.process_model(model, inplace=False) solver = pybamm.CasadiSolver(regularity_check=False) # Solve with failure at t=2 t_eval = np.linspace(0, 20, 100) with self.assertRaises(pybamm.SolverError): - solver.solve(model, t_eval) + solver.solve(model_disc, t_eval) # Solve with failure at t=0 model.initial_conditions = {var: 0} - disc.process_model(model) + model_disc = disc.process_model(model, inplace=False) t_eval = np.linspace(0, 20, 100) with self.assertRaises(pybamm.SolverError): - solver.solve(model, t_eval) + solver.solve(model_disc, t_eval) def test_model_solver_events(self): # Create model diff --git a/tests/unit/test_solvers/test_scikits_solvers.py b/tests/unit/test_solvers/test_scikits_solvers.py index c646075a3a..dbaf00cb6e 100644 --- a/tests/unit/test_solvers/test_scikits_solvers.py +++ b/tests/unit/test_solvers/test_scikits_solvers.py @@ -524,12 +524,12 @@ def test_model_solver_dae_events_casadi(self): pybamm.Event("var2 = 2.5", pybamm.min(var2 - 2.5)), ] disc = get_discretisation_for_testing() - disc.process_model(model) + model_disc = disc.process_model(model, inplace=False) # Solve solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8) t_eval = np.linspace(0, 5, 100) - solution = solver.solve(model, t_eval) + solution = solver.solve(model_disc, t_eval) np.testing.assert_array_less(solution.y[0], 1.5) np.testing.assert_array_less(solution.y[-1], 2.5) np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t)) diff --git a/tests/unit/test_solvers/test_scipy_solver.py b/tests/unit/test_solvers/test_scipy_solver.py index e7393927ce..77fe891862 100644 --- a/tests/unit/test_solvers/test_scipy_solver.py +++ b/tests/unit/test_solvers/test_scipy_solver.py @@ -254,11 +254,11 @@ def test_model_solver_with_event_with_casadi(self): mesh = get_mesh_for_testing() spatial_methods = {"macroscale": pybamm.FiniteVolume()} disc = pybamm.Discretisation(mesh, spatial_methods) - disc.process_model(model) + model_disc = disc.process_model(model, inplace=False) # Solve solver = pybamm.ScipySolver(rtol=1e-8, atol=1e-8, method="RK45") t_eval = np.linspace(0, 10, 100) - solution = solver.solve(model, t_eval) + solution = solver.solve(model_disc, t_eval) self.assertLess(len(solution.t), len(t_eval)) np.testing.assert_array_equal(solution.t, t_eval[: len(solution.t)]) np.testing.assert_allclose(solution.y[0], np.exp(-0.1 * solution.t))