Skip to content

Commit

Permalink
pybamm-team#744 add test for ones_like
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Dec 11, 2019
1 parent abe7e65 commit 0a32edd
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 5 deletions.
17 changes: 13 additions & 4 deletions pybamm/expression_tree/broadcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,14 +186,23 @@ def evaluate_for_shape(self):
return child_eval * vec


def ones_like(symbol):
def ones_like(*symbols):
"""
Create a symbol with the same shape as the input symbol and with constant value '1',
using `FullBroadcast`.
Parameters
----------
symbol : :class:`Symbol`
Symbol whose shape to copy
symbols : :class:`Symbol`
Symbols whose shape to copy
"""
return FullBroadcast(1, symbol.domain, symbol.auxiliary_domains)
# Make a symbol that combines all the children, to get the right domain
# that takes all the child symbols into account
sum_symbol = 0
for sym in symbols:
import ipdb

ipdb.set_trace()
sum_symbol += sym
return FullBroadcast(1, sum_symbol.domain, sum_symbol.auxiliary_domains)

2 changes: 1 addition & 1 deletion pybamm/parameters/parameter_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def _process_symbol(self, symbol):
# Also use ones_like so that we get the right shapes
function = pybamm.Scalar(
function_name, name=symbol.name
) * pybamm.ones_like(new_children[0])
) * pybamm.ones_like(new_children)
else:
# otherwise evaluate the function to create a new PyBaMM object
function = function_name(*new_children)
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/test_expression_tree/test_broadcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,25 @@ def test_broadcast_type(self):
with self.assertRaisesRegex(ValueError, "Variables on the current collector"):
pybamm.Broadcast(a, "electrode")

def test_ones_like(self):
a = pybamm.Variable(
"a",
domain="negative electrode",
auxiliary_domains={"secondary": "current collector"},
)
ones_like_a = pybamm.ones_like(a)
self.assertIsInstance(ones_like_a, pybamm.FullBroadcast)
self.assertEqual(ones_like_a.name, "broadcast")
self.assertEqual(ones_like_a.domain, a.domain)
self.assertEqual(ones_like_a.auxiliary_domains, a.auxiliary_domains)

b = pybamm.Variable("b", domain="negative electrode")
ones_like_ab = pybamm.ones_like(b, a)
self.assertIsInstance(ones_like_ab, pybamm.FullBroadcast)
self.assertEqual(ones_like_ab.name, "broadcast")
self.assertEqual(ones_like_ab.domain, a.domain)
self.assertEqual(ones_like_ab.auxiliary_domains, a.auxiliary_domains)


if __name__ == "__main__":
print("Add -v for more debug output")
Expand Down

0 comments on commit 0a32edd

Please sign in to comment.