Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 774 remove outer kron #777

Merged
merged 10 commits into from
Jan 10, 2020
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@

## Breaking changes

- Removed `Outer` and `Kron` nodes as no longer used ([#777](https://github.com/pybamm-team/PyBaMM/pull/777))
- Moved `results` to separate repositories ([#761](https://github.com/pybamm-team/PyBaMM/pull/761))
- The parameters "Bruggeman coefficient" must now be specified separately as "Bruggeman coefficient (electrolyte)" and "Bruggeman coefficient (electrode)"
- The current classes (`GetConstantCurrent`, `GetUserCurrent` and `GetUserData`) have now been removed. Please refer to the [`change-input-current` notebook](https://github.com/pybamm-team/PyBaMM/blob/master/examples/notebooks/change-input-current.ipynb) for information on how to specify an input current
Expand Down
8 changes: 0 additions & 8 deletions docs/source/expression_tree/binary_operator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,7 @@ Binary Operators
.. autoclass:: pybamm.Inner
:members:

.. autoclass:: pybamm.Outer
:members:

.. autoclass:: pybamm.Kron
:members:

.. autoclass:: pybamm.Heaviside
:members:

.. autofunction:: pybamm.outer

.. autofunction:: pybamm.source
3 changes: 0 additions & 3 deletions pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,7 @@ def version(formatted=False):
Division,
Inner,
inner,
Outer,
Kron,
Heaviside,
outer,
source,
)
from .expression_tree.concatenations import (
Expand Down
4 changes: 1 addition & 3 deletions pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,7 @@ def check_variables(self, model):
"""
Check variables in variable list against rhs
Be lenient with size check if the variable in model.variables is broadcasted, or
a concatenation, or an outer product
a concatenation
(if broadcasted, variable is a multiplication with a vector of ones)
"""
for rhs_var in model.rhs.keys():
Expand All @@ -1001,7 +1001,6 @@ def check_variables(self, model):
)

not_concatenation = not isinstance(var, pybamm.Concatenation)
not_outer = not isinstance(var, pybamm.Outer)

not_mult_by_one_vec = not (
isinstance(var, pybamm.Multiplication)
Expand All @@ -1012,7 +1011,6 @@ def check_variables(self, model):
if (
different_shapes
and not_concatenation
and not_outer
and not_mult_by_one_vec
):
raise pybamm.ModelError(
Expand Down
113 changes: 4 additions & 109 deletions pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np
import numbers
from scipy.sparse import issparse, csr_matrix, kron
from scipy.sparse import issparse, csr_matrix


def is_scalar_zero(expr):
Expand Down Expand Up @@ -79,17 +79,8 @@ class BinaryOperator(pybamm.Symbol):
def __init__(self, name, left, right):
left, right = self.format(left, right)

# Check and process domains, except for Outer symbol which takes the outer
# product of two smbols in different domains, and gives it the domain of the
# right child.
if isinstance(self, (pybamm.Outer, pybamm.Kron)):
domain = right.domain
auxiliary_domains = {}
if domain != []:
auxiliary_domains["secondary"] = left.domain
else:
domain = self.get_children_domains(left.domain, right.domain)
auxiliary_domains = self.get_children_auxiliary_domains([left, right])
domain = self.get_children_domains(left.domain, right.domain)
auxiliary_domains = self.get_children_auxiliary_domains([left, right])
super().__init__(
name,
children=[left, right],
Expand All @@ -116,11 +107,7 @@ def format(self, left, right):
)

# Do some broadcasting in special cases, to avoid having to do this manually
if (
not isinstance(self, (Outer, Kron))
and left.domain != []
and right.domain != []
):
if left.domain != [] and right.domain != []:
if (
left.domain != right.domain
and "secondary" in right.auxiliary_domains
Expand Down Expand Up @@ -654,86 +641,6 @@ def inner(left, right):
return pybamm.Inner(left, right)


class Outer(BinaryOperator):
"""A node in the expression tree representing an outer product.
This takes a 1D vector in the current collector domain of size (n,1) and a 1D
variable of size (m,1), takes their outer product, and reshapes this into a vector
of size (nm,1). It can also take in a vector in a single particle and a vector
of the electrolyte domain to repeat that particle.
Note: this class might be a bit dangerous, so at the moment it is very restrictive
in what symbols can be passed to it

**Extends:** :class:`BinaryOperator`
"""

def __init__(self, left, right):
""" See :meth:`pybamm.BinaryOperator.__init__()`. """
# cannot have certain types of objects in the right symbol, as these
# can already be 2D objects (so we can't take an outer product with them)
if right.has_symbol_of_classes(
(pybamm.Variable, pybamm.StateVector, pybamm.Matrix, pybamm.SpatialVariable)
):
raise TypeError("right child must only contain Vectors and Scalars" "")

super().__init__("outer product", left, right)

def __str__(self):
""" See :meth:`pybamm.Symbol.__str__()`. """
return "outer({!s}, {!s})".format(self.left, self.right)

def diff(self, variable):
""" See :meth:`pybamm.Symbol.diff()`. """
raise NotImplementedError("diff not implemented for symbol of type 'Outer'")

def _outer_jac(self, left_jac, right_jac, variable):
"""
Calculate jacobian of outer product.
See :meth:`pybamm.Jacobian._jac()`.
"""
# right cannot be a StateVector, so no need for product rule
left, right = self.orphans
if left.evaluates_to_number():
# Return zeros of correct size
return pybamm.Matrix(
csr_matrix((self.size, variable.evaluation_array.count(True)))
)
else:
return pybamm.Kron(left_jac, right)

def _binary_evaluate(self, left, right):
""" See :meth:`pybamm.BinaryOperator._binary_evaluate()`. """

return np.outer(left, right).reshape(-1, 1)


class Kron(BinaryOperator):
"""A node in the expression tree representing a (sparse) kronecker product operator

**Extends:** :class:`BinaryOperator`
"""

def __init__(self, left, right):
""" See :meth:`pybamm.BinaryOperator.__init__()`. """

super().__init__("kronecker product", left, right)

def __str__(self):
""" See :meth:`pybamm.Symbol.__str__()`. """
return "kron({!s}, {!s})".format(self.left, self.right)

def diff(self, variable):
""" See :meth:`pybamm.Symbol.diff()`. """
raise NotImplementedError("diff not implemented for symbol of type 'Kron'")

def _binary_jac(self, left_jac, right_jac):
""" See :meth:`pybamm.BinaryOperator._binary_jac()`. """
raise NotImplementedError("jac not implemented for symbol of type 'Kron'")

def _binary_evaluate(self, left, right):
""" See :meth:`pybamm.BinaryOperator._binary_evaluate()`. """
return csr_matrix(kron(left, right))


class Heaviside(BinaryOperator):
"""A node in the expression tree representing a heaviside step function.

Expand Down Expand Up @@ -783,18 +690,6 @@ def _binary_new_copy(self, left, right):
return Heaviside(left, right, self.equal)


def outer(left, right):
"""
Return outer product of two symbols. If the symbols have the same domain, the outer
product is just a multiplication. If they have different domains, make a copy of the
left child with same domain as right child, and then take outer product.
"""
try:
return left * right
except pybamm.DomainError:
return pybamm.Outer(left, right)


def source(left, right, boundary=False):
"""A convinience function for creating (part of) an expression tree representing
a source term. This is necessary for spatial methods where the mass matrix
Expand Down
7 changes: 2 additions & 5 deletions pybamm/expression_tree/operations/convert_to_casadi.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,8 @@ def _convert(self, symbol, t=None, y=None, u=None):
# process children
converted_left = self.convert(left, t, y, u)
converted_right = self.convert(right, t, y, u)
if isinstance(symbol, pybamm.Outer):
return casadi.kron(converted_left, converted_right)
else:
# _binary_evaluate defined in derived classes for specific rules
return symbol._binary_evaluate(converted_left, converted_right)
# _binary_evaluate defined in derived classes for specific rules
return symbol._binary_evaluate(converted_left, converted_right)

elif isinstance(symbol, pybamm.UnaryOperator):
converted_child = self.convert(symbol.child, t, y, u)
Expand Down
8 changes: 0 additions & 8 deletions pybamm/expression_tree/operations/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,6 @@ def find_symbols(symbol, constant_symbols, variable_symbols):
"if scipy.sparse.issparse({1}) else "
"{0} * {1}".format(children_vars[0], children_vars[1])
)
elif isinstance(symbol, pybamm.Outer):
symbol_str = "np.outer({}, {}).reshape(-1, 1)".format(
children_vars[0], children_vars[1]
)
elif isinstance(symbol, pybamm.Kron):
symbol_str = "scipy.sparse.csr_matrix(scipy.sparse.kron({}, {}))".format(
children_vars[0], children_vars[1]
)
else:
symbol_str = children_vars[0] + " " + symbol.name + " " + children_vars[1]

Expand Down
11 changes: 2 additions & 9 deletions pybamm/expression_tree/operations/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,8 @@ def _jac(self, symbol, variable):
# process children
left_jac = self.jac(left, variable)
right_jac = self.jac(right, variable)
# Need to treat outer differently. If the left child of an Outer
# evaluates to number then we need to return a matrix of zeros
# of the correct size, which requires variable.evaluation_array
if isinstance(symbol, pybamm.Outer):
# _outer_jac defined in pybamm.Outer
jac = symbol._outer_jac(left_jac, right_jac, variable)
else:
# _binary_jac defined in derived classes for specific rules
jac = symbol._binary_jac(left_jac, right_jac)
# _binary_jac defined in derived classes for specific rules
jac = symbol._binary_jac(left_jac, right_jac)

elif isinstance(symbol, pybamm.UnaryOperator):
child_jac = self.jac(symbol.child, variable)
Expand Down
12 changes: 9 additions & 3 deletions pybamm/spatial_methods/spatial_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,15 @@ def broadcast(self, symbol, domain, auxiliary_domains, broadcast_type):
)

if broadcast_type == "primary":
out = pybamm.Outer(
symbol, pybamm.Vector(np.ones(primary_domain_size), domain=domain)
)
# Make copies of the child stacked on top of each other
sub_vector = np.ones((primary_domain_size, 1))
if symbol.shape_for_testing == ():
out = symbol * pybamm.Vector(sub_vector)
else:
# Repeat for secondary points
matrix = csr_matrix(kron(eye(symbol.shape_for_testing[0]), sub_vector))
out = pybamm.Matrix(matrix) @ symbol
out.domain = domain
elif broadcast_type == "secondary":
secondary_domain_size = sum(
self.mesh[dom][0].npts_for_broadcast
Expand Down
34 changes: 8 additions & 26 deletions tests/unit/test_discretisations/test_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,13 +781,18 @@ def test_broadcast_2D(self):

disc.set_variable_slices([var])
broad_disc = disc.process_symbol(broad)
self.assertIsInstance(broad_disc, pybamm.Outer)
self.assertIsInstance(broad_disc.children[0], pybamm.StateVector)
self.assertIsInstance(broad_disc.children[1], pybamm.Vector)
self.assertIsInstance(broad_disc, pybamm.MatrixMultiplication)
self.assertIsInstance(broad_disc.children[0], pybamm.Matrix)
self.assertIsInstance(broad_disc.children[1], pybamm.StateVector)
self.assertEqual(
broad_disc.shape,
(mesh["separator"][0].npts * mesh["current collector"][0].npts, 1),
)
y_test = np.linspace(0, 1, mesh["current collector"][0].npts)
np.testing.assert_array_equal(
broad_disc.evaluate(y=y_test),
np.outer(y_test, np.ones(mesh["separator"][0].npts)).reshape(-1, 1),
)

def test_secondary_broadcast_2D(self):
# secondary broadcast in 2D --> Matrix multiplication
Expand All @@ -808,29 +813,6 @@ def test_secondary_broadcast_2D(self):
(mesh["negative particle"][0].npts * mesh["negative electrode"][0].npts, 1),
)

def test_outer(self):

# create discretisation
disc = get_1p1d_discretisation_for_testing()
mesh = disc.mesh

var_z = pybamm.Variable("var_z", ["current collector"])
var_x = pybamm.Vector(
np.linspace(0, 1, mesh["separator"][0].npts), domain="separator"
)

# process Outer variable
disc.set_variable_slices([var_z, var_x])
outer = pybamm.outer(var_z, var_x)
outer_disc = disc.process_symbol(outer)
self.assertIsInstance(outer_disc, pybamm.Outer)
self.assertIsInstance(outer_disc.children[0], pybamm.StateVector)
self.assertIsInstance(outer_disc.children[1], pybamm.Vector)
self.assertEqual(
outer_disc.shape,
(mesh["separator"][0].npts * mesh["current collector"][0].npts, 1),
)

def test_concatenation(self):
a = pybamm.Symbol("a")
b = pybamm.Symbol("b")
Expand Down
48 changes: 0 additions & 48 deletions tests/unit/test_expression_tree/test_binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,54 +58,6 @@ def test_power(self):
pow2 = pybamm.Power(a, b)
self.assertEqual(pow2.evaluate(), 16)

def test_outer(self):
# Outer class
v = pybamm.Vector(np.ones(5), domain="current collector")
w = pybamm.Vector(2 * np.ones(3), domain="test")
outer = pybamm.Outer(v, w)
np.testing.assert_array_equal(outer.evaluate(), 2 * np.ones((15, 1)))
self.assertEqual(outer.domain, w.domain)
self.assertEqual(
str(outer), "outer(Column vector of length 5, Column vector of length 3)"
)

# outer function
# if there is no domain clash, normal multiplication is retured
u = pybamm.Vector(np.linspace(0, 1, 5))
outer = pybamm.outer(u, v)
self.assertIsInstance(outer, pybamm.Multiplication)
np.testing.assert_array_equal(outer.evaluate(), u.evaluate())
# otherwise, Outer class is returned
outer_fun = pybamm.outer(v, w)
outer_class = pybamm.Outer(v, w)
self.assertEqual(outer_fun.id, outer_class.id)

# failures
y = pybamm.StateVector(slice(10))
with self.assertRaisesRegex(
TypeError, "right child must only contain Vectors and Scalars"
):
pybamm.Outer(v, y)
with self.assertRaises(NotImplementedError):
outer_fun.diff(None)

def test_kron(self):
# Kron class
A = pybamm.Matrix(np.eye(2))
b = pybamm.Vector(np.array([[4], [5]]))
kron = pybamm.Kron(A, b)
np.testing.assert_array_equal(
kron.evaluate().toarray(), np.kron(A.entries, b.entries)
)

# failures
with self.assertRaises(NotImplementedError):
kron.diff(None)

y = pybamm.StateVector(slice(0, 2))
with self.assertRaises(NotImplementedError):
kron.jac(y)

def test_known_eval(self):
# Scalars
a = pybamm.Scalar(4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,6 @@ def test_convert_array_symbols(self):
# State Vector
self.assert_casadi_equal(pybamm_y.to_casadi(casadi_t, casadi_y), casadi_y)

# outer product
outer = pybamm.Outer(pybamm_a, pybamm_a)
self.assert_casadi_equal(
outer.to_casadi(), casadi.MX(outer.evaluate()), evalf=True
)

def test_special_functions(self):
a = pybamm.Array(np.array([1, 2, 3, 4, 5]))
self.assert_casadi_equal(pybamm.max(a).to_casadi(), casadi.MX(5), evalf=True)
Expand Down
Loading