From 5c1217338b5018bbafc860493604baa6734f7891 Mon Sep 17 00:00:00 2001 From: Valentin Sulzer Date: Tue, 2 Aug 2022 16:03:13 -0400 Subject: [PATCH 1/7] add a few more expression simplifications --- pybamm/expression_tree/binary_operators.py | 43 ++++++++++++++----- pybamm/expression_tree/state_vector.py | 2 +- .../test_binary_operators.py | 16 ++++++- 3 files changed, 49 insertions(+), 12 deletions(-) diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index af93c71d8d..45c029df5f 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -1123,7 +1123,7 @@ def simplified_multiplication(left, right): elif isinstance(right, (Addition, Subtraction)): mul_classes = ( pybamm.Multiplication, - pybamm.MatrixMultiplication, + MatrixMultiplication, pybamm.Division, ) if ( @@ -1257,6 +1257,16 @@ def simplified_division(left, right): r_left, r_right = right.orphans return (left * r_right) / r_left + # Cancelling out common terms + if ( + isinstance(left, pybamm.Multiplication) + and isinstance(right, pybamm.Multiplication) + and left.left == right.left + ): + _, l_right = left.orphans + _, r_right = right.orphans + return l_right / r_right + # Negation simplifications if isinstance(left, pybamm.Negate) and isinstance(right, pybamm.Negate): # Double negation cancels out @@ -1275,7 +1285,7 @@ def simplified_division(left, right): def simplified_matrix_multiplication(left, right): left, right = preprocess_binary(left, right) if pybamm.is_matrix_zero(left) or pybamm.is_matrix_zero(right): - return pybamm.zeros_like(pybamm.MatrixMultiplication(left, right)) + return pybamm.zeros_like(MatrixMultiplication(left, right)) if isinstance(right, Multiplication) and left.is_constant(): # Simplify A @ (b * c) to (A * b) @ c if (A * b) is constant @@ -1309,18 +1319,31 @@ def simplified_matrix_multiplication(left, right): new_mul.copy_domains(right) return new_mul - # Simplify A @ (b + c) to (A @ b) + (A @ c) if (A @ b) or (A @ c) is constant - # This is a common construction that appears from discretisation of spatial - # operators - # Don't do this if either b or c is a number as this will lead to matmul errors - elif isinstance(right, Addition): - if (right.left.is_constant() or right.right.is_constant()) and not ( + elif isinstance(right, (Addition, Subtraction)): + # Simplify A @ (b +- c) to (A @ b) +- (A @ c) if (A @ b) or (A @ c) is constant + # This is a common construction that appears from discretisation of spatial + # operators + # Or simplify A @ (B @ b +- C @ c) to (A @ B @ b) +- (A @ C @ c) if (A @ B) + # and (A @ C) are constant + # Don't do this if either b or c is a number as this will lead to matmul errors + if ( + (right.left.is_constant() or right.right.is_constant()) + or ( + isinstance(right.left, MatrixMultiplication) + and right.left.left.is_constant() + and isinstance(right.right, MatrixMultiplication) + and right.right.left.is_constant() + ) + ) and not ( right.left.size_for_testing == 1 or right.right.size_for_testing == 1 ): r_left, r_right = right.orphans - return (left @ r_left) + (left @ r_right) + if isinstance(right, Addition): + return (left @ r_left) + (left @ r_right) + elif isinstance(right, Subtraction): + return (left @ r_left) - (left @ r_right) - return pybamm.simplify_if_constant(pybamm.MatrixMultiplication(left, right)) + return pybamm.simplify_if_constant(MatrixMultiplication(left, right)) def minimum(left, right): diff --git a/pybamm/expression_tree/state_vector.py b/pybamm/expression_tree/state_vector.py index 920f553e8c..68c6a4fe5d 100644 --- a/pybamm/expression_tree/state_vector.py +++ b/pybamm/expression_tree/state_vector.py @@ -49,7 +49,7 @@ def __init__( raise TypeError("all y_slices must be slice objects") if name is None: if y_slices[0].start is None: - name = base_name + "[:{:d}]".format(y_slice.stop) + name = base_name + "[:{:d}".format(y_slice.stop) else: name = base_name + "[{:d}:{:d}".format( y_slices[0].start, y_slices[0].stop diff --git a/tests/unit/test_expression_tree/test_binary_operators.py b/tests/unit/test_expression_tree/test_binary_operators.py index 2625df5da4..b1ebc16342 100644 --- a/tests/unit/test_expression_tree/test_binary_operators.py +++ b/tests/unit/test_expression_tree/test_binary_operators.py @@ -5,7 +5,7 @@ import numpy as np import sympy -from scipy.sparse.coo import coo_matrix +from scipy.sparse import coo_matrix import pybamm @@ -521,6 +521,8 @@ def test_binary_simplifications(self): # division by zero with self.assertRaises(ZeroDivisionError): b / a + # division with a common term + self.assertEqual((2 * c) / (2 * var), (c / var)) def test_binary_simplifications_concatenations(self): def conc_broad(x, y, z): @@ -556,7 +558,9 @@ def test_advanced_binary_simplifications(self): # MatMul simplifications that often appear when discretising spatial operators A = pybamm.Matrix(np.random.rand(10, 10)) B = pybamm.Matrix(np.random.rand(10, 10)) + C = pybamm.Matrix(np.random.rand(10, 10)) var = pybamm.StateVector(slice(0, 10)) + var2 = pybamm.StateVector(slice(10, 20)) vec = pybamm.Vector(np.random.rand(10)) # Do A@B first if it is constant @@ -567,9 +571,19 @@ def test_advanced_binary_simplifications(self): # constant expr = A @ (var + vec) self.assertEqual(expr, ((A @ var) + (A @ vec))) + expr = A @ (var - vec) + self.assertEqual(expr, ((A @ var) - (A @ vec))) expr = A @ ((B @ var) + vec) self.assertEqual(expr, (((A @ B) @ var) + (A @ vec))) + expr = A @ ((B @ var) - vec) + self.assertEqual(expr, (((A @ B) @ var) - (A @ vec))) + + # Distribute the @ operator to a sum if both symbols being summed are matmuls + expr = A @ (B @ var + C @ var2) + self.assertEqual(expr, ((A @ B) @ var + (A @ C) @ var2)) + expr = A @ (B @ var - C @ var2) + self.assertEqual(expr, ((A @ B) @ var - (A @ C) @ var2)) # Reduce (A@var + B@var) to ((A+B)@var) expr = A @ var + B @ var From 610f0dcce4c1c30b2080ea1941ee0ab7e3c84705 Mon Sep 17 00:00:00 2001 From: Valentin Sulzer Date: Tue, 2 Aug 2022 16:06:38 -0400 Subject: [PATCH 2/7] renaming --- pybamm/expression_tree/binary_operators.py | 46 ++++++++++------------ 1 file changed, 21 insertions(+), 25 deletions(-) diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index 45c029df5f..5d4671f858 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -837,7 +837,7 @@ def simplified_addition(left, right): # Return constant if both sides are constant if left.is_constant() and right.is_constant(): - return pybamm.simplify_if_constant(pybamm.Addition(left, right)) + return pybamm.simplify_if_constant(Addition(left, right)) # Simplify A @ c + B @ c to (A + B) @ c if (A + B) is constant # This is a common construction that appears from discretisation of spatial @@ -852,10 +852,10 @@ def simplified_addition(left, right): new_left = l_left + r_left if new_left.is_constant(): new_sum = new_left @ l_right - new_sum.copy_domains(pybamm.Addition(left, right)) + new_sum.copy_domains(Addition(left, right)) return new_sum - if isinstance(right, pybamm.Addition) and left.is_constant(): + if isinstance(right, Addition) and left.is_constant(): # Simplify a + (b + c) to (a + b) + c if (a + b) is constant if right.left.is_constant(): r_left, r_right = right.orphans @@ -864,7 +864,7 @@ def simplified_addition(left, right): elif right.right.is_constant(): r_left, r_right = right.orphans return (left + r_right) + r_left - elif isinstance(right, pybamm.Subtraction) and left.is_constant(): + elif isinstance(right, Subtraction) and left.is_constant(): # Simplify a + (b - c) to (a + b) - c if (a + b) is constant if right.left.is_constant(): r_left, r_right = right.orphans @@ -873,7 +873,7 @@ def simplified_addition(left, right): elif right.right.is_constant(): r_left, r_right = right.orphans return (left - r_right) + r_left - if isinstance(left, pybamm.Addition) and right.is_constant(): + if isinstance(left, Addition) and right.is_constant(): # Simplify (a + b) + c to a + (b + c) if (b + c) is constant if left.right.is_constant(): l_left, l_right = left.orphans @@ -882,7 +882,7 @@ def simplified_addition(left, right): elif left.left.is_constant(): l_left, l_right = left.orphans return (l_left + right) + l_right - elif isinstance(left, pybamm.Subtraction) and right.is_constant(): + elif isinstance(left, Subtraction) and right.is_constant(): # Simplify (a - b) + c to a + (c - b) if (c - b) is constant if left.right.is_constant(): l_left, l_right = left.orphans @@ -892,7 +892,7 @@ def simplified_addition(left, right): l_left, l_right = left.orphans return (l_left + right) - l_right - return pybamm.simplify_if_constant(pybamm.Addition(left, right)) + return pybamm.simplify_if_constant(Addition(left, right)) def simplified_subtraction(left, right): @@ -953,7 +953,7 @@ def simplified_subtraction(left, right): if left == right: return pybamm.zeros_like(left) - if isinstance(right, pybamm.Addition) and left.is_constant(): + if isinstance(right, Addition) and left.is_constant(): # Simplify a - (b + c) to (a - b) - c if (a - b) is constant if right.left.is_constant(): r_left, r_right = right.orphans @@ -962,7 +962,7 @@ def simplified_subtraction(left, right): elif right.right.is_constant(): r_left, r_right = right.orphans return (left - r_right) - r_left - elif isinstance(right, pybamm.Subtraction) and left.is_constant(): + elif isinstance(right, Subtraction) and left.is_constant(): # Simplify a - (b - c) to (a - b) + c if (a - b) is constant if right.left.is_constant(): r_left, r_right = right.orphans @@ -971,7 +971,7 @@ def simplified_subtraction(left, right): elif right.right.is_constant(): r_left, r_right = right.orphans return (left + r_right) - r_left - if isinstance(left, pybamm.Addition) and right.is_constant(): + if isinstance(left, Addition) and right.is_constant(): # Simplify (a + b) - c to a + (b - c) if (b - c) is constant if left.right.is_constant(): l_left, l_right = left.orphans @@ -980,7 +980,7 @@ def simplified_subtraction(left, right): elif left.left.is_constant(): l_left, l_right = left.orphans return (l_left - right) + l_right - elif isinstance(left, pybamm.Subtraction) and right.is_constant(): + elif isinstance(left, Subtraction) and right.is_constant(): # Simplify (a - b) - c to a - (c + b) if (c + b) is constant if left.right.is_constant(): l_left, l_right = left.orphans @@ -990,7 +990,7 @@ def simplified_subtraction(left, right): l_left, l_right = left.orphans return (l_left - right) - l_right - return pybamm.simplify_if_constant(pybamm.Subtraction(left, right)) + return pybamm.simplify_if_constant(Subtraction(left, right)) def simplified_multiplication(left, right): @@ -1011,7 +1011,7 @@ def simplified_multiplication(left, right): # if one of the children is a zero matrix, we have to be careful about shapes if pybamm.is_matrix_zero(left) or pybamm.is_matrix_zero(right): - return pybamm.zeros_like(pybamm.Multiplication(left, right)) + return pybamm.zeros_like(Multiplication(left, right)) # anything multiplied by a scalar one returns itself if pybamm.is_scalar_one(left): @@ -1027,7 +1027,7 @@ def simplified_multiplication(left, right): # Return constant if both sides are constant if left.is_constant() and right.is_constant(): - return pybamm.simplify_if_constant(pybamm.Multiplication(left, right)) + return pybamm.simplify_if_constant(Multiplication(left, right)) # anything multiplied by a matrix one returns itself if # - the shapes are the same @@ -1121,11 +1121,7 @@ def simplified_multiplication(left, right): # operators # Also do this for cases like a * (b @ c + d) where (a * b) is constant elif isinstance(right, (Addition, Subtraction)): - mul_classes = ( - pybamm.Multiplication, - MatrixMultiplication, - pybamm.Division, - ) + mul_classes = (Multiplication, MatrixMultiplication, Division) if ( right.left.is_constant() or right.right.is_constant() @@ -1152,7 +1148,7 @@ def simplified_multiplication(left, right): # Simplify a * (-b) to (-a) * b if (-a) is constant return (-left) * right.orphans[0] - return pybamm.Multiplication(left, right) + return Multiplication(left, right) def simplified_division(left, right): @@ -1169,7 +1165,7 @@ def simplified_division(left, right): # matrix zero divided by anything returns matrix zero (i.e. itself) if pybamm.is_matrix_zero(left): - return pybamm.zeros_like(pybamm.Division(left, right)) + return pybamm.zeros_like(Division(left, right)) # anything divided by zero raises error if pybamm.is_scalar_zero(right): @@ -1204,7 +1200,7 @@ def simplified_division(left, right): # Return constant if both sides are constant if left.is_constant() and right.is_constant(): - return pybamm.simplify_if_constant(pybamm.Division(left, right)) + return pybamm.simplify_if_constant(Division(left, right)) # Simplify (B @ c) / a to (B / a) @ c if (B / a) is constant # This is a common construction that appears from discretisation of averages @@ -1259,8 +1255,8 @@ def simplified_division(left, right): # Cancelling out common terms if ( - isinstance(left, pybamm.Multiplication) - and isinstance(right, pybamm.Multiplication) + isinstance(left, Multiplication) + and isinstance(right, Multiplication) and left.left == right.left ): _, l_right = left.orphans @@ -1279,7 +1275,7 @@ def simplified_division(left, right): # Simplify a / (-b) to (-a) / b if (-a) is constant return (-left) / right.orphans[0] - return pybamm.simplify_if_constant(pybamm.Division(left, right)) + return pybamm.simplify_if_constant(Division(left, right)) def simplified_matrix_multiplication(left, right): From fbb39542d586fc823bcad06b14f289577c43904b Mon Sep 17 00:00:00 2001 From: Valentin Sulzer Date: Wed, 3 Aug 2022 11:52:54 -0400 Subject: [PATCH 3/7] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index dd3159739d..630a13e7b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ ## Optimizations +- Added more rules for simplifying expressions ([#2211](https://github.com/pybamm-team/PyBaMM/pull/2211)) - Improved eSOH calculations to be more robust ([#2192](https://github.com/pybamm-team/PyBaMM/pull/2192),[#2199](https://github.com/pybamm-team/PyBaMM/pull/2199)) - The (2x2x2=8) particle diffusion submodels have been consolidated into just three submodels (Fickian diffusion, polynomial profile, and x-averaged polynomial profile) with optional x-averaging and size distribution. Polynomial profile and x-averaged polynomial profile are still two separate submodels, since they deal with surface concentration differently. - Added error for when solution vector gets too large, to help debug solver errors ([#2138](https://github.com/pybamm-team/PyBaMM/pull/2138)) From be6d7860eeb7730ce63119b5661b615ade0e2e33 Mon Sep 17 00:00:00 2001 From: Valentin Sulzer Date: Wed, 3 Aug 2022 12:30:15 -0400 Subject: [PATCH 4/7] fix 1+1D bug --- pybamm/expression_tree/binary_operators.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index 5d4671f858..7f70ffce70 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -1334,6 +1334,8 @@ def simplified_matrix_multiplication(left, right): right.left.size_for_testing == 1 or right.right.size_for_testing == 1 ): r_left, r_right = right.orphans + r_left.domains = right.domains + r_right.domains = right.domains if isinstance(right, Addition): return (left @ r_left) + (left @ r_right) elif isinstance(right, Subtraction): From 3fcf9d884a86588b28122c26bd6f1e0d34a18b61 Mon Sep 17 00:00:00 2001 From: Valentin Sulzer Date: Thu, 4 Aug 2022 12:08:54 -0400 Subject: [PATCH 5/7] comment out simplification for matmuls, possibly failing because of poor model conditioning --- pybamm/expression_tree/binary_operators.py | 14 ++++++++------ .../test_expression_tree/test_binary_operators.py | 8 ++++---- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index 7f70ffce70..ef16b5e187 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -1324,12 +1324,14 @@ def simplified_matrix_multiplication(left, right): # Don't do this if either b or c is a number as this will lead to matmul errors if ( (right.left.is_constant() or right.right.is_constant()) - or ( - isinstance(right.left, MatrixMultiplication) - and right.left.left.is_constant() - and isinstance(right.right, MatrixMultiplication) - and right.right.left.is_constant() - ) + # these lines should work but don't, possibly because of poorly + # conditioned model? + # or ( + # isinstance(right.left, MatrixMultiplication) + # and right.left.left.is_constant() + # and isinstance(right.right, MatrixMultiplication) + # and right.right.left.is_constant() + # ) ) and not ( right.left.size_for_testing == 1 or right.right.size_for_testing == 1 ): diff --git a/tests/unit/test_expression_tree/test_binary_operators.py b/tests/unit/test_expression_tree/test_binary_operators.py index b1ebc16342..ed8fa82ba8 100644 --- a/tests/unit/test_expression_tree/test_binary_operators.py +++ b/tests/unit/test_expression_tree/test_binary_operators.py @@ -580,10 +580,10 @@ def test_advanced_binary_simplifications(self): self.assertEqual(expr, (((A @ B) @ var) - (A @ vec))) # Distribute the @ operator to a sum if both symbols being summed are matmuls - expr = A @ (B @ var + C @ var2) - self.assertEqual(expr, ((A @ B) @ var + (A @ C) @ var2)) - expr = A @ (B @ var - C @ var2) - self.assertEqual(expr, ((A @ B) @ var - (A @ C) @ var2)) + # expr = A @ (B @ var + C @ var2) + # self.assertEqual(expr, ((A @ B) @ var + (A @ C) @ var2)) + # expr = A @ (B @ var - C @ var2) + # self.assertEqual(expr, ((A @ B) @ var - (A @ C) @ var2)) # Reduce (A@var + B@var) to ((A+B)@var) expr = A @ var + B @ var From b2e4f4d383990e78f66893ab45939be688789b40 Mon Sep 17 00:00:00 2001 From: Valentin Sulzer Date: Fri, 5 Aug 2022 09:44:34 -0400 Subject: [PATCH 6/7] flake8 --- CHANGELOG.md | 5 ++++- tests/unit/test_expression_tree/test_binary_operators.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b717fdae6..cc9c297793 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # [Unreleased](https://github.com/pybamm-team/PyBaMM/) +## Optimizations + +- Added more rules for simplifying expressions ([#2211](https://github.com/pybamm-team/PyBaMM/pull/2211)) + # [v22.7](https://github.com/pybamm-team/PyBaMM/tree/v22.7) - 2022-07-31 ## Features @@ -15,7 +19,6 @@ ## Optimizations -- Added more rules for simplifying expressions ([#2211](https://github.com/pybamm-team/PyBaMM/pull/2211)) - Improved eSOH calculations to be more robust ([#2192](https://github.com/pybamm-team/PyBaMM/pull/2192),[#2199](https://github.com/pybamm-team/PyBaMM/pull/2199)) - The (2x2x2=8) particle diffusion submodels have been consolidated into just three submodels (Fickian diffusion, polynomial profile, and x-averaged polynomial profile) with optional x-averaging and size distribution. Polynomial profile and x-averaged polynomial profile are still two separate submodels, since they deal with surface concentration differently. - Added error for when solution vector gets too large, to help debug solver errors ([#2138](https://github.com/pybamm-team/PyBaMM/pull/2138)) diff --git a/tests/unit/test_expression_tree/test_binary_operators.py b/tests/unit/test_expression_tree/test_binary_operators.py index ed8fa82ba8..cb363200ef 100644 --- a/tests/unit/test_expression_tree/test_binary_operators.py +++ b/tests/unit/test_expression_tree/test_binary_operators.py @@ -558,9 +558,9 @@ def test_advanced_binary_simplifications(self): # MatMul simplifications that often appear when discretising spatial operators A = pybamm.Matrix(np.random.rand(10, 10)) B = pybamm.Matrix(np.random.rand(10, 10)) - C = pybamm.Matrix(np.random.rand(10, 10)) + # C = pybamm.Matrix(np.random.rand(10, 10)) var = pybamm.StateVector(slice(0, 10)) - var2 = pybamm.StateVector(slice(10, 20)) + # var2 = pybamm.StateVector(slice(10, 20)) vec = pybamm.Vector(np.random.rand(10)) # Do A@B first if it is constant From e28b196878918ee275bdbe598b270f50cdaaafe3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 Oct 2022 14:28:06 +0000 Subject: [PATCH 7/7] style: pre-commit fixes --- CHANGELOG.md | 4 ++-- tests/unit/test_expression_tree/test_binary_operators.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ed48a90387..e3d12bd88c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,8 +6,8 @@ ## Optimizations -- Added more rules for simplifying expressions ([#2211](https://github.com/pybamm-team/PyBaMM/pull/2211)) -- Sped up calculations of Electrode SOH variables for summary variables ([#2210](https://github.com/pybamm-team/PyBaMM/pull/2210)) +- Added more rules for simplifying expressions ([#2211](https://github.com/pybamm-team/PyBaMM/pull/2211)) +- Sped up calculations of Electrode SOH variables for summary variables ([#2210](https://github.com/pybamm-team/PyBaMM/pull/2210)) # [v22.9](https://github.com/pybamm-team/PyBaMM/tree/v22.9) - 2022-09-30 diff --git a/tests/unit/test_expression_tree/test_binary_operators.py b/tests/unit/test_expression_tree/test_binary_operators.py index b82a0113f2..58fba662ff 100644 --- a/tests/unit/test_expression_tree/test_binary_operators.py +++ b/tests/unit/test_expression_tree/test_binary_operators.py @@ -532,7 +532,7 @@ def test_binary_simplifications(self): # division with a common term self.assertEqual((2 * c) / (2 * var), (c / var)) - self.assertEqual((c * 2) / (var*2), (c / var)) + self.assertEqual((c * 2) / (var * 2), (c / var)) def test_binary_simplifications_concatenations(self): def conc_broad(x, y, z):