Skip to content

Commit

Permalink
Merge pull request #1978 from pybamm-team/remove-id
Browse files Browse the repository at this point in the history
Remove .id
  • Loading branch information
valentinsulzer authored Jun 8, 2022
2 parents 79b9368 + f1954ef commit b98b446
Show file tree
Hide file tree
Showing 72 changed files with 626 additions and 796 deletions.
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
# [Unreleased](https://github.com/pybamm-team/PyBaMM/)
## Features

- Added `__eq__` and `__hash__` methods for `Symbol` objects, using `.id` ([#1978](https://github.com/pybamm-team/PyBaMM/pull/1978))
## Breaking changes

- Changed some dictionary keys to `Symbol` instead of `Symbol.id` (internal change only, should not affect external facing functions) ([#1978](https://github.com/pybamm-team/PyBaMM/pull/1978))

# [v22.5](https://github.com/pybamm-team/PyBaMM/tree/v22.5) - 2022-05-31

Expand Down Expand Up @@ -32,6 +38,10 @@
- Added "Discharge energy [W.h]", which is the integral of the power in Watts, as an optional output. Set the option "calculate discharge energy" to "true" to get this output ("false" by default, since it can slow down some of the simple models) ([#1969](https://github.com/pybamm-team/PyBaMM/pull/1969)))
- Added an option "calculate heat source for isothermal models" to choose whether or not the heat generation terms are computed when running models with the option `thermal="isothermal"` ([#1958](https://github.com/pybamm-team/PyBaMM/pull/1958))

## Optimizations

- Simplified `model.new_copy()` ([#1977](https://github.com/pybamm-team/PyBaMM/pull/1977))

## Bug fixes

- Fix bug where sensitivity calculation failed if len of `calculate_sensitivities` was less than `inputs` ([#1897](https://github.com/pybamm-team/PyBaMM/pull/1897))
Expand All @@ -40,6 +50,7 @@

## Breaking changes

- Removed `model.new_empty_copy()` (use `model.new_copy()` instead) ([#1977](https://github.com/pybamm-team/PyBaMM/pull/1977))
- Dropped support for Windows 32-bit architecture ([#1964](https://github.com/pybamm-team/PyBaMM/pull/1964))

# [v22.2](https://github.com/pybamm-team/PyBaMM/tree/v22.2) - 2022-02-28
Expand Down
2 changes: 1 addition & 1 deletion examples/notebooks/expression_tree/expression-tree.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@
"source": [
"from tests import get_discretisation_for_testing\n",
"disc = get_discretisation_for_testing()\n",
"disc.y_slices = {c.id: [slice(0, 40)]}\n",
"disc.y_slices = {c: [slice(0, 40)]}\n",
"dcdt = disc.process_symbol(dcdt)\n",
"dcdt.visualise('expression_tree5.png')"
]
Expand Down
20 changes: 10 additions & 10 deletions examples/notebooks/spatial_methods/finite-volumes.ipynb

Large diffs are not rendered by default.

63 changes: 30 additions & 33 deletions pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

def has_bc_of_form(symbol, side, bcs, form):

if symbol.id in bcs:
if bcs[symbol.id][side][1] == form:
if symbol in bcs:
if bcs[symbol][side][1] == form:
return True
else:
return False
Expand Down Expand Up @@ -292,15 +292,15 @@ def set_variable_slices(self, variables):
for domain_mesh in mesh:
end += domain_mesh.npts_for_broadcast_to_nodes
# Add to slices
y_slices[child.id].append(slice(start_, end))
y_slices[child].append(slice(start_, end))
y_slices_explicit[child].append(slice(start_, end))
# Increment start_
start_ = end
else:
end += self._get_variable_size(variable)

# Add to slices
y_slices[variable.id].append(slice(start, end))
y_slices[variable].append(slice(start, end))
y_slices_explicit[variable].append(slice(start, end))
# Add to bounds
lower_bounds.extend([variable.bounds[0]] * (end - start))
Expand Down Expand Up @@ -363,7 +363,7 @@ def set_external_variables(self, model):
# Find the name in the model variables
# Look up dictionary key based on value
try:
idx = [x.id for x in model.variables.values()].index(var.id)
idx = list(model.variables.values()).index(var)
except ValueError:
raise ValueError(
"""
Expand Down Expand Up @@ -423,8 +423,7 @@ def boundary_gradient(left_symbol, right_symbol):
left_symbol_disc, right_symbol_disc, left_mesh, right_mesh
)

# bc_key_ids = [key.id for key in list(model.boundary_conditions.keys())]
bc_key_ids = list(self.bcs.keys())
bc_keys = list(self.bcs.keys())

internal_bcs = {}
for var in model.boundary_conditions.keys():
Expand All @@ -434,24 +433,24 @@ def boundary_gradient(left_symbol, right_symbol):
first_child = children[0]
next_child = children[1]

lbc = self.bcs[var.id]["left"]
lbc = self.bcs[var]["left"]
rbc = (boundary_gradient(first_child, next_child), "Neumann")

if first_child.id not in bc_key_ids:
internal_bcs.update({first_child.id: {"left": lbc, "right": rbc}})
if first_child not in bc_keys:
internal_bcs.update({first_child: {"left": lbc, "right": rbc}})

for current_child, next_child in zip(children[1:-1], children[2:]):
lbc = rbc
rbc = (boundary_gradient(current_child, next_child), "Neumann")
if current_child.id not in bc_key_ids:
if current_child not in bc_keys:
internal_bcs.update(
{current_child.id: {"left": lbc, "right": rbc}}
{current_child: {"left": lbc, "right": rbc}}
)

lbc = rbc
rbc = self.bcs[var.id]["right"]
if children[-1].id not in bc_key_ids:
internal_bcs.update({children[-1].id: {"left": lbc, "right": rbc}})
rbc = self.bcs[var]["right"]
if children[-1] not in bc_keys:
internal_bcs.update({children[-1]: {"left": lbc, "right": rbc}})

self.bcs.update(internal_bcs)

Expand Down Expand Up @@ -504,7 +503,7 @@ def process_boundary_conditions(self, model):
# process and set pybamm.variables first incase required
# in discrisation of other boundary conditions
for key, bcs in model.boundary_conditions.items():
processed_bcs[key.id] = {}
processed_bcs[key] = {}

# check if the boundary condition at the origin for sphere domains is other
# than no flux
Expand All @@ -528,7 +527,7 @@ def process_boundary_conditions(self, model):
eqn, typ = bc
pybamm.logger.debug("Discretise {} ({} bc)".format(key, side))
processed_eqn = self.process_symbol(eqn)
processed_bcs[key.id][side] = (processed_eqn, typ)
processed_bcs[key][side] = (processed_eqn, typ)

return processed_bcs

Expand Down Expand Up @@ -661,7 +660,7 @@ def create_mass_matrix(self, model):
model_variables = model.rhs.keys()
model_slices = []
for v in model_variables:
model_slices.append(self.y_slices[v.id][0])
model_slices.append(self.y_slices[v][0])
sorted_model_variables = [
v for _, v in sorted(zip(model_slices, model_variables))
]
Expand Down Expand Up @@ -792,8 +791,6 @@ def process_dict(self, var_eqn_dict):
if np.prod(eqn.shape_for_testing) == 1 and not isinstance(eqn_key, str):
eqn = pybamm.FullBroadcast(eqn, broadcast_domains=eqn_key.domains)

# note we are sending in the key.id here so we don't have to
# keep calling .id
pybamm.logger.debug("Discretise {!r}".format(eqn_key))

processed_eqn = self.process_symbol(eqn)
Expand All @@ -818,10 +815,10 @@ def process_symbol(self, symbol):
"""
try:
return self._discretised_symbols[symbol.id]
return self._discretised_symbols[symbol]
except KeyError:
discretised_symbol = self._process_symbol(symbol)
self._discretised_symbols[symbol.id] = discretised_symbol
self._discretised_symbols[symbol] = discretised_symbol
discretised_symbol.test_shape()

# Assign mesh as an attribute to the processed variable
Expand Down Expand Up @@ -978,15 +975,15 @@ def _process_symbol(self, symbol):

elif isinstance(symbol, pybamm.VariableDot):
return pybamm.StateVectorDot(
*self.y_slices[symbol.get_variable().id],
*self.y_slices[symbol.get_variable()],
domains=symbol.domains,
)

elif isinstance(symbol, pybamm.Variable):
# Check if variable is a standard variable or an external variable
if any(symbol.id == var.id for var in self.external_variables.values()):
if any(symbol == var for var in self.external_variables.values()):
# Look up dictionary key based on value
idx = [x.id for x in self.external_variables.values()].index(symbol.id)
idx = list(self.external_variables.values()).index(symbol)
name, parent_and_slice = list(self.external_variables.keys())[idx]
if parent_and_slice is None:
# Variable didn't come from a concatenation so we can just create a
Expand Down Expand Up @@ -1014,7 +1011,7 @@ def _process_symbol(self, symbol):
# can't be found. This should usually be caught earlier by
# model.check_well_posedness, but won't be if debug_mode is False
try:
y_slices = self.y_slices[symbol.id]
y_slices = self.y_slices[symbol]
except KeyError:
raise pybamm.ModelError(
"""
Expand Down Expand Up @@ -1088,19 +1085,19 @@ def _concatenate_in_order(self, var_eqn_dict, check_complete=False, sparse=False
unpacked_variables.extend([symbol] + [var for var in symbol.children])
else:
unpacked_variables.append(symbol)
slices.append(self.y_slices[symbol.id][0])
slices.append(self.y_slices[symbol][0])

if check_complete:
# Check keys from the given var_eqn_dict against self.y_slices
ids = {v.id for v in unpacked_variables}
external_id = {v.id for v in self.external_variables.values()}
unpacked_variables_set = set(unpacked_variables)
external_vars = set(self.external_variables.values())
for var in self.external_variables.values():
child_ids = {child.id for child in var.children}
external_id = external_id.union(child_ids)
child_vars = set(var.children)
external_vars = external_vars.union(child_vars)
y_slices_with_external_removed = set(self.y_slices.keys()).difference(
external_id
external_vars
)
if ids != y_slices_with_external_removed:
if unpacked_variables_set != y_slices_with_external_removed:
given_variable_names = [v.name for v in var_eqn_dict.keys()]
raise pybamm.ModelError(
"Initial conditions are insufficient. Only "
Expand Down
31 changes: 9 additions & 22 deletions pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,24 +102,11 @@ def _binary_new_copy(self, left, right):
"""
return self._binary_evaluate(left, right)

def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
def evaluate(self, t=None, y=None, y_dot=None, inputs=None):
"""See :meth:`pybamm.Symbol.evaluate()`."""
if known_evals is not None:
id = self.id
try:
return known_evals[id], known_evals
except KeyError:
left, known_evals = self.left.evaluate(t, y, y_dot, inputs, known_evals)
right, known_evals = self.right.evaluate(
t, y, y_dot, inputs, known_evals
)
value = self._binary_evaluate(left, right)
known_evals[id] = value
return value, known_evals
else:
left = self.left.evaluate(t, y, y_dot, inputs)
right = self.right.evaluate(t, y, y_dot, inputs)
return self._binary_evaluate(left, right)
left = self.left.evaluate(t, y, y_dot, inputs)
right = self.right.evaluate(t, y, y_dot, inputs)
return self._binary_evaluate(left, right)

def _evaluate_for_shape(self):
"""See :meth:`pybamm.Symbol.evaluate_for_shape()`."""
Expand Down Expand Up @@ -181,7 +168,7 @@ def _diff(self, variable):
diff = exponent * (base ** (exponent - 1)) * base.diff(variable)
# derivative if variable is in the exponent (rare, check separately to avoid
# unecessarily big tree)
if any(variable.id == x.id for x in exponent.pre_order()):
if any(variable == x for x in exponent.pre_order()):
diff += (base ** exponent) * pybamm.log(base) * exponent.diff(variable)
return diff

Expand Down Expand Up @@ -581,7 +568,7 @@ def _diff(self, variable):
diff = left.diff(variable)
# derivative if variable is in the right term (rare, check separately to avoid
# unecessarily big tree)
if any(variable.id == x.id for x in right.pre_order()):
if any(variable == x for x in right.pre_order()):
diff += -pybamm.Floor(left / right) * right.diff(variable)
return diff

Expand Down Expand Up @@ -850,7 +837,7 @@ def simplified_addition(left, right):
elif (
isinstance(left, MatrixMultiplication)
and isinstance(right, MatrixMultiplication)
and left.right.id == right.right.id
and left.right == right.right
):
l_left, l_right = left.orphans
r_left = right.orphans[0]
Expand Down Expand Up @@ -955,7 +942,7 @@ def simplified_subtraction(left, right):
return pybamm.simplify_if_constant(Subtraction(left, right))

# a symbol minus itself is 0s of the same shape
if left.id == right.id:
if left == right:
return pybamm.zeros_like(left)

if isinstance(right, pybamm.Addition) and left.is_constant():
Expand Down Expand Up @@ -1185,7 +1172,7 @@ def simplified_division(left, right):
return left

# a symbol divided by itself is 1s of the same shape
if left.id == right.id:
if left == right:
return pybamm.ones_like(left)

# anything multiplied by a matrix one returns itself if
Expand Down
23 changes: 6 additions & 17 deletions pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,23 +98,13 @@ def _concatenation_evaluate(self, children_eval):
else:
return self.concatenation_function(children_eval)

def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
def evaluate(self, t=None, y=None, y_dot=None, inputs=None):
"""See :meth:`pybamm.Symbol.evaluate()`."""
children = self.children
if known_evals is not None:
if self.id not in known_evals:
children_eval = [None] * len(children)
for idx, child in enumerate(children):
children_eval[idx], known_evals = child.evaluate(
t, y, y_dot, inputs, known_evals
)
known_evals[self.id] = self._concatenation_evaluate(children_eval)
return known_evals[self.id], known_evals
else:
children_eval = [None] * len(children)
for idx, child in enumerate(children):
children_eval[idx] = child.evaluate(t, y, y_dot, inputs)
return self._concatenation_evaluate(children_eval)
children_eval = [None] * len(children)
for idx, child in enumerate(children):
children_eval[idx] = child.evaluate(t, y, y_dot, inputs)
return self._concatenation_evaluate(children_eval)

def create_copy(self):
"""See :meth:`pybamm.Symbol.new_copy()`."""
Expand Down Expand Up @@ -430,8 +420,7 @@ def simplified_concatenation(*children):
# Create Concatenation to easily read domains
concat = Concatenation(*children)
if all(
isinstance(child, pybamm.Broadcast)
and child.child.id == children[0].child.id
isinstance(child, pybamm.Broadcast) and child.child == children[0].child
for child in children
):
unique_child = children[0].orphans[0]
Expand Down
24 changes: 7 additions & 17 deletions pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,15 @@ def __str__(self):

def diff(self, variable):
"""See :meth:`pybamm.Symbol.diff()`."""
if variable.id == self.id:
if variable == self:
return pybamm.Scalar(1)
else:
children = self.orphans
partial_derivatives = [None] * len(children)
for i, child in enumerate(self.children):
# if variable appears in the function, differentiate
# function, and apply chain rule
if variable.id in [symbol.id for symbol in child.pre_order()]:
if variable in child.pre_order():
partial_derivatives[i] = self._function_diff(
children, i
) * child.diff(variable)
Expand Down Expand Up @@ -142,22 +142,12 @@ def _function_jac(self, children_jacs):

return jacobian

def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
def evaluate(self, t=None, y=None, y_dot=None, inputs=None):
"""See :meth:`pybamm.Symbol.evaluate()`."""
if known_evals is not None:
if self.id not in known_evals:
evaluated_children = [None] * len(self.children)
for i, child in enumerate(self.children):
evaluated_children[i], known_evals = child.evaluate(
t, y, y_dot, inputs, known_evals=known_evals
)
known_evals[self.id] = self._function_evaluate(evaluated_children)
return known_evals[self.id], known_evals
else:
evaluated_children = [
child.evaluate(t, y, y_dot, inputs) for child in self.children
]
return self._function_evaluate(evaluated_children)
evaluated_children = [
child.evaluate(t, y, y_dot, inputs) for child in self.children
]
return self._function_evaluate(evaluated_children)

def _evaluates_on_edges(self, dimension):
"""See :meth:`pybamm.Symbol._evaluates_on_edges()`."""
Expand Down
4 changes: 2 additions & 2 deletions pybamm/expression_tree/operations/convert_to_casadi.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ def convert(self, symbol, t, y, y_dot, inputs):
The converted symbol
"""
try:
return self._casadi_symbols[symbol.id]
return self._casadi_symbols[symbol]
except KeyError:
# Change inputs to empty dictionary if it's None
inputs = inputs or {}
casadi_symbol = self._convert(symbol, t, y, y_dot, inputs)
self._casadi_symbols[symbol.id] = casadi_symbol
self._casadi_symbols[symbol] = casadi_symbol

return casadi_symbol

Expand Down
Loading

0 comments on commit b98b446

Please sign in to comment.