Skip to content

Commit

Permalink
#923 some under-the-hood optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Mar 28, 2020
1 parent 8ba60aa commit f591eec
Show file tree
Hide file tree
Showing 31 changed files with 349 additions and 162 deletions.
1 change: 1 addition & 0 deletions pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def version(formatted=False):
)
from .expression_tree.operations.jacobian import Jacobian
from .expression_tree.operations.convert_to_casadi import CasadiConverter
from .expression_tree.operations.unpack_symbols import SymbolUnpacker

#
# Model classes
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def diff(self, variable):
children = self.orphans
partial_derivatives = [None] * len(children)
for i, child in enumerate(self.children):
# if variable appears in the function,use autograd to differentiate
# if variable appears in the function, differentiate
# function, and apply chain rule
if variable.id in [symbol.id for symbol in child.pre_order()]:
partial_derivatives[i] = self._function_diff(
Expand Down
89 changes: 89 additions & 0 deletions pybamm/expression_tree/operations/unpack_symbols.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#
# Helper function to unpack a symbol
#


class SymbolUnpacker(object):
"""
Helper class to unpack a (set of) symbol(s) to find all instances of a class.
Uses caching to speed up the process.
Parameters
----------
classes_to_find : list of pybamm classes
Classes to identify in the equations
unpacked_symbols: dict {variable ids -> :class:`pybamm.Symbol`}
cached unpacked equations
"""

def __init__(self, classes_to_find, unpacked_symbols=None):
self.classes_to_find = classes_to_find
self._unpacked_symbols = unpacked_symbols or {}

def unpack_list_of_symbols(self, list_of_symbols):
"""
Unpack a list of symbols. See :meth:`EquationUnpacker.unpack()`
Parameters
----------
list_of_symbols : list of :class:`pybamm.Symbol`
List of symbols to unpack
Returns
-------
list of :class:`pybamm.Symbol`
List of unpacked symbols with class in `self.classes_to_find`
"""
all_instances = {}
for symbol in list_of_symbols:
new_instances = self.unpack_symbol(symbol)
all_instances.update(new_instances)

return all_instances

def unpack_symbol(self, symbol):
"""
This function recurses down the tree, unpacking the symbols and saving the ones
that have a class in `self.classes_to_find`.
Parameters
----------
symbol : list of :class:`pybamm.Symbol`
The symbols to unpack
Returns
-------
list of :class:`pybamm.Symbol`
List of unpacked symbols with class in `self.classes_to_find`
"""

try:
return self._unpacked_symbols[symbol.id]
except KeyError:
unpacked = self._unpack(symbol)
self._unpacked_symbols[symbol.id] = unpacked
return unpacked

def _unpack(self, symbol):
""" See :meth:`EquationUnpacker.unpack()`. """

children = symbol.children

# If symbol has no children, just check its class
if len(children) == 0:
# found a symbol of the right class -> return it
if isinstance(symbol, self.classes_to_find):
return {symbol.id: symbol}
# otherwise return empty dictionary
else:
return {}

else:
# iterate over all children
found_vars = {}
for child in children:
# call back unpack_symbol to cache values
child_vars = self.unpack_symbol(child)
found_vars.update(child_vars)
return found_vars

14 changes: 6 additions & 8 deletions pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,9 +485,8 @@ def diff(self, variable):
return pybamm.Scalar(1)
elif any(variable.id == x.id for x in self.pre_order()):
return self._diff(variable)
elif variable.id == pybamm.t.id and any(
isinstance(x, (pybamm.VariableBase, pybamm.StateVectorBase))
for x in self.pre_order()
elif variable.id == pybamm.t.id and self.has_symbol_of_classes(
(pybamm.VariableBase, pybamm.StateVectorBase)
):
return self._diff(variable)
else:
Expand Down Expand Up @@ -609,7 +608,7 @@ def is_constant(self):
)

# do the search, return true if no relevent nodes are found
return not any((isinstance(n, search_types)) for n in self.pre_order())
return not self.has_symbol_of_classes(search_types)

def evaluate_ignoring_errors(self, t=0):
"""
Expand Down Expand Up @@ -720,15 +719,14 @@ def shape(self):
Shape of an object, found by evaluating it with appropriate t and y.
"""
# Default behaviour is to try to evaluate the object directly
# Try with some large y, to avoid having to use pre_order (slow)
# Try with some large y, to avoid having to unpack (slow)
try:
y = np.linspace(0.1, 0.9, int(1e4))
evaluated_self = self.evaluate(0, y, y, inputs="shape test")
# If that fails, fall back to calculating how big y should really be
except ValueError:
state_vectors_in_node = [
x for x in self.pre_order() if isinstance(x, pybamm.StateVector)
]
unpacker = pybamm.SymbolUnpacker(pybamm.StateVector)
state_vectors_in_node = unpacker.unpack_symbol(self).values()
if state_vectors_in_node == []:
y = None
else:
Expand Down
4 changes: 2 additions & 2 deletions pybamm/expression_tree/unary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def _unary_jac(self, child_jac):
# when trying to simplify the node Index(child_jac). Instead, search the
# tree for StateVectors and return a matrix of zeros of the correct size
# if none are found.
if all([not (isinstance(n, pybamm.StateVector)) for n in self.pre_order()]):
if not self.has_symbol_of_classes(pybamm.StateVector):
jac = csr_matrix((1, child_jac.shape[1]))
return pybamm.Matrix(jac)
else:
Expand Down Expand Up @@ -297,7 +297,7 @@ def _unary_simplify(self, simplified_child):
search_types = (pybamm.Variable, pybamm.StateVector, pybamm.SpatialVariable)

# do the search, return a scalar zero node if no relevent nodes are found
if all([not (isinstance(n, search_types)) for n in self.pre_order()]):
if not self.has_symbol_of_classes(search_types):
return pybamm.Scalar(0)
else:
return self.__class__(simplified_child)
Expand Down
90 changes: 45 additions & 45 deletions pybamm/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,13 +398,13 @@ def check_for_time_derivatives(self):
for node in eq.pre_order():
if isinstance(node, pybamm.VariableDot):
raise pybamm.ModelError(
"time derivative of variable found ({}) in rhs equation {}"
.format(node, key)
"time derivative of variable found "
"({}) in rhs equation {}".format(node, key)
)
if isinstance(node, pybamm.StateVectorDot):
raise pybamm.ModelError(
"time derivative of state vector found ({}) in rhs equation {}"
.format(node, key)
"time derivative of state vector found "
"({}) in rhs equation {}".format(node, key)
)

# Check that no variable time derivatives exist in the algebraic equations
Expand Down Expand Up @@ -433,37 +433,43 @@ def check_well_determined(self, post_discretisation):
# For equations we look through the whole expression tree.
# "Variables" can be Concatenations so we also have to look in the whole
# expression tree
unpacker = pybamm.SymbolUnpacker((pybamm.Variable, pybamm.VariableDot))

for var, eqn in self.rhs.items():
# Find all variables and variabledot objects
vars_in_rhs_keys_dict = unpacker.unpack_symbol(var)
vars_in_eqns_dict = unpacker.unpack_symbol(eqn)

# Store ids only
# Look only for Variable (not VariableDot) in rhs keys
vars_in_rhs_keys.update(
[x.id for x in var.pre_order() if isinstance(x, pybamm.Variable)]
)
vars_in_eqns.update(
[x.id for x in eqn.pre_order() if isinstance(x, pybamm.Variable)]
)
vars_in_eqns.update(
[x.get_variable().id for x in eqn.pre_order()
if isinstance(x, pybamm.VariableDot)]
[
var_id
for var_id, var in vars_in_rhs_keys_dict.items()
if isinstance(var, pybamm.Variable)
]
)
vars_in_eqns.update(vars_in_eqns_dict.keys())
for var, eqn in self.algebraic.items():
# Find all variables and variabledot objects
vars_in_algebraic_keys_dict = unpacker.unpack_symbol(var)
vars_in_eqns_dict = unpacker.unpack_symbol(eqn)

# Store ids only
# Look only for Variable (not VariableDot) in algebraic keys
vars_in_algebraic_keys.update(
[x.id for x in var.pre_order() if isinstance(x, pybamm.Variable)]
)
vars_in_eqns.update(
[x.id for x in eqn.pre_order() if isinstance(x, pybamm.Variable)]
)
vars_in_eqns.update(
[x.get_variable().id for x in eqn.pre_order()
if isinstance(x, pybamm.VariableDot)]
[
var_id
for var_id, var in vars_in_algebraic_keys_dict.items()
if isinstance(var, pybamm.Variable)
]
)
vars_in_eqns.update(vars_in_eqns_dict.keys())
for var, side_eqn in self.boundary_conditions.items():
for side, (eqn, typ) in side_eqn.items():
vars_in_eqns.update(
[x.id for x in eqn.pre_order() if isinstance(x, pybamm.Variable)]
)
vars_in_eqns.update(
[x.get_variable().id for x in eqn.pre_order()
if isinstance(x, pybamm.VariableDot)]
)
vars_in_eqns_dict = unpacker.unpack_symbol(eqn)
vars_in_eqns.update(vars_in_eqns_dict.keys())

# If any keys are repeated between rhs and algebraic then the model is
# overdetermined
if not set(vars_in_rhs_keys).isdisjoint(vars_in_algebraic_keys):
Expand Down Expand Up @@ -501,11 +507,12 @@ def check_algebraic_equations(self, post_discretisation):
equation
"""
vars_in_bcs = set()
for var, side_eqn in self.boundary_conditions.items():
for eqn, _ in side_eqn.values():
vars_in_bcs.update(
[x.id for x in eqn.pre_order() if isinstance(x, pybamm.Variable)]
)
unpacker = pybamm.SymbolUnpacker(pybamm.Variable)
for side_eqn in self.boundary_conditions.values():
all_vars = unpacker.unpack_list_of_symbols(
[eqn for eqn, _ in side_eqn.values()]
)
vars_in_bcs.update(all_vars.keys())
if not post_discretisation:
# After the model has been defined, each algebraic equation key should
# appear in that algebraic equation, or in the boundary conditions
Expand All @@ -524,7 +531,7 @@ def check_algebraic_equations(self, post_discretisation):
# with the state vectors in the algebraic equations. Instead, we check
# that each algebraic equation contains some StateVector
for eqn in self.algebraic.values():
if not any(isinstance(x, pybamm.StateVector) for x in eqn.pre_order()):
if not eqn.has_symbol_of_classes(pybamm.StateVector):
raise pybamm.ModelError(
"each algebraic equation must contain at least one StateVector"
)
Expand Down Expand Up @@ -553,12 +560,8 @@ def check_ics_bcs(self):
for x in symbol.pre_order()
):
raise pybamm.ModelError(
"""
no boundary condition given for
variable '{}' with equation '{}'.
""".format(
var, eqn
)
"no boundary condition given for "
"variable '{}' with equation '{}'.".format(var, eqn)
)

def check_default_variables_dictionaries(self):
Expand All @@ -581,12 +584,9 @@ def check_default_variables_dictionaries(self):

def check_variables(self):
# Create list of all Variable nodes that appear in the model's list of variables
all_vars = {}
for eqn in self.variables.values():
# Add all variables in the equation to the list of variables
all_vars.update(
{x.id: x for x in eqn.pre_order() if isinstance(x, pybamm.Variable)}
)
unpacker = pybamm.SymbolUnpacker(pybamm.Variable)
all_vars = unpacker.unpack_list_of_symbols(self.variables.values())

var_ids_in_keys = set()

model_and_external_variables = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,19 @@ def __init__(self, param, domain=None, reactions=None):
super().__init__(param, domain)
self.reactions = reactions

def _get_standard_potential_variables(self, phi_e, phi_e_av):
def _get_standard_potential_variables(self, phi_e_n, phi_e_s, phi_e_p):
"""
A private function to obtain the standard variables which
can be derived from the potential in the electrolyte.
Parameters
----------
phi_e : :class:`pybamm.Symbol`
The potential in the electrolyte.
phi_e_av : :class:`pybamm.Symbol`
The cell-averaged potential in the electrolyte.
phi_e_n : :class:`pybamm.Symbol`
The electrolyte potential in the negative electrode.
phi_e_s : :class:`pybamm.Symbol`
The electrolyte potential in the separator.
phi_e_p : :class:`pybamm.Symbol`
The electrolyte potential in the positive electrode.
Returns
-------
Expand All @@ -45,8 +47,8 @@ def _get_standard_potential_variables(self, phi_e, phi_e_av):

param = self.param
pot_scale = param.potential_scale
phi_e_n, phi_e_s, phi_e_p = phi_e.orphans

phi_e = pybamm.Concatenation(phi_e_n, phi_e_s, phi_e_p)
phi_e_n_av = pybamm.x_average(phi_e_n)
phi_e_s_av = pybamm.x_average(phi_e_s)
phi_e_p_av = pybamm.x_average(phi_e_p)
Expand Down Expand Up @@ -277,15 +279,15 @@ def _get_whole_cell_variables(self, variables):
phi_e_n = variables["Negative electrolyte potential"]
phi_e_s = variables["Separator electrolyte potential"]
phi_e_p = variables["Positive electrolyte potential"]
phi_e = pybamm.Concatenation(phi_e_n, phi_e_s, phi_e_p)
phi_e_av = pybamm.x_average(phi_e)

i_e_n = variables["Negative electrolyte current density"]
i_e_s = variables["Separator electrolyte current density"]
i_e_p = variables["Positive electrolyte current density"]
i_e = pybamm.Concatenation(i_e_n, i_e_s, i_e_p)

variables.update(self._get_standard_potential_variables(phi_e, phi_e_av))
variables.update(
self._get_standard_potential_variables(phi_e_n, phi_e_s, phi_e_p)
)
variables.update(self._get_standard_current_variables(i_e))

return variables
Expand Down
Loading

0 comments on commit f591eec

Please sign in to comment.