Skip to content

Commit

Permalink
#858 style fixes, remove d_dt helper function, fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Mar 7, 2020
1 parent 9ba63f6 commit c8eeb81
Show file tree
Hide file tree
Showing 12 changed files with 50 additions and 62 deletions.
4 changes: 0 additions & 4 deletions pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,10 +1030,6 @@ def check_initial_conditions_rhs(self, model):
y0 = model.concatenated_initial_conditions
# Individual
for var in model.rhs.keys():
print('rhs')
print(model.rhs[var])
print('init')
print(model.initial_conditions[var])
assert (
model.rhs[var].shape == model.initial_conditions[var].shape
), pybamm.ModelError(
Expand Down
6 changes: 3 additions & 3 deletions pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def is_scalar_zero(expr):
Utility function to test if an expression evaluates to a constant scalar zero
"""
if expr.is_constant():
result = expr.evaluate_ignoring_errors()
result = expr.evaluate_ignoring_errors(t=None)
return isinstance(result, numbers.Number) and result == 0
else:
return False
Expand All @@ -24,7 +24,7 @@ def is_matrix_zero(expr):
Utility function to test if an expression evaluates to a constant matrix zero
"""
if expr.is_constant():
result = expr.evaluate_ignoring_errors()
result = expr.evaluate_ignoring_errors(t=None)
return (issparse(result) and result.count_nonzero() == 0) or (
isinstance(result, np.ndarray) and np.all(result == 0)
)
Expand All @@ -37,7 +37,7 @@ def is_scalar_one(expr):
Utility function to test if an expression evaluates to a constant scalar one
"""
if expr.is_constant():
result = expr.evaluate_ignoring_errors()
result = expr.evaluate_ignoring_errors(t=None)
return isinstance(result, numbers.Number) and result == 1
else:
return False
Expand Down
3 changes: 2 additions & 1 deletion pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def evaluate(self, t=None, y=None, y_dot=None, u=None, known_evals=None):
known_evals[self.id] = self._function_evaluate(evaluated_children)
return known_evals[self.id], known_evals
else:
evaluated_children = [child.evaluate(t, y, y_dot, u) for child in self.children]
evaluated_children = [child.evaluate(t, y, y_dot, u)
for child in self.children]
return self._function_evaluate(evaluated_children)

def _evaluate_for_shape(self):
Expand Down
1 change: 1 addition & 0 deletions pybamm/expression_tree/independent_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def _jac(self, variable):
else:
return pybamm.Scalar(0)


class Time(IndependentVariable):
"""A node in the expression tree representing time
Expand Down
28 changes: 20 additions & 8 deletions pybamm/expression_tree/state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def _base_evaluate(self, t=None, y=None, y_dot=None, u=None):

def _jac_diff_vector(self, variable):
"""
Differentiate a slice of a StateVector of size m with respect to another
slice of a different StateVector of size n. This returns a (sparse) zero matrix of size
m x n
Differentiate a slice of a StateVector of size m with respect to another slice
of a different StateVector of size n. This returns a (sparse) zero matrix of
size m x n
Parameters
----------
Expand Down Expand Up @@ -255,13 +255,19 @@ def _base_evaluate(self, t=None, y=None, y_dot=None, u=None):
out = out[:, np.newaxis]
return out

def _jac(self, variable):
if variable.id == pybamm.t.id:
def diff(self, variable):
if variable.id == self.id:
return pybamm.Scalar(1)
elif variable.id == pybamm.t.id:
return StateVectorDot(*self._y_slices, name=self.name + "'",
domain=self.domain,
auxiliary_domains=self.auxiliary_domains,
evaluation_array=self.evaluation_array)
elif isinstance(variable, pybamm.StateVector):
else:
return pybamm.Scalar(0)

def _jac(self, variable):
if isinstance(variable, pybamm.StateVector):
return self._jac_same_vector(variable)
elif isinstance(variable, pybamm.StateVectorDot):
return self._jac_diff_vector(variable)
Expand Down Expand Up @@ -316,11 +322,17 @@ def _base_evaluate(self, t=None, y=None, y_dot=None, u=None):
out = out[:, np.newaxis]
return out

def _jac(self, variable):
if variable.id == pybamm.t.id:
def diff(self, variable):
if variable.id == self.id:
return pybamm.Scalar(1)
elif variable.id == pybamm.t.id:
raise pybamm.ModelError(
"cannot take second time derivative of a state vector"
)
else:
return pybamm.Scalar(0)

def _jac(self, variable):
if isinstance(variable, pybamm.StateVectorDot):
return self._jac_same_vector(variable)
elif isinstance(variable, pybamm.StateVector):
Expand Down
16 changes: 9 additions & 7 deletions pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,8 @@ def _base_evaluate(self, t=None, y=None, y_dot=None, u=None):
array with state values to evaluate when solving (default None)
y_dot : numpy.array, optional
array with time derivatives of state values to evaluate when solving (default None)
array with time derivatives of state values to evaluate when solving
(default None)
"""
raise NotImplementedError(
Expand All @@ -547,7 +548,8 @@ def evaluate(self, t=None, y=None, y_dot=None, u=None, known_evals=None):
y : numpy.array, optional
array with state values to evaluate when solving (default None)
y_dot : numpy.array, optional
array with time derivatives of state values to evaluate when solving (default None)
array with time derivatives of state values to evaluate when solving
(default None)
u : dict, optional
dictionary of inputs to use when solving (default None)
known_evals : dict, optional
Expand Down Expand Up @@ -604,19 +606,20 @@ def is_constant(self):
# do the search, return true if no relevent nodes are found
return not any((isinstance(n, search_types)) for n in self.pre_order())

def evaluate_ignoring_errors(self):
def evaluate_ignoring_errors(self, t=0):
"""
Evaluates the expression. If a node exists in the tree that cannot be evaluated
as a scalar or vector (e.g. Time, Parameter, Variable, StateVector, InputParameter),
then None is returned. Otherwise the result of the evaluation is given
as a scalar or vector (e.g. Time, Parameter, Variable, StateVector,
InputParameter), then None is returned. Otherwise the result of the evaluation
is given
See Also
--------
evaluate : evaluate the expression
"""
try:
result = self.evaluate(u="shape test")
result = self.evaluate(t=t, u="shape test")
except NotImplementedError:
# return None if NotImplementedError is raised
# (there is a e.g. Parameter, Variable, ... in the tree)
Expand Down Expand Up @@ -713,7 +716,6 @@ def shape(self):
try:
y = np.linspace(0.1, 0.9, int(1e4))
evaluated_self = self.evaluate(0, y, y, u="shape test")
print('evaluated self is ',evaluated_self)
# If that fails, fall back to calculating how big y should really be
except ValueError:
state_vectors_in_node = [
Expand Down
26 changes: 0 additions & 26 deletions pybamm/expression_tree/unary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,32 +763,6 @@ def __init__(self, child, side):
super().__init__("boundary flux", child, side)



def d_dt(expression):
"""
convenience function for taking the time derivative of an expression
Note that this operator is different to the other unary operators in that it is
*not* lazily evaluated, it instead returns the expression tree that is the time
derivative of the input
Parameters
----------
expression : :class:`Symbol`
the time derivative will be performed on this sub-expression
Returns
-------
:class:`Symbol`
the time derivative of ``expression``
"""

return expression.jac(pybamm.t, clear_domain=False)



#
# Methods to call Gradient, Divergence, Laplacian and Gradient_Squared
#
Expand Down
16 changes: 9 additions & 7 deletions pybamm/expression_tree/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def _evaluate_for_shape(self):
self.domain, self.auxiliary_domains
)


class Variable(VariableBase):
"""A node in the expression tree represending a dependent variable
Expand All @@ -70,19 +71,21 @@ class Variable(VariableBase):
*Extends:* :class:`Symbol`
"""

def __init__(self, name, domain=None, auxiliary_domains=None):
super().__init__(name, domain=domain, auxiliary_domains=auxiliary_domains)

def _jac(self, variable):
def diff(self, variable):
if variable.id == self.id:
return pybamm.Scalar(1)
elif variable.id == pybamm.t.id:
return pybamm.VariableDot(self.name+"'",
return pybamm.VariableDot(self.name + "'",
domain=self.domain,
auxiliary_domains=self.auxiliary_domains)
else:
return pybamm.Scalar(0)


class VariableDot(VariableBase):
"""
A node in the expression tree represending the time derviative of a dependent
Expand Down Expand Up @@ -124,7 +127,7 @@ def get_variable(self):
domain=self._domain,
auxiliary_domains=self._auxiliary_domains)

def _jac(self, variable):
def diff(self, variable):
if variable.id == self.id:
return pybamm.Scalar(1)
elif variable.id == pybamm.t.id:
Expand Down Expand Up @@ -195,12 +198,11 @@ def _base_evaluate(self, t=None, y=None, y_dot=None, u=None):
except KeyError:
raise KeyError("External variable '{}' not found".format(self.name))

def _jac(self, variable):
def diff(self, variable):
if variable.id == self.id:
return pybamm.Scalar(1)
elif variable.id == pybamm.t.id:
raise pybamm.ModelError("cannot take time derivative of an external variable")
raise pybamm.ModelError(
"cannot take time derivative of an external variable")
else:
return pybamm.Scalar(0)


4 changes: 2 additions & 2 deletions pybamm/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,8 @@ def check_for_time_derivatives(self):
for node in eq.pre_order():
if isinstance(node, pybamm.VariableDot):
raise pybamm.ModelError(
"time derivative of variable found ({}) in algebraic equation {}"
.format(node, key)
"time derivative of variable found ({}) in algebraic"
"equation {}".format(node, key)
)
if isinstance(node, pybamm.StateVectorDot):
raise pybamm.ModelError(
Expand Down
2 changes: 1 addition & 1 deletion pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def report(string):
model.residuals_eval = residuals_eval
model.jacobian_eval = jacobian_eval
y0_guess = y0.flatten()
model.y0 = self.calculate_consistent_state(model, 0, y0_guess,inputs)
model.y0 = self.calculate_consistent_state(model, 0, y0_guess, inputs)
else:
# can use DAE solver to solve ODE model
model.residuals_eval = Residuals(rhs, "residuals", model)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_discretisations/test_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ def test_process_model_ode(self):
# test that any time derivatives of variables in rhs raises an
# error
model = pybamm.BaseModel()
model.rhs = {c: pybamm.div(N) + pybamm.d_dt(c), T: pybamm.div(q), S: pybamm.div(p)}
model.rhs = {c: pybamm.div(N) + c.diff(pybamm.t), T: pybamm.div(q), S: pybamm.div(p)}
model.initial_conditions = {
c: pybamm.Scalar(2),
T: pybamm.Scalar(5),
Expand Down Expand Up @@ -821,7 +821,7 @@ def test_process_model_dae(self):
# error
model = pybamm.BaseModel()
model.rhs = {c: pybamm.div(N)}
model.algebraic = {d: d - 2 * pybamm.d_dt(c)}
model.algebraic = {d: d - 2 * c.diff(pybamm.t)}
model.initial_conditions = {d: pybamm.Scalar(6), c: pybamm.Scalar(3)}
model.boundary_conditions = {
c: {"left": (0, "Neumann"), "right": (0, "Neumann")}
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_solvers/test_casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def test_model_solver_with_dvdt(self):
var1 = pybamm.Variable("var1", domain="negative electrode")
var2 = pybamm.Variable("var2", domain="negative electrode")
model.rhs = {var1: -2 * var1 * pybamm.t}
model.algebraic = {var2: var2 - pybamm.d_dt(var1)}
model.algebraic = {var2: var2 - var1.diff(pybamm.t)}
model.initial_conditions = {var1: 1, var2: 0}
pybamm.make_semi_explicit(model)
disc = get_discretisation_for_testing()
Expand Down

0 comments on commit c8eeb81

Please sign in to comment.