Skip to content

Commit

Permalink
Merge pull request #2211 from pybamm-team/more-expression-simplificat…
Browse files Browse the repository at this point in the history
…ions

add a few more expression simplifications
  • Loading branch information
valentinsulzer authored Oct 10, 2022
2 parents 82a705c + e28b196 commit dbc7a6c
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 34 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

- For simulations with events that cause the simulation to stop early, the sensitivities could be evaluated incorrectly to zero ([#2331](https://github.com/pybamm-team/PyBaMM/pull/2337))

## Optimizations

- Added more rules for simplifying expressions ([#2211](https://github.com/pybamm-team/PyBaMM/pull/2211))
- Sped up calculations of Electrode SOH variables for summary variables ([#2210](https://github.com/pybamm-team/PyBaMM/pull/2210))

# [v22.9](https://github.com/pybamm-team/PyBaMM/tree/v22.9) - 2022-09-30

## Features
Expand Down
90 changes: 57 additions & 33 deletions pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,7 @@ def simplified_addition(left, right):

# Return constant if both sides are constant
if left.is_constant() and right.is_constant():
return pybamm.simplify_if_constant(pybamm.Addition(left, right))
return pybamm.simplify_if_constant(Addition(left, right))

# Simplify A @ c + B @ c to (A + B) @ c if (A + B) is constant
# This is a common construction that appears from discretisation of spatial
Expand All @@ -852,10 +852,10 @@ def simplified_addition(left, right):
new_left = l_left + r_left
if new_left.is_constant():
new_sum = new_left @ l_right
new_sum.copy_domains(pybamm.Addition(left, right))
new_sum.copy_domains(Addition(left, right))
return new_sum

if isinstance(right, pybamm.Addition) and left.is_constant():
if isinstance(right, Addition) and left.is_constant():
# Simplify a + (b + c) to (a + b) + c if (a + b) is constant
if right.left.is_constant():
r_left, r_right = right.orphans
Expand All @@ -864,7 +864,7 @@ def simplified_addition(left, right):
elif right.right.is_constant():
r_left, r_right = right.orphans
return (left + r_right) + r_left
elif isinstance(right, pybamm.Subtraction) and left.is_constant():
elif isinstance(right, Subtraction) and left.is_constant():
# Simplify a + (b - c) to (a + b) - c if (a + b) is constant
if right.left.is_constant():
r_left, r_right = right.orphans
Expand All @@ -873,7 +873,7 @@ def simplified_addition(left, right):
elif right.right.is_constant():
r_left, r_right = right.orphans
return (left - r_right) + r_left
if isinstance(left, pybamm.Addition) and right.is_constant():
if isinstance(left, Addition) and right.is_constant():
# Simplify (a + b) + c to a + (b + c) if (b + c) is constant
if left.right.is_constant():
l_left, l_right = left.orphans
Expand All @@ -882,7 +882,7 @@ def simplified_addition(left, right):
elif left.left.is_constant():
l_left, l_right = left.orphans
return (l_left + right) + l_right
elif isinstance(left, pybamm.Subtraction) and right.is_constant():
elif isinstance(left, Subtraction) and right.is_constant():
# Simplify (a - b) + c to a + (c - b) if (c - b) is constant
if left.right.is_constant():
l_left, l_right = left.orphans
Expand All @@ -892,7 +892,7 @@ def simplified_addition(left, right):
l_left, l_right = left.orphans
return (l_left + right) - l_right

return pybamm.simplify_if_constant(pybamm.Addition(left, right))
return pybamm.simplify_if_constant(Addition(left, right))


def simplified_subtraction(left, right):
Expand Down Expand Up @@ -953,7 +953,7 @@ def simplified_subtraction(left, right):
if left == right:
return pybamm.zeros_like(left)

if isinstance(right, pybamm.Addition) and left.is_constant():
if isinstance(right, Addition) and left.is_constant():
# Simplify a - (b + c) to (a - b) - c if (a - b) is constant
if right.left.is_constant():
r_left, r_right = right.orphans
Expand All @@ -962,7 +962,7 @@ def simplified_subtraction(left, right):
elif right.right.is_constant():
r_left, r_right = right.orphans
return (left - r_right) - r_left
elif isinstance(right, pybamm.Subtraction) and left.is_constant():
elif isinstance(right, Subtraction) and left.is_constant():
# Simplify a - (b - c) to (a - b) + c if (a - b) is constant
if right.left.is_constant():
r_left, r_right = right.orphans
Expand All @@ -971,7 +971,7 @@ def simplified_subtraction(left, right):
elif right.right.is_constant():
r_left, r_right = right.orphans
return (left + r_right) - r_left
if isinstance(left, pybamm.Addition) and right.is_constant():
if isinstance(left, Addition) and right.is_constant():
# Simplify (a + b) - c to a + (b - c) if (b - c) is constant
if left.right.is_constant():
l_left, l_right = left.orphans
Expand All @@ -980,7 +980,7 @@ def simplified_subtraction(left, right):
elif left.left.is_constant():
l_left, l_right = left.orphans
return (l_left - right) + l_right
elif isinstance(left, pybamm.Subtraction) and right.is_constant():
elif isinstance(left, Subtraction) and right.is_constant():
# Simplify (a - b) - c to a - (c + b) if (c + b) is constant
if left.right.is_constant():
l_left, l_right = left.orphans
Expand All @@ -990,7 +990,7 @@ def simplified_subtraction(left, right):
l_left, l_right = left.orphans
return (l_left - right) - l_right

return pybamm.simplify_if_constant(pybamm.Subtraction(left, right))
return pybamm.simplify_if_constant(Subtraction(left, right))


def simplified_multiplication(left, right):
Expand All @@ -1011,7 +1011,7 @@ def simplified_multiplication(left, right):

# if one of the children is a zero matrix, we have to be careful about shapes
if pybamm.is_matrix_zero(left) or pybamm.is_matrix_zero(right):
return pybamm.zeros_like(pybamm.Multiplication(left, right))
return pybamm.zeros_like(Multiplication(left, right))

# anything multiplied by a scalar one returns itself
if pybamm.is_scalar_one(left):
Expand All @@ -1027,7 +1027,7 @@ def simplified_multiplication(left, right):

# Return constant if both sides are constant
if left.is_constant() and right.is_constant():
return pybamm.simplify_if_constant(pybamm.Multiplication(left, right))
return pybamm.simplify_if_constant(Multiplication(left, right))

# anything multiplied by a matrix one returns itself if
# - the shapes are the same
Expand Down Expand Up @@ -1121,11 +1121,7 @@ def simplified_multiplication(left, right):
# operators
# Also do this for cases like a * (b @ c + d) where (a * b) is constant
elif isinstance(right, (Addition, Subtraction)):
mul_classes = (
pybamm.Multiplication,
pybamm.MatrixMultiplication,
pybamm.Division,
)
mul_classes = (Multiplication, MatrixMultiplication, Division)
if (
right.left.is_constant()
or right.right.is_constant()
Expand All @@ -1152,7 +1148,7 @@ def simplified_multiplication(left, right):
# Simplify a * (-b) to (-a) * b if (-a) is constant
return (-left) * right.orphans[0]

return pybamm.Multiplication(left, right)
return Multiplication(left, right)


def simplified_division(left, right):
Expand All @@ -1169,7 +1165,7 @@ def simplified_division(left, right):

# matrix zero divided by anything returns matrix zero (i.e. itself)
if pybamm.is_matrix_zero(left):
return pybamm.zeros_like(pybamm.Division(left, right))
return pybamm.zeros_like(Division(left, right))

# anything divided by zero raises error
if pybamm.is_scalar_zero(right):
Expand Down Expand Up @@ -1204,7 +1200,7 @@ def simplified_division(left, right):

# Return constant if both sides are constant
if left.is_constant() and right.is_constant():
return pybamm.simplify_if_constant(pybamm.Division(left, right))
return pybamm.simplify_if_constant(Division(left, right))

# Simplify (B @ c) / a to (B / a) @ c if (B / a) is constant
# This is a common construction that appears from discretisation of averages
Expand Down Expand Up @@ -1257,6 +1253,17 @@ def simplified_division(left, right):
r_left, r_right = right.orphans
return (left * r_right) / r_left

# Cancelling out common terms
if isinstance(left, Multiplication) and isinstance(right, Multiplication):
if left.left == right.left:
_, l_right = left.orphans
_, r_right = right.orphans
return l_right / r_right
if left.right == right.right:
l_left, _ = left.orphans
r_left, _ = right.orphans
return l_left / r_left

# Negation simplifications
if isinstance(left, pybamm.Negate) and isinstance(right, pybamm.Negate):
# Double negation cancels out
Expand All @@ -1269,13 +1276,13 @@ def simplified_division(left, right):
# Simplify a / (-b) to (-a) / b if (-a) is constant
return (-left) / right.orphans[0]

return pybamm.simplify_if_constant(pybamm.Division(left, right))
return pybamm.simplify_if_constant(Division(left, right))


def simplified_matrix_multiplication(left, right):
left, right = preprocess_binary(left, right)
if pybamm.is_matrix_zero(left) or pybamm.is_matrix_zero(right):
return pybamm.zeros_like(pybamm.MatrixMultiplication(left, right))
return pybamm.zeros_like(MatrixMultiplication(left, right))

if isinstance(right, Multiplication) and left.is_constant():
# Simplify A @ (b * c) to (A * b) @ c if (A * b) is constant
Expand Down Expand Up @@ -1309,18 +1316,35 @@ def simplified_matrix_multiplication(left, right):
new_mul.copy_domains(right)
return new_mul

# Simplify A @ (b + c) to (A @ b) + (A @ c) if (A @ b) or (A @ c) is constant
# This is a common construction that appears from discretisation of spatial
# operators
# Don't do this if either b or c is a number as this will lead to matmul errors
elif isinstance(right, Addition):
if (right.left.is_constant() or right.right.is_constant()) and not (
elif left.is_constant() and isinstance(right, (Addition, Subtraction)):
# Simplify A @ (b +- c) to (A @ b) +- (A @ c) if (A @ b) or (A @ c) is constant
# This is a common construction that appears from discretisation of spatial
# operators
# Or simplify A @ (B @ b +- C @ c) to (A @ B @ b) +- (A @ C @ c) if (A @ B)
# and (A @ C) are constant
# Don't do this if either b or c is a number as this will lead to matmul errors
if (
(right.left.is_constant() or right.right.is_constant())
# these lines should work but don't, possibly because of poorly
# conditioned model?
# or (
# isinstance(right.left, MatrixMultiplication)
# and right.left.left.is_constant()
# and isinstance(right.right, MatrixMultiplication)
# and right.right.left.is_constant()
# )
) and not (
right.left.size_for_testing == 1 or right.right.size_for_testing == 1
):
r_left, r_right = right.orphans
return (left @ r_left) + (left @ r_right)

return pybamm.simplify_if_constant(pybamm.MatrixMultiplication(left, right))
r_left.domains = right.domains
r_right.domains = right.domains
if isinstance(right, Addition):
return (left @ r_left) + (left @ r_right)
elif isinstance(right, Subtraction):
return (left @ r_left) - (left @ r_right)

return pybamm.simplify_if_constant(MatrixMultiplication(left, right))


def minimum(left, right):
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
raise TypeError("all y_slices must be slice objects")
if name is None:
if y_slices[0].start is None:
name = base_name + "[:{:d}]".format(y_slice.stop)
name = base_name + "[:{:d}".format(y_slice.stop)
else:
name = base_name + "[{:d}:{:d}".format(
y_slices[0].start, y_slices[0].stop
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/test_expression_tree/test_binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,10 @@ def test_binary_simplifications(self):
with self.assertRaises(ZeroDivisionError):
b / a

# division with a common term
self.assertEqual((2 * c) / (2 * var), (c / var))
self.assertEqual((c * 2) / (var * 2), (c / var))

def test_binary_simplifications_concatenations(self):
def conc_broad(x, y, z):
return pybamm.concatenation(
Expand Down Expand Up @@ -564,7 +568,9 @@ def test_advanced_binary_simplifications(self):
# MatMul simplifications that often appear when discretising spatial operators
A = pybamm.Matrix(np.random.rand(10, 10))
B = pybamm.Matrix(np.random.rand(10, 10))
# C = pybamm.Matrix(np.random.rand(10, 10))
var = pybamm.StateVector(slice(0, 10))
# var2 = pybamm.StateVector(slice(10, 20))
vec = pybamm.Vector(np.random.rand(10))

# Do A@B first if it is constant
Expand All @@ -575,9 +581,19 @@ def test_advanced_binary_simplifications(self):
# constant
expr = A @ (var + vec)
self.assertEqual(expr, ((A @ var) + (A @ vec)))
expr = A @ (var - vec)
self.assertEqual(expr, ((A @ var) - (A @ vec)))

expr = A @ ((B @ var) + vec)
self.assertEqual(expr, (((A @ B) @ var) + (A @ vec)))
expr = A @ ((B @ var) - vec)
self.assertEqual(expr, (((A @ B) @ var) - (A @ vec)))

# Distribute the @ operator to a sum if both symbols being summed are matmuls
# expr = A @ (B @ var + C @ var2)
# self.assertEqual(expr, ((A @ B) @ var + (A @ C) @ var2))
# expr = A @ (B @ var - C @ var2)
# self.assertEqual(expr, ((A @ B) @ var - (A @ C) @ var2))

# Reduce (A@var + B@var) to ((A+B)@var)
expr = A @ var + B @ var
Expand Down

0 comments on commit dbc7a6c

Please sign in to comment.