Skip to content

Commit

Permalink
#546 use secondary dimension instead of secondary broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Jul 23, 2019
1 parent 0a79b9b commit f277217
Show file tree
Hide file tree
Showing 20 changed files with 219 additions and 172 deletions.
7 changes: 6 additions & 1 deletion pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@ def version(formatted=False):
#
# Classes for the Expression Tree
#
from .expression_tree.symbol import Symbol, evaluate_for_shape_using_domain
from .expression_tree.symbol import (
Symbol,
domain_size,
create_object_of_size,
evaluate_for_shape_using_domain,
)
from .expression_tree.binary_operators import (
is_scalar_zero,
is_matrix_zero,
Expand Down
58 changes: 50 additions & 8 deletions pybamm/expression_tree/broadcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@ class Broadcast(pybamm.SpatialOperator):
**Extends:** :class:`SpatialOperator`
"""

def __init__(self, child, broadcast_domain, broadcast_type="full", name=None):
def __init__(
self,
child,
broadcast_domain,
secondary_domain=None,
broadcast_type="full",
name=None,
):
# Convert child to scalar if it is a number
if isinstance(child, numbers.Number):
child = pybamm.Scalar(child)
Expand All @@ -41,7 +48,7 @@ def __init__(self, child, broadcast_domain, broadcast_type="full", name=None):
)
self.broadcast_type = broadcast_type
self.broadcast_domain = broadcast_domain
super().__init__(name, child, domain)
super().__init__(name, child, domain, secondary_domain)

def check_and_set_domain_and_broadcast_type(
self, child, broadcast_domain, broadcast_type
Expand Down Expand Up @@ -130,7 +137,7 @@ class PrimaryBroadcast(Broadcast):
"A class for primary broadcasts"

def __init__(self, child, broadcast_domain, name=None):
super().__init__(child, broadcast_domain, "primary", name)
super().__init__(child, broadcast_domain, broadcast_type="primary", name=name)

def _unary_simplify(self, child):
""" See :meth:`pybamm.UnaryOperator.simplify()`. """
Expand All @@ -140,12 +147,21 @@ def _unary_new_copy(self, child):
""" See :meth:`pybamm.UnaryOperator.simplify()`. """
return PrimaryBroadcast(child, self.broadcast_domain)

def evaluate_for_shape(self):
"""
Returns a vector of NaNs to represent the shape of a Broadcast.
See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`
"""
child_eval = self.children[0].evaluate_for_shape()
vec = pybamm.evaluate_for_shape_using_domain(self.domain)
return np.outer(child_eval, vec).reshape(-1, 1)


class SecondaryBroadcast(Broadcast):
"A class for secondary broadcasts"

def __init__(self, child, broadcast_domain, name=None):
super().__init__(child, broadcast_domain, "secondary", name)
super().__init__(child, broadcast_domain, broadcast_type="secondary", name=name)

def _unary_simplify(self, child):
""" See :meth:`pybamm.UnaryOperator.simplify()`. """
Expand All @@ -155,17 +171,43 @@ def _unary_new_copy(self, child):
""" See :meth:`pybamm.UnaryOperator.simplify()`. """
return SecondaryBroadcast(child, self.broadcast_domain)

def evaluate_for_shape(self):
"""
Returns a vector of NaNs to represent the shape of a Broadcast.
See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`
"""
child_eval = self.children[0].evaluate_for_shape()
return np.outer(
pybamm.evaluate_for_shape_using_domain(self.broadcast_domain), child_eval
).reshape(-1, 1)


class FullBroadcast(Broadcast):
"A class for full broadcasts"

def __init__(self, child, broadcast_domain, name=None):
super().__init__(child, broadcast_domain, "full", name)
def __init__(self, child, broadcast_domain, secondary_domain=None, name=None):
super().__init__(
child,
broadcast_domain,
secondary_domain=secondary_domain,
broadcast_type="full",
name=name,
)

def _unary_simplify(self, child):
""" See :meth:`pybamm.UnaryOperator.simplify()`. """
return FullBroadcast(child, self.broadcast_domain)
return FullBroadcast(child, self.broadcast_domain, self.secondary_domain)

def _unary_new_copy(self, child):
""" See :meth:`pybamm.UnaryOperator.simplify()`. """
return FullBroadcast(child, self.broadcast_domain)
return FullBroadcast(child, self.broadcast_domain, self.secondary_domain)

def evaluate_for_shape(self):
"""
Returns a vector of NaNs to represent the shape of a Broadcast.
See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`
"""
child_eval = self.children[0].evaluate_for_shape()
vec = pybamm.evaluate_for_shape_using_domain(self.domain, self.secondary_domain)

return child_eval * vec
26 changes: 21 additions & 5 deletions pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def domain_size(domain):
}
if isinstance(domain, str):
domain = [domain]
if domain == []:
if domain in [[], None]:
size = 1
elif all(dom in fixed_domain_sizes for dom in domain):
size = sum(fixed_domain_sizes[dom] for dom in domain)
Expand All @@ -47,13 +47,14 @@ def create_object_of_size(size, typ="vector"):
return np.nan * np.ones((size, size))


def evaluate_for_shape_using_domain(domain, typ="vector"):
def evaluate_for_shape_using_domain(domain, secondary_domain=None, typ="vector"):
"""
Return a vector of the appropriate shape, based on the domain.
Domain 'sizes' can clash, but are unlikely to, and won't cause failures if they do.
"""
size = domain_size(domain)
return create_object_of_size(size, typ)
_domain_size = domain_size(domain)
_secondary_domain_size = domain_size(secondary_domain)
return create_object_of_size(_domain_size * _secondary_domain_size, typ)


class Symbol(anytree.NodeMixin):
Expand All @@ -72,10 +73,21 @@ class Symbol(anytree.NodeMixin):
"""

def __init__(self, name, children=[], domain=[]):
def __init__(self, name, children=None, domain=None, secondary_domain=None):
super(Symbol, self).__init__()
self.name = name

if children is None:
children = []
if domain is None:
domain = []
elif isinstance(domain, str):
domain = [domain]
if secondary_domain is None:
secondary_domain = []
elif isinstance(secondary_domain, str):
secondary_domain = [secondary_domain]

for child in children:
# copy child before adding
# this also adds copy.copy(child) to self.children
Expand All @@ -84,6 +96,8 @@ def __init__(self, name, children=[], domain=[]):
# cache children
self.cached_children = super(Symbol, self).children

# Set secondary domain
self.secondary_domain = secondary_domain
# Set domain (and hence id)
self.domain = domain

Expand Down Expand Up @@ -160,6 +174,7 @@ def set_id(self):
(self.__class__, self.name)
+ tuple([child.id for child in self.children])
+ tuple(self.domain)
+ tuple(self.secondary_domain)
)

@property
Expand Down Expand Up @@ -264,6 +279,7 @@ def __repr__(self):
self._name,
[str(child) for child in self.children],
[str(subdomain) for subdomain in self.domain],
[str(subdomain) for subdomain in self.secondary_domain],
)

def __add__(self, other):
Expand Down
10 changes: 6 additions & 4 deletions pybamm/expression_tree/unary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ class UnaryOperator(pybamm.Symbol):
"""

def __init__(self, name, child, domain=None):
def __init__(self, name, child, domain=None, secondary_domain=None):
if domain is None:
domain = child.domain
super().__init__(name, children=[child], domain=domain)
super().__init__(
name, children=[child], domain=domain, secondary_domain=secondary_domain
)
self.child = self.children[0]

def __str__(self):
Expand Down Expand Up @@ -247,8 +249,8 @@ class with a :class:`Matrix`
"""

def __init__(self, name, child, domain=None):
super().__init__(name, child, domain)
def __init__(self, name, child, domain=None, secondary_domain=None):
super().__init__(name, child, domain, secondary_domain)

def diff(self, variable):
""" See :meth:`pybamm.Symbol.diff()`. """
Expand Down
14 changes: 10 additions & 4 deletions pybamm/expression_tree/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,19 @@ class Variable(pybamm.Symbol):
*Extends:* :class:`Symbol`
"""

def __init__(self, name, domain=[]):
super().__init__(name, domain=domain)
def __init__(self, name, domain=None, secondary_domain=None):
if domain is None:
domain = []
if secondary_domain is None:
secondary_domain = []
super().__init__(name, domain=domain, secondary_domain=secondary_domain)

def new_copy(self):
""" See :meth:`pybamm.Symbol.new_copy()`. """
return Variable(self.name, self.domain)
return Variable(self.name, self.domain, self.secondary_domain)

def evaluate_for_shape(self):
""" See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()` """
return pybamm.evaluate_for_shape_using_domain(self.domain)
return pybamm.evaluate_for_shape_using_domain(
self.domain, self.secondary_domain
)
Loading

0 comments on commit f277217

Please sign in to comment.