Skip to content

Commit

Permalink
#858 add d_dt unary operator, fixes to jac to support this
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Mar 6, 2020
1 parent 5489aa3 commit e2b4ba7
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 16 deletions.
14 changes: 7 additions & 7 deletions pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def is_matrix_zero(expr):
return False


def is_one(expr):
def is_scalar_one(expr):
"""
Utility function to test if an expression evaluates to a constant scalar one
"""
Expand Down Expand Up @@ -253,7 +253,7 @@ def _binary_simplify(self, left, right):
return pybamm.Scalar(1)

# anything to the power of one is itself
if is_scalar_zero(left):
if is_scalar_one(right):
return left

return self.__class__(left, right)
Expand Down Expand Up @@ -425,9 +425,9 @@ def _binary_simplify(self, left, right):
return zeros_of_shape(shape)

# anything multiplied by a scalar one returns itself
if is_one(left):
if is_scalar_one(left):
return right
if is_one(right):
if is_scalar_one(right):
return left

return pybamm.simplify_multiplication_division(self.__class__, left, right)
Expand Down Expand Up @@ -549,7 +549,7 @@ def _binary_simplify(self, left, right):
return pybamm.Array(np.inf * np.ones(left.shape_for_testing))

# anything divided by one is itself
if is_one(right):
if is_scalar_one(right):
return left

return pybamm.simplify_multiplication_division(self.__class__, left, right)
Expand Down Expand Up @@ -622,9 +622,9 @@ def _binary_simplify(self, left, right):
return zeros_of_shape(shape)

# anything multiplied by a scalar one returns itself
if is_one(left):
if is_scalar_one(left):
return right
if is_one(right):
if is_scalar_one(right):
return left

return pybamm.simplify_multiplication_division(self.__class__, left, right)
Expand Down
6 changes: 4 additions & 2 deletions pybamm/expression_tree/independent_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ def _evaluate_for_shape(self):

def _jac(self, variable):
""" See :meth:`pybamm.Symbol._jac()`. """
return pybamm.Scalar(0)

if variable.id == self.id:
return pybamm.Scalar(1)
else:
return pybamm.Scalar(0)

class Time(IndependentVariable):
"""A node in the expression tree representing time
Expand Down
21 changes: 18 additions & 3 deletions pybamm/expression_tree/operations/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,22 @@


class Jacobian(object):
def __init__(self, known_jacs=None):
"""
Helper class to calculate the jacobian of an expression.
Parameters
----------
known_jacs: dict {variable ids -> :class:`pybamm.Symbol`}
cached jacobians
clear_domain: bool
wether or not the jacobian clears the domain (default True)
"""

def __init__(self, known_jacs=None, clear_domain=True):
self._known_jacs = known_jacs or {}
self._clear_domain = clear_domain

def jac(self, symbol, variable):
"""
Expand Down Expand Up @@ -75,6 +89,7 @@ def _jac(self, symbol, variable):
)
)

# jacobian removes the domain(s)
jac.clear_domains()
# jacobian by default removes the domain(s)
if self._clear_domain:
jac.clear_domains()
return jac
15 changes: 11 additions & 4 deletions pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,12 +492,13 @@ def _diff(self, variable):
"Default behaviour for differentiation, overriden by Binary and Unary Operators"
raise NotImplementedError

def jac(self, variable, known_jacs=None):
def jac(self, variable, known_jacs=None, clear_domain=True):
"""
Differentiate a symbol with respect to a (slice of) a State Vector.
See :class:`pybamm.Jacobian`.
"""
return pybamm.Jacobian(known_jacs).jac(self, variable)
jac = pybamm.Jacobian(known_jacs, clear_domain=clear_domain)
return jac.jac(self, variable)

def _jac(self, variable):
"""
Expand Down Expand Up @@ -606,7 +607,7 @@ def is_constant(self):
def evaluate_ignoring_errors(self):
"""
Evaluates the expression. If a node exists in the tree that cannot be evaluated
as a scalar or vector (e.g. Parameter, Variable, StateVector, InputParameter),
as a scalar or vector (e.g. Time, Parameter, Variable, StateVector, InputParameter),
then None is returned. Otherwise the result of the evaluation is given
See Also
Expand All @@ -615,7 +616,7 @@ def evaluate_ignoring_errors(self):
"""
try:
result = self.evaluate(t=0, u="shape test")
result = self.evaluate(u="shape test")
except NotImplementedError:
# return None if NotImplementedError is raised
# (there is a e.g. Parameter, Variable, ... in the tree)
Expand All @@ -628,6 +629,10 @@ def evaluate_ignoring_errors(self):
else:
raise error
except ValueError as e:
# return None if specific ValueError is raised
# (there is a e.g. Time in the tree)
if e.args[0] == "t must be provided":
return None
raise pybamm.ShapeError("Cannot find shape (original error: {})".format(e))
return result

Expand Down Expand Up @@ -762,3 +767,5 @@ def test_shape(self):
self.shape_for_testing
except ValueError as e:
raise pybamm.ShapeError("Cannot find shape (original error: {})".format(e))


26 changes: 26 additions & 0 deletions pybamm/expression_tree/unary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,32 @@ def __init__(self, child, side):
super().__init__("boundary flux", child, side)



def d_dt(expression):
"""
convenience function for taking the time derivative of an expression
Note that this operator is different to the other unary operators in that it is
*not* lazily evaluated, it instead returns the expression tree that is the time
derivative of the input
Parameters
----------
expression : :class:`Symbol`
the time derivative will be performed on this sub-expression
Returns
-------
:class:`Symbol`
the time derivative of ``expression``
"""

return expression.jac(pybamm.t, clear_domain=False)



#
# Methods to call Gradient, Divergence, Laplacian and Gradient_Squared
#
Expand Down
10 changes: 10 additions & 0 deletions pybamm/expression_tree/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ def _evaluate_for_shape(self):
self.domain, self.auxiliary_domains
)

def _jac(self, variable):
if variable == self:
return pybamm.Scalar(1)
elif variable == pybamm.t:
return pybamm.VariableDot(self.name+"'",
domain=self.domain,
auxiliary_domains=self.auxiliary_domains)
else:
return pybamm.Scalar(0)


class VariableDot(Variable):
"""
Expand Down

0 comments on commit e2b4ba7

Please sign in to comment.