Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a few more expression simplifications #2211

Merged
merged 12 commits into from
Oct 10, 2022
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

- Added function `pybamm.get_git_commit_info()`, which returns information about the last git commit, useful for reproducibility ([#2293](https://github.com/pybamm-team/PyBaMM/pull/2293))
- Added SEI model for composite electrodes ([#2290](https://github.com/pybamm-team/PyBaMM/pull/2290))
- Default options for `particle mechanics` now dealt with differently in each electrode ([#2262](https://github.com/pybamm-team/PyBaMM/pull/2262))
- For experiments, the simulation now automatically checks and skips steps that cannot be performed (e.g. "Charge at 1C until 4.2V" from 100% SOC) ([#2212](https://github.com/pybamm-team/PyBaMM/pull/2212))

## Bug fixes
Expand All @@ -15,7 +16,7 @@

## Optimizations

- Default options for `particle mechanics` now dealt with differently in each electrode ([#2262](https://github.com/pybamm-team/PyBaMM/pull/2262))
- Added more rules for simplifying expressions ([#2211](https://github.com/pybamm-team/PyBaMM/pull/2211))
- Sped up calculations of Electrode SOH variables for summary variables ([#2210](https://github.com/pybamm-team/PyBaMM/pull/2210))

## Breaking changes
Expand Down
89 changes: 56 additions & 33 deletions pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,7 @@ def simplified_addition(left, right):

# Return constant if both sides are constant
if left.is_constant() and right.is_constant():
return pybamm.simplify_if_constant(pybamm.Addition(left, right))
return pybamm.simplify_if_constant(Addition(left, right))

# Simplify A @ c + B @ c to (A + B) @ c if (A + B) is constant
# This is a common construction that appears from discretisation of spatial
Expand All @@ -852,10 +852,10 @@ def simplified_addition(left, right):
new_left = l_left + r_left
if new_left.is_constant():
new_sum = new_left @ l_right
new_sum.copy_domains(pybamm.Addition(left, right))
new_sum.copy_domains(Addition(left, right))
return new_sum

if isinstance(right, pybamm.Addition) and left.is_constant():
if isinstance(right, Addition) and left.is_constant():
# Simplify a + (b + c) to (a + b) + c if (a + b) is constant
if right.left.is_constant():
r_left, r_right = right.orphans
Expand All @@ -864,7 +864,7 @@ def simplified_addition(left, right):
elif right.right.is_constant():
r_left, r_right = right.orphans
return (left + r_right) + r_left
elif isinstance(right, pybamm.Subtraction) and left.is_constant():
elif isinstance(right, Subtraction) and left.is_constant():
# Simplify a + (b - c) to (a + b) - c if (a + b) is constant
if right.left.is_constant():
r_left, r_right = right.orphans
Expand All @@ -873,7 +873,7 @@ def simplified_addition(left, right):
elif right.right.is_constant():
r_left, r_right = right.orphans
return (left - r_right) + r_left
if isinstance(left, pybamm.Addition) and right.is_constant():
if isinstance(left, Addition) and right.is_constant():
# Simplify (a + b) + c to a + (b + c) if (b + c) is constant
if left.right.is_constant():
l_left, l_right = left.orphans
Expand All @@ -882,7 +882,7 @@ def simplified_addition(left, right):
elif left.left.is_constant():
l_left, l_right = left.orphans
return (l_left + right) + l_right
elif isinstance(left, pybamm.Subtraction) and right.is_constant():
elif isinstance(left, Subtraction) and right.is_constant():
# Simplify (a - b) + c to a + (c - b) if (c - b) is constant
if left.right.is_constant():
l_left, l_right = left.orphans
Expand All @@ -892,7 +892,7 @@ def simplified_addition(left, right):
l_left, l_right = left.orphans
return (l_left + right) - l_right

return pybamm.simplify_if_constant(pybamm.Addition(left, right))
return pybamm.simplify_if_constant(Addition(left, right))


def simplified_subtraction(left, right):
Expand Down Expand Up @@ -953,7 +953,7 @@ def simplified_subtraction(left, right):
if left == right:
return pybamm.zeros_like(left)

if isinstance(right, pybamm.Addition) and left.is_constant():
if isinstance(right, Addition) and left.is_constant():
# Simplify a - (b + c) to (a - b) - c if (a - b) is constant
if right.left.is_constant():
r_left, r_right = right.orphans
Expand All @@ -962,7 +962,7 @@ def simplified_subtraction(left, right):
elif right.right.is_constant():
r_left, r_right = right.orphans
return (left - r_right) - r_left
elif isinstance(right, pybamm.Subtraction) and left.is_constant():
elif isinstance(right, Subtraction) and left.is_constant():
# Simplify a - (b - c) to (a - b) + c if (a - b) is constant
if right.left.is_constant():
r_left, r_right = right.orphans
Expand All @@ -971,7 +971,7 @@ def simplified_subtraction(left, right):
elif right.right.is_constant():
r_left, r_right = right.orphans
return (left + r_right) - r_left
if isinstance(left, pybamm.Addition) and right.is_constant():
if isinstance(left, Addition) and right.is_constant():
# Simplify (a + b) - c to a + (b - c) if (b - c) is constant
if left.right.is_constant():
l_left, l_right = left.orphans
Expand All @@ -980,7 +980,7 @@ def simplified_subtraction(left, right):
elif left.left.is_constant():
l_left, l_right = left.orphans
return (l_left - right) + l_right
elif isinstance(left, pybamm.Subtraction) and right.is_constant():
elif isinstance(left, Subtraction) and right.is_constant():
# Simplify (a - b) - c to a - (c + b) if (c + b) is constant
if left.right.is_constant():
l_left, l_right = left.orphans
Expand All @@ -990,7 +990,7 @@ def simplified_subtraction(left, right):
l_left, l_right = left.orphans
return (l_left - right) - l_right

return pybamm.simplify_if_constant(pybamm.Subtraction(left, right))
return pybamm.simplify_if_constant(Subtraction(left, right))


def simplified_multiplication(left, right):
Expand All @@ -1011,7 +1011,7 @@ def simplified_multiplication(left, right):

# if one of the children is a zero matrix, we have to be careful about shapes
if pybamm.is_matrix_zero(left) or pybamm.is_matrix_zero(right):
return pybamm.zeros_like(pybamm.Multiplication(left, right))
return pybamm.zeros_like(Multiplication(left, right))

# anything multiplied by a scalar one returns itself
if pybamm.is_scalar_one(left):
Expand All @@ -1027,7 +1027,7 @@ def simplified_multiplication(left, right):

# Return constant if both sides are constant
if left.is_constant() and right.is_constant():
return pybamm.simplify_if_constant(pybamm.Multiplication(left, right))
return pybamm.simplify_if_constant(Multiplication(left, right))

# anything multiplied by a matrix one returns itself if
# - the shapes are the same
Expand Down Expand Up @@ -1121,11 +1121,7 @@ def simplified_multiplication(left, right):
# operators
# Also do this for cases like a * (b @ c + d) where (a * b) is constant
elif isinstance(right, (Addition, Subtraction)):
mul_classes = (
pybamm.Multiplication,
pybamm.MatrixMultiplication,
pybamm.Division,
)
mul_classes = (Multiplication, MatrixMultiplication, Division)
if (
right.left.is_constant()
or right.right.is_constant()
Expand All @@ -1152,7 +1148,7 @@ def simplified_multiplication(left, right):
# Simplify a * (-b) to (-a) * b if (-a) is constant
return (-left) * right.orphans[0]

return pybamm.Multiplication(left, right)
return Multiplication(left, right)


def simplified_division(left, right):
Expand All @@ -1169,7 +1165,7 @@ def simplified_division(left, right):

# matrix zero divided by anything returns matrix zero (i.e. itself)
if pybamm.is_matrix_zero(left):
return pybamm.zeros_like(pybamm.Division(left, right))
return pybamm.zeros_like(Division(left, right))

# anything divided by zero raises error
if pybamm.is_scalar_zero(right):
Expand Down Expand Up @@ -1204,7 +1200,7 @@ def simplified_division(left, right):

# Return constant if both sides are constant
if left.is_constant() and right.is_constant():
return pybamm.simplify_if_constant(pybamm.Division(left, right))
return pybamm.simplify_if_constant(Division(left, right))

# Simplify (B @ c) / a to (B / a) @ c if (B / a) is constant
# This is a common construction that appears from discretisation of averages
Expand Down Expand Up @@ -1257,6 +1253,16 @@ def simplified_division(left, right):
r_left, r_right = right.orphans
return (left * r_right) / r_left

# Cancelling out common terms
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't there more simplifications here? I think this catches things like (a*b)/(a*c) but wouldn't catch (b*a)/(c*a)? Is it much overhead to check for these?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, just adding things as I see them come up in expression trees

if (
isinstance(left, Multiplication)
and isinstance(right, Multiplication)
and left.left == right.left
):
_, l_right = left.orphans
_, r_right = right.orphans
return l_right / r_right

# Negation simplifications
if isinstance(left, pybamm.Negate) and isinstance(right, pybamm.Negate):
# Double negation cancels out
Expand All @@ -1269,13 +1275,13 @@ def simplified_division(left, right):
# Simplify a / (-b) to (-a) / b if (-a) is constant
return (-left) / right.orphans[0]

return pybamm.simplify_if_constant(pybamm.Division(left, right))
return pybamm.simplify_if_constant(Division(left, right))


def simplified_matrix_multiplication(left, right):
left, right = preprocess_binary(left, right)
if pybamm.is_matrix_zero(left) or pybamm.is_matrix_zero(right):
return pybamm.zeros_like(pybamm.MatrixMultiplication(left, right))
return pybamm.zeros_like(MatrixMultiplication(left, right))

if isinstance(right, Multiplication) and left.is_constant():
# Simplify A @ (b * c) to (A * b) @ c if (A * b) is constant
Expand Down Expand Up @@ -1309,18 +1315,35 @@ def simplified_matrix_multiplication(left, right):
new_mul.copy_domains(right)
return new_mul

# Simplify A @ (b + c) to (A @ b) + (A @ c) if (A @ b) or (A @ c) is constant
# This is a common construction that appears from discretisation of spatial
# operators
# Don't do this if either b or c is a number as this will lead to matmul errors
elif isinstance(right, Addition):
if (right.left.is_constant() or right.right.is_constant()) and not (
elif isinstance(right, (Addition, Subtraction)):
# Simplify A @ (b +- c) to (A @ b) +- (A @ c) if (A @ b) or (A @ c) is constant
# This is a common construction that appears from discretisation of spatial
# operators
# Or simplify A @ (B @ b +- C @ c) to (A @ B @ b) +- (A @ C @ c) if (A @ B)
# and (A @ C) are constant
# Don't do this if either b or c is a number as this will lead to matmul errors
if (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need A to be constant here too? I.e. left.is_constant()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, good point

(right.left.is_constant() or right.right.is_constant())
# these lines should work but don't, possibly because of poorly
# conditioned model?
# or (
# isinstance(right.left, MatrixMultiplication)
# and right.left.left.is_constant()
# and isinstance(right.right, MatrixMultiplication)
# and right.right.left.is_constant()
# )
) and not (
right.left.size_for_testing == 1 or right.right.size_for_testing == 1
):
r_left, r_right = right.orphans
return (left @ r_left) + (left @ r_right)

return pybamm.simplify_if_constant(pybamm.MatrixMultiplication(left, right))
r_left.domains = right.domains
r_right.domains = right.domains
if isinstance(right, Addition):
return (left @ r_left) + (left @ r_right)
elif isinstance(right, Subtraction):
return (left @ r_left) - (left @ r_right)

return pybamm.simplify_if_constant(MatrixMultiplication(left, right))


def minimum(left, right):
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
raise TypeError("all y_slices must be slice objects")
if name is None:
if y_slices[0].start is None:
name = base_name + "[:{:d}]".format(y_slice.stop)
name = base_name + "[:{:d}".format(y_slice.stop)
else:
name = base_name + "[{:d}:{:d}".format(
y_slices[0].start, y_slices[0].stop
Expand Down
14 changes: 14 additions & 0 deletions tests/unit/test_expression_tree/test_binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,8 @@ def test_binary_simplifications(self):
# division by zero
with self.assertRaises(ZeroDivisionError):
b / a
# division with a common term
self.assertEqual((2 * c) / (2 * var), (c / var))

def test_binary_simplifications_concatenations(self):
def conc_broad(x, y, z):
Expand Down Expand Up @@ -564,7 +566,9 @@ def test_advanced_binary_simplifications(self):
# MatMul simplifications that often appear when discretising spatial operators
A = pybamm.Matrix(np.random.rand(10, 10))
B = pybamm.Matrix(np.random.rand(10, 10))
# C = pybamm.Matrix(np.random.rand(10, 10))
var = pybamm.StateVector(slice(0, 10))
# var2 = pybamm.StateVector(slice(10, 20))
vec = pybamm.Vector(np.random.rand(10))

# Do A@B first if it is constant
Expand All @@ -575,9 +579,19 @@ def test_advanced_binary_simplifications(self):
# constant
expr = A @ (var + vec)
self.assertEqual(expr, ((A @ var) + (A @ vec)))
expr = A @ (var - vec)
self.assertEqual(expr, ((A @ var) - (A @ vec)))

expr = A @ ((B @ var) + vec)
self.assertEqual(expr, (((A @ B) @ var) + (A @ vec)))
expr = A @ ((B @ var) - vec)
self.assertEqual(expr, (((A @ B) @ var) - (A @ vec)))

# Distribute the @ operator to a sum if both symbols being summed are matmuls
# expr = A @ (B @ var + C @ var2)
# self.assertEqual(expr, ((A @ B) @ var + (A @ C) @ var2))
# expr = A @ (B @ var - C @ var2)
# self.assertEqual(expr, ((A @ B) @ var - (A @ C) @ var2))

# Reduce (A@var + B@var) to ((A+B)@var)
expr = A @ var + B @ var
Expand Down