Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 1259 concat #1368

Merged
merged 7 commits into from
Feb 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

## Bug fixes

- Added a check for domains in `Concatenation` ([#1368](https://github.com/pybamm-team/PyBaMM/pull/1368))
- Differentiation now works even when the differentiation variable is a constant ([#1294](https://github.com/pybamm-team/PyBaMM/pull/1294))
- Fixed a bug where the event time and state were no longer returned as part of the solution ([#1344](https://github.com/pybamm-team/PyBaMM/pull/1344))
- Fixed a bug in `CasadiSolver` safe mode which crashed when there were extrapolation events but no termination events ([#1321](https://github.com/pybamm-team/PyBaMM/pull/1321))
Expand Down
1 change: 1 addition & 0 deletions examples/scripts/compare_lead_acid.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# create and run simulations
sims = []
for model in models:
model.convert_to_format = None
sim = pybamm.Simulation(model)
sim.solve([0, 3600 * 17])
sims.append(sim)
Expand Down
2 changes: 1 addition & 1 deletion pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ def _process_symbol(self, symbol):
disc_right = self.process_symbol(right)
if symbol.domain == []:
return pybamm.simplify_if_constant(
symbol._binary_new_copy(disc_left, disc_right), clear_domains=False
symbol._binary_new_copy(disc_left, disc_right)
)
else:
return spatial_method.process_binary_operators(
Expand Down
48 changes: 14 additions & 34 deletions pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,7 @@ def _binary_jac(self, left_jac, right_jac):
""" See :meth:`pybamm.BinaryOperator._binary_jac()`. """
# apply chain rule and power rule
left, right = self.orphans
if left.evaluates_to_constant_number() and right.evaluates_to_constant_number():
return pybamm.Scalar(0)
elif right.evaluates_to_constant_number():
if right.evaluates_to_constant_number():
return (right * left ** (right - 1)) * left_jac
elif left.evaluates_to_constant_number():
return (left ** right * pybamm.log(left)) * right_jac
Expand Down Expand Up @@ -292,9 +290,7 @@ def _binary_jac(self, left_jac, right_jac):
""" See :meth:`pybamm.BinaryOperator._binary_jac()`. """
# apply product rule
left, right = self.orphans
if left.evaluates_to_constant_number() and right.evaluates_to_constant_number():
return pybamm.Scalar(0)
elif left.evaluates_to_constant_number():
if left.evaluates_to_constant_number():
return left * right_jac
elif right.evaluates_to_constant_number():
return right * left_jac
Expand Down Expand Up @@ -375,9 +371,7 @@ def _binary_jac(self, left_jac, right_jac):
""" See :meth:`pybamm.BinaryOperator._binary_jac()`. """
# apply quotient rule
left, right = self.orphans
if left.evaluates_to_constant_number() and right.evaluates_to_constant_number():
return pybamm.Scalar(0)
elif left.evaluates_to_constant_number():
if left.evaluates_to_constant_number():
return -left / right ** 2 * right_jac
elif right.evaluates_to_constant_number():
return left_jac / right
Expand Down Expand Up @@ -430,9 +424,7 @@ def _binary_jac(self, left_jac, right_jac):
""" See :meth:`pybamm.BinaryOperator._binary_jac()`. """
# apply product rule
left, right = self.orphans
if left.evaluates_to_constant_number() and right.evaluates_to_constant_number():
return pybamm.Scalar(0)
elif left.evaluates_to_constant_number():
if left.evaluates_to_constant_number():
return left * right_jac
elif right.evaluates_to_constant_number():
return right * left_jac
Expand Down Expand Up @@ -480,7 +472,7 @@ def inner(left, right):
if pybamm.is_scalar_one(right):
return left

return pybamm.simplify_if_constant(pybamm.Inner(left, right), clear_domains=False)
return pybamm.simplify_if_constant(pybamm.Inner(left, right))


class Heaviside(BinaryOperator):
Expand Down Expand Up @@ -573,9 +565,7 @@ def _binary_jac(self, left_jac, right_jac):
""" See :meth:`pybamm.BinaryOperator._binary_jac()`. """
# apply chain rule and power rule
left, right = self.orphans
if left.evaluates_to_constant_number() and right.evaluates_to_constant_number():
return pybamm.Scalar(0)
elif right.evaluates_to_constant_number():
if right.evaluates_to_constant_number():
return left_jac
elif left.evaluates_to_constant_number():
return -right_jac * pybamm.Floor(left / right)
Expand Down Expand Up @@ -712,7 +702,7 @@ def simplified_power(left, right):
if new_left.is_constant() or new_right.is_constant():
return new_left / new_right

return pybamm.simplify_if_constant(pybamm.Power(left, right), clear_domains=False)
return pybamm.simplify_if_constant(pybamm.Power(left, right))


def simplified_addition(left, right):
Expand Down Expand Up @@ -784,9 +774,7 @@ def simplified_addition(left, right):
new_sum.copy_domains(pybamm.Addition(left, right))
return new_sum

return pybamm.simplify_if_constant(
pybamm.Addition(left, right), clear_domains=False
)
return pybamm.simplify_if_constant(pybamm.Addition(left, right))


def simplified_subtraction(left, right):
Expand Down Expand Up @@ -844,9 +832,7 @@ def simplified_subtraction(left, right):
if left.id == right.id:
return pybamm.zeros_like(left)

return pybamm.simplify_if_constant(
pybamm.Subtraction(left, right), clear_domains=False
)
return pybamm.simplify_if_constant(pybamm.Subtraction(left, right))


def simplified_multiplication(left, right):
Expand Down Expand Up @@ -893,9 +879,7 @@ def simplified_multiplication(left, right):

# Return constant if both sides are constant
if left.is_constant() and right.is_constant():
return pybamm.simplify_if_constant(
pybamm.Multiplication(left, right), clear_domains=False
)
return pybamm.simplify_if_constant(pybamm.Multiplication(left, right))

# Simplify (B @ c) * a to (a * B) @ c if (a * B) is constant
# This is a common construction that appears from discretisation of spatial
Expand Down Expand Up @@ -1024,9 +1008,7 @@ def simplified_division(left, right):
if new_right.is_constant():
return l_left * new_right

return pybamm.simplify_if_constant(
pybamm.Division(left, right), clear_domains=False
)
return pybamm.simplify_if_constant(pybamm.Division(left, right))


def simplified_matrix_multiplication(left, right):
Expand Down Expand Up @@ -1080,9 +1062,7 @@ def simplified_matrix_multiplication(left, right):
r_left, r_right = right.orphans
return (left @ r_left) + (left @ r_right)

return pybamm.simplify_if_constant(
pybamm.MatrixMultiplication(left, right), clear_domains=False
)
return pybamm.simplify_if_constant(pybamm.MatrixMultiplication(left, right))


def minimum(left, right):
Expand All @@ -1097,7 +1077,7 @@ def minimum(left, right):
out = Minimum(left, right)
else:
out = pybamm.softminus(left, right, k)
return pybamm.simplify_if_constant(out, clear_domains=False)
return pybamm.simplify_if_constant(out)


def maximum(left, right):
Expand All @@ -1112,7 +1092,7 @@ def maximum(left, right):
out = Maximum(left, right)
else:
out = pybamm.softplus(left, right, k)
return pybamm.simplify_if_constant(out, clear_domains=False)
return pybamm.simplify_if_constant(out)


def softminus(left, right, k):
Expand Down
4 changes: 0 additions & 4 deletions pybamm/expression_tree/broadcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,6 @@ def __init__(
self.broadcast_domain = broadcast_domain
super().__init__(name, child, domain, auxiliary_domains)

def _unary_simplify(self, simplified_child):
""" See :meth:`pybamm.UnaryOperator.simplify()`. """
return self._unary_new_copy(simplified_child)


class PrimaryBroadcast(Broadcast):
"""A node in the expression tree representing a primary broadcasting operator.
Expand Down
25 changes: 7 additions & 18 deletions pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,14 @@ def get_children_domains(self, children):
if not isinstance(child, pybamm.Symbol):
raise TypeError("{} is not a pybamm symbol".format(child))
child_domain = child.domain
if child_domain == []:
raise pybamm.DomainError(
"Cannot concatenate child '{}' with empty domain".format(child)
)
if set(domain).isdisjoint(child_domain):
domain += child_domain
else:
raise pybamm.DomainError("""domain of children must be disjoint""")
raise pybamm.DomainError("domain of children must be disjoint")
return domain

def _concatenation_evaluate(self, children_eval):
Expand Down Expand Up @@ -191,16 +195,6 @@ def __init__(self, children, full_mesh, copy_this=None):
# store mesh
self._full_mesh = full_mesh

# Check that there is a domain, otherwise the functionality won't work
# and we should raise a DomainError
if self.domain == []:
raise pybamm.DomainError(
"""
domain cannot be empty for a DomainConcatenation.
Perhaps the children should have been Broadcasted first?
"""
)

# create dict of domain => slice of final vector
self.secondary_dimensions_npts = self._get_auxiliary_domain_repeats(
self.domains
Expand Down Expand Up @@ -337,9 +331,7 @@ def simplified_numpy_concatenation(*children):
new_children.extend(child.orphans)
else:
new_children.append(child)
return pybamm.simplify_if_constant(
NumpyConcatenation(*new_children), clear_domains=False
)
return pybamm.simplify_if_constant(NumpyConcatenation(*new_children))


def numpy_concatenation(*children):
Expand Down Expand Up @@ -375,10 +367,7 @@ def simplified_domain_concatenation(children, mesh, copy_this=None):
auxiliary_domains=concat.auxiliary_domains,
)

return pybamm.simplify_if_constant(
concat,
clear_domains=False,
)
return pybamm.simplify_if_constant(concat)


def domain_concatenation(children, mesh):
Expand Down
Loading