Skip to content

Commit

Permalink
#963 make equations and bcs custom dictionaries, and add check for jac
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Apr 24, 2020
1 parent dc19206 commit c1642b8
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 52 deletions.
8 changes: 7 additions & 1 deletion pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,10 +498,16 @@ def _diff(self, variable):

def jac(self, variable, known_jacs=None, clear_domain=True):
"""
Differentiate a symbol with respect to a (slice of) a State Vector.
Differentiate a symbol with respect to a (slice of) a StateVector
or StateVectorDot.
See :class:`pybamm.Jacobian`.
"""
jac = pybamm.Jacobian(known_jacs, clear_domain=clear_domain)
if not isinstance(variable, (pybamm.StateVector, pybamm.StateVectorDot)):
raise TypeError(
"Jacobian can only be taken with respect to a 'StateVector' "
"or 'StateVectorDot', but {} is a {}".format(variable, type(variable))
)
return jac.jac(self, variable)

def _jac(self, variable):
Expand Down
142 changes: 97 additions & 45 deletions pybamm/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,30 +128,6 @@ def __init__(self, name="Unnamed model"):
# Default timescale is 1 second
self.timescale = pybamm.Scalar(1)

def _set_dictionary(self, dict, name):
"""
Convert any scalar equations in dict to 'pybamm.Scalar'
and check that domains are consistent
"""
# Convert any numbers to a pybamm.Scalar
for var, eqn in dict.items():
if isinstance(eqn, numbers.Number):
dict[var] = pybamm.Scalar(eqn)

if not all(
[
variable.domain == equation.domain
or variable.domain == []
or equation.domain == []
for variable, equation in dict.items()
]
):
raise pybamm.DomainError(
"variable and equation in '{}' must have the same domain".format(name)
)

return dict

@property
def name(self):
return self._name
Expand All @@ -166,24 +142,24 @@ def rhs(self):

@rhs.setter
def rhs(self, rhs):
self._rhs = self._set_dictionary(rhs, "rhs")
self._rhs = EquationDict("rhs", rhs)

@property
def algebraic(self):
return self._algebraic

@algebraic.setter
def algebraic(self, algebraic):
self._algebraic = self._set_dictionary(algebraic, "algebraic")
self._algebraic = EquationDict("algebraic", algebraic)

@property
def initial_conditions(self):
return self._initial_conditions

@initial_conditions.setter
def initial_conditions(self, initial_conditions):
self._initial_conditions = self._set_dictionary(
initial_conditions, "initial_conditions"
self._initial_conditions = EquationDict(
"initial_conditions", initial_conditions
)

@property
Expand All @@ -192,23 +168,7 @@ def boundary_conditions(self):

@boundary_conditions.setter
def boundary_conditions(self, boundary_conditions):
# Convert any numbers to a pybamm.Scalar
for var, bcs in boundary_conditions.items():
for side, bc in bcs.items():
if isinstance(bc[0], numbers.Number):
# typ is the type of the bc, e.g. "Dirichlet" or "Neumann"
eqn, typ = boundary_conditions[var][side]
boundary_conditions[var][side] = (pybamm.Scalar(eqn), typ)
# Check types
if bc[1] not in ["Dirichlet", "Neumann"]:
raise pybamm.ModelError(
"""
boundary condition types must be Dirichlet or Neumann, not '{}'
""".format(
bc[1]
)
)
self._boundary_conditions = boundary_conditions
self._boundary_conditions = BoundaryConditionsDict(boundary_conditions)

@property
def variables(self):
Expand Down Expand Up @@ -723,3 +683,95 @@ def find_symbol_in_model(model, name):
dic_return = find_symbol_in_dict(dic, name)
if dic_return:
return dic_return


class EquationDict(dict):
def __init__(self, name, equations):
self.name = name
equations = self.check_and_convert_equations(equations)
super().__init__(equations)

def __setitem__(self, key, value):
"Call the update functionality when doing a setitem"
self.update({key: value})

def update(self, equations):
equations = self.check_and_convert_equations(equations)
super().update(equations)

def check_and_convert_equations(self, equations):
"""
Convert any scalar equations in dict to 'pybamm.Scalar'
and check that domains are consistent
"""
# Convert any numbers to a pybamm.Scalar
for var, eqn in equations.items():
if isinstance(eqn, numbers.Number):
equations[var] = pybamm.Scalar(eqn)

if not all(
[
variable.domain == equation.domain
or variable.domain == []
or equation.domain == []
for variable, equation in equations.items()
]
):
raise pybamm.DomainError(
"variable and equation in '{}' must have the same domain".format(
self.name
)
)

# For initial conditions, check that the equation doesn't contain any
# Variable objects
# skip this if the dictionary has no "name" attribute (which will be the case
# after pickling)
if hasattr(self, "name") and self.name == "initial_conditions":
for var, eqn in equations.items():
if eqn.has_symbol_of_classes(pybamm.Variable):
unpacker = pybamm.SymbolUnpacker(pybamm.Variable)
variable_in_equation = list(unpacker.unpack_symbol(eqn).values())[0]
raise TypeError(
"Initial conditions cannot contain 'Variable' objects, "
"but '{!r}' found in initial conditions for '{}'".format(
variable_in_equation, var
)
)

return equations


class BoundaryConditionsDict(dict):
def __init__(self, bcs):
bcs = self.check_and_convert_bcs(bcs)
super().__init__(bcs)

def __setitem__(self, key, value):
"Call the update functionality when doing a setitem"
self.update({key: value})

def update(self, bcs):
bcs = self.check_and_convert_bcs(bcs)
super().update(bcs)

def check_and_convert_bcs(self, boundary_conditions):
""" Convert any scalar bcs in dict to 'pybamm.Scalar', and check types """
# Convert any numbers to a pybamm.Scalar
for var, bcs in boundary_conditions.items():
for side, bc in bcs.items():
if isinstance(bc[0], numbers.Number):
# typ is the type of the bc, e.g. "Dirichlet" or "Neumann"
eqn, typ = boundary_conditions[var][side]
boundary_conditions[var][side] = (pybamm.Scalar(eqn), typ)
# Check types
if bc[1] not in ["Dirichlet", "Neumann"]:
raise pybamm.ModelError(
"""
boundary condition types must be Dirichlet or Neumann, not '{}'
""".format(
bc[1]
)
)

return boundary_conditions
15 changes: 12 additions & 3 deletions tests/unit/test_expression_tree/test_operations/test_jac.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ def test_multi_var_function(arg1, arg2):


class TestJacobian(unittest.TestCase):
def test_variable_is_statevector(self):
a = pybamm.Symbol("a")
with self.assertRaisesRegex(
TypeError, "Jacobian can only be taken with respect to a 'StateVector'"
):
a.jac(a)

def test_linear(self):
y = pybamm.StateVector(slice(0, 4))
u = pybamm.StateVector(slice(0, 2))
Expand Down Expand Up @@ -233,7 +240,7 @@ def test_jac_of_number(self):
a = pybamm.Scalar(1)
b = pybamm.Scalar(2)

y = pybamm.Variable("y")
y = pybamm.StateVector(slice(0, 1))

self.assertEqual(a.jac(y).evaluate(), 0)

Expand Down Expand Up @@ -261,14 +268,16 @@ def test_jac_of_symbol(self):
def test_spatial_operator(self):
a = pybamm.Variable("a")
b = pybamm.SpatialOperator("Operator", a)
y = pybamm.StateVector(slice(0, 1))
with self.assertRaises(NotImplementedError):
b.jac(None)
b.jac(y)

def test_jac_of_unary_operator(self):
a = pybamm.Scalar(1)
b = pybamm.UnaryOperator("Operator", a)
y = pybamm.StateVector(slice(0, 1))
with self.assertRaises(NotImplementedError):
b.jac(None)
b.jac(y)

def test_jac_of_independent_variable(self):
a = pybamm.IndependentVariable("Variable")
Expand Down
13 changes: 10 additions & 3 deletions tests/unit/test_models/test_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,16 @@ def test_initial_conditions_set_get(self):

# Test number input
c0 = pybamm.Symbol("c0")
model.initial_conditions = {c0: 34}
model.initial_conditions[c0] = 34
self.assertIsInstance(model.initial_conditions[c0], pybamm.Scalar)
self.assertEqual(model.initial_conditions[c0].value, 34)

# Variable in initial conditions should fail
with self.assertRaisesRegex(
TypeError, "Initial conditions cannot contain 'Variable' objects"
):
model.initial_conditions = {c0: pybamm.Variable("v")}

# non-matching domains should fail
with self.assertRaises(pybamm.DomainError):
model.initial_conditions = {
Expand All @@ -72,8 +78,9 @@ def test_boundary_conditions_set_get(self):

# Test number input
c0 = pybamm.Symbol("c0")
model.boundary_conditions = {
c0: {"left": (-2, "Dirichlet"), "right": (4, "Dirichlet")}
model.boundary_conditions[c0] = {
"left": (-2, "Dirichlet"),
"right": (4, "Dirichlet"),
}
self.assertIsInstance(model.boundary_conditions[c0]["left"][0], pybamm.Scalar)
self.assertIsInstance(model.boundary_conditions[c0]["right"][0], pybamm.Scalar)
Expand Down

0 comments on commit c1642b8

Please sign in to comment.