Skip to content

Commit

Permalink
#746 fixing more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Nov 26, 2019
1 parent e5d5c12 commit 07f4e46
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 11 deletions.
8 changes: 3 additions & 5 deletions pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,14 +667,12 @@ class Outer(BinaryOperator):

def __init__(self, left, right):
""" See :meth:`pybamm.BinaryOperator.__init__()`. """
# cannot have Variable, StateVector or Matrix in the right symbol, as these
# cannot have certain types of objects in the right symbol, as these
# can already be 2D objects (so we can't take an outer product with them)
if right.has_symbol_of_classes(
(pybamm.Variable, pybamm.StateVector, pybamm.Matrix)
(pybamm.Variable, pybamm.StateVector, pybamm.Matrix, pybamm.SpatialVariable)
):
raise TypeError(
"right child must only contain SpatialVariable and scalars" ""
)
raise TypeError("right child must only contain Vectors and Scalars" "")

super().__init__("outer product", left, right)

Expand Down
4 changes: 3 additions & 1 deletion pybamm/spatial_methods/finite_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def spatial_variable(self, symbol):
# for finite volume we use the cell centres
symbol_mesh = self.mesh.combine_submeshes(*symbol.domain)
entries = np.concatenate([mesh.nodes for mesh in symbol_mesh])
return pybamm.Vector(entries, domain=symbol.domain)
return pybamm.Vector(
entries, domain=symbol.domain, auxiliary_domains=symbol.auxiliary_domains
)

def gradient(self, symbol, discretised_symbol, boundary_conditions):
"""Matrix-vector multiplication to implement the gradient operator.
Expand Down
11 changes: 7 additions & 4 deletions tests/unit/test_discretisations/test_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,16 +790,19 @@ def test_broadcast_2D(self):
)

def test_outer(self):
var = pybamm.Variable("var", ["current collector"])
x = pybamm.SpatialVariable("x_s", ["separator"])

# create discretisation
disc = get_1p1d_discretisation_for_testing()
mesh = disc.mesh

var_z = pybamm.Variable("var_z", ["current collector"])
var_x = pybamm.Vector(
np.linspace(0, 1, mesh["separator"][0].npts), domain="separator"
)

# process Outer variable
disc.set_variable_slices([var])
outer = pybamm.outer(var, x)
disc.set_variable_slices([var_z, var_x])
outer = pybamm.outer(var_z, var_x)
outer_disc = disc.process_symbol(outer)
self.assertIsInstance(outer_disc, pybamm.Outer)
self.assertIsInstance(outer_disc.children[0], pybamm.StateVector)
Expand Down
10 changes: 9 additions & 1 deletion tests/unit/test_processed_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def test_processed_variable_3D_x_r(self):
disc.set_variable_slices([var])
x_sol = disc.process_symbol(x).entries[:, 0]
r_sol = disc.process_symbol(r).entries[:, 0]
# Keep only the first iteration of entries
r_sol = r_sol[: len(r_sol) // len(x_sol)]
var_sol = disc.process_symbol(var)
t_sol = np.linspace(0, 1)
y_sol = np.ones(len(x_sol) * len(r_sol))[:, np.newaxis] * np.linspace(0, 5)
Expand All @@ -114,8 +116,10 @@ def test_processed_variable_3D_x_z(self):

disc = tests.get_1p1d_discretisation_for_testing()
disc.set_variable_slices([var])
x_sol = disc.process_symbol(x).entries[:, 0]
z_sol = disc.process_symbol(z).entries[:, 0]
x_sol = disc.process_symbol(x).entries[:, 0]
# Keep only the first iteration of entries
x_sol = x_sol[: len(x_sol) // len(z_sol)]
var_sol = disc.process_symbol(var)
t_sol = np.linspace(0, 1)
y_sol = np.ones(len(x_sol) * len(z_sol))[:, np.newaxis] * np.linspace(0, 5)
Expand Down Expand Up @@ -257,6 +261,8 @@ def test_processed_var_3D_interpolation(self):
disc.set_variable_slices([var])
x_sol = disc.process_symbol(x).entries[:, 0]
r_sol = disc.process_symbol(r).entries[:, 0]
# Keep only the first iteration of entries
r_sol = r_sol[: len(r_sol) // len(x_sol)]
var_sol = disc.process_symbol(var)
t_sol = np.linspace(0, 1)
y_sol = np.ones(len(x_sol) * len(r_sol))[:, np.newaxis] * np.linspace(0, 5)
Expand Down Expand Up @@ -293,6 +299,8 @@ def test_processed_var_3D_interpolation(self):
disc.set_variable_slices([var])
x_sol = disc.process_symbol(x).entries[:, 0]
r_sol = disc.process_symbol(r).entries[:, 0]
# Keep only the first iteration of entries
r_sol = r_sol[: len(r_sol) // len(x_sol)]
var_sol = disc.process_symbol(var)
t_sol = np.linspace(0, 1)
y_sol = np.ones(len(x_sol) * len(r_sol))[:, np.newaxis] * np.linspace(0, 5)
Expand Down

0 comments on commit 07f4e46

Please sign in to comment.