Skip to content

Commit

Permalink
feat: add discrete time sum expression tree node (#4501)
Browse files Browse the repository at this point in the history
* feat: add discrete time sum expression tree node #4485

* docs: fix math syntax in docstring

* remove prints

* test casadi solver as well

* coverage

* coverage

* add to changelog and tidy solution test
  • Loading branch information
martinjrobins authored Oct 9, 2024
1 parent 9e62b66 commit e4eb82a
Show file tree
Hide file tree
Showing 11 changed files with 319 additions and 41 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- Added phase-dependent particle options to LAM
([#4369](https://github.com/pybamm-team/PyBaMM/pull/4369))
- Added a lithium ion equivalent circuit model with split open circuit voltages for each electrode (`SplitOCVR`). ([#4330](https://github.com/pybamm-team/PyBaMM/pull/4330))
- Added the `pybamm.DiscreteTimeSum` expression node to sum an expression over a sequence of data times, and accompanying `pybamm.DiscreteTimeData` class to store the data times and values ([#4501](https://github.com/pybamm-team/PyBaMM/pull/4501))

## Optimizations

Expand Down
2 changes: 2 additions & 0 deletions src/pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .expression_tree.broadcasts import *
from .expression_tree.functions import *
from .expression_tree.interpolant import Interpolant
from .expression_tree.discrete_time_sum import *
from .expression_tree.input_parameter import InputParameter
from .expression_tree.parameter import Parameter, FunctionParameter
from .expression_tree.scalar import Scalar
Expand Down Expand Up @@ -158,6 +159,7 @@

# Solver classes
from .solvers.solution import Solution, EmptySolution, make_cycle_solution
from .solvers.processed_variable_time_integral import ProcessedVariableTimeIntegral
from .solvers.processed_variable import ProcessedVariable, process_variable
from .solvers.processed_variable_computed import ProcessedVariableComputed
from .solvers.base_solver import BaseSolver
Expand Down
2 changes: 1 addition & 1 deletion src/pybamm/expression_tree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
'concatenations', 'exceptions', 'functions', 'independent_variable',
'input_parameter', 'interpolant', 'matrix', 'operations',
'parameter', 'printing', 'scalar', 'state_vector', 'symbol',
'unary_operators', 'variable', 'vector']
'unary_operators', 'variable', 'vector', 'discrete_time_sum' ]
88 changes: 88 additions & 0 deletions src/pybamm/expression_tree/discrete_time_sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import pybamm
import numpy as np


class DiscreteTimeData(pybamm.Interpolant):
"""
A class for representing data that is only defined at discrete points in time.
This is implemented as a 1D interpolant with the time points as the nodes.
Parameters
----------
time_points : :class:`numpy.ndarray`
The time points at which the data is defined
data : :class:`numpy.ndarray`
The data to be interpolated
name : str
The name of the data
"""

def __init__(self, time_points: np.ndarray, data: np.ndarray, name: str):
super().__init__(time_points, data, pybamm.t, name)

def create_copy(self, new_children=None, perform_simplifications=True):
"""See :meth:`pybamm.Symbol.new_copy()`."""
return pybamm.DiscreteTimeData(self.x[0], self.y, self.name)


class DiscreteTimeSum(pybamm.UnaryOperator):
"""
A node in the expression tree representing a discrete time sum operator.
.. math::
\\sum_{i=0}^{N} f(y(t_i), t_i)
where f is the expression given by the child, and the sum is over the discrete
time points t_i. The set of time points is given by the :class:`pybamm.DiscreteTimeData` node,
which must be somewhere in the expression tree given by the child. If the child
does not contain a :class:`pybamm.DiscreteTimeData` node, then an error will be raised when
the node is created. If the child contains multiple :class:`pybamm.DiscreteTimeData` nodes,
an error will be raised when the node is created.
Parameters
----------
child: :class:`pybamm.Symbol`
The symbol to be summed
Attributes
----------
data: :class:`pybamm.DiscreteTimeData`
The discrete time data node in the child
Raises
------
:class:`pybamm.ModelError`
If the child does not contain a :class:`pybamm.DiscreteTimeData` node, or if the child
contains multiple :class:`pybamm.DiscreteTimeData` nodes.
"""

def __init__(self, child: pybamm.Symbol):
self.data = None
for node in child.pre_order():
if isinstance(node, DiscreteTimeData):
# Check that there is exactly one DiscreteTimeData node in the child
if self.data is not None:
raise pybamm.ModelError(
"DiscreteTimeSum can only have one DiscreteTimeData node in the child"
)
self.data = node
if self.data is None:
raise pybamm.ModelError(
"DiscreteTimeSum must contain a DiscreteTimeData node"
)
super().__init__("discrete time sum", child)

@property
def sum_values(self):
return self.data.y

@property
def sum_times(self):
return self.data.x[0]

def _unary_evaluate(self, child):
# return result of evaluating the child, we'll only implement the sum once the model is solved (in pybamm.ProcessedVariable)
return child
2 changes: 1 addition & 1 deletion src/pybamm/solvers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
'casadi_algebraic_solver', 'casadi_solver', 'dummy_solver',
'idaklu_jax', 'idaklu_solver', 'jax_bdf_solver', 'jax_solver',
'lrudict', 'processed_variable', 'processed_variable_computed',
'scipy_solver', 'solution']
'scipy_solver', 'solution', 'processed_variable_time_integral']
78 changes: 50 additions & 28 deletions src/pybamm/solvers/processed_variable.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#
# Processed Variable class
#
from typing import Optional
import casadi
import numpy as np
import pybamm
Expand Down Expand Up @@ -29,14 +30,16 @@ class ProcessedVariable:
`base_Variable.evaluate` (but more efficiently).
solution : :class:`pybamm.Solution`
The solution object to be used to create the processed variables
time_integral : :class:`pybamm.ProcessedVariableTimeIntegral`, optional
Not none if the variable is to be time-integrated (default is None)
"""

def __init__(
self,
base_variables,
base_variables_casadi,
solution,
cumtrapz_ic=None,
time_integral: Optional[pybamm.ProcessedVariableTimeIntegral] = None,
):
self.base_variables = base_variables
self.base_variables_casadi = base_variables_casadi
Expand All @@ -50,7 +53,7 @@ def __init__(
self.mesh = base_variables[0].mesh
self.domain = base_variables[0].domain
self.domains = base_variables[0].domains
self.cumtrapz_ic = cumtrapz_ic
self.time_integral = time_integral

# Process spatial variables
geometry = solution.all_models[0].geometry
Expand Down Expand Up @@ -271,18 +274,21 @@ def __call__(
self._coords_raw,
)

processed_entries = self._xr_interpolate(
entries_for_interp,
coords,
observe_raw,
t,
x,
r,
y,
z,
R,
fill_value,
)
if self.time_integral is None:
processed_entries = self._xr_interpolate(
entries_for_interp,
coords,
observe_raw,
t,
x,
r,
y,
z,
R,
fill_value,
)
else:
processed_entries = entries_for_interp
else:
processed_entries = entries

Expand Down Expand Up @@ -343,6 +349,16 @@ def _check_observe_raw(self, t):
t_observe (np.ndarray): time points to observe
observe_raw (bool): True if observing the raw data
"""
# if this is a time integral variable, t must be None and we observe either the
# data times (for a discrete sum) or the solution times (for a continuous sum)
if self.time_integral is not None:
if self.time_integral.method == "discrete":
# discrete sum should be observed at the discrete times
t = self.time_integral.discrete_times
else:
# assume we can do a sufficiently accurate trapezoidal integration at t_pts
t = self.t_pts

observe_raw = (t is None) or (
np.asarray(t).size == len(self.t_pts) and np.all(t == self.t_pts)
)
Expand Down Expand Up @@ -483,14 +499,14 @@ def __init__(
base_variables,
base_variables_casadi,
solution,
cumtrapz_ic=None,
time_integral: Optional[pybamm.ProcessedVariableTimeIntegral] = None,
):
self.dimensions = 0
super().__init__(
base_variables,
base_variables_casadi,
solution,
cumtrapz_ic=cumtrapz_ic,
time_integral=time_integral,
)

def _observe_raw_python(self):
Expand All @@ -510,13 +526,19 @@ def _observe_raw_python(self):
idx += 1
return entries

def _observe_postfix(self, entries, _):
if self.cumtrapz_ic is None:
def _observe_postfix(self, entries, t):
if self.time_integral is None:
return entries

return cumulative_trapezoid(
entries, self.t_pts, initial=float(self.cumtrapz_ic)
)
if self.time_integral.method == "discrete":
return np.sum(entries, axis=0, initial=self.time_integral.initial_condition)
elif self.time_integral.method == "continuous":
return cumulative_trapezoid(
entries, self.t_pts, initial=float(self.time_integral.initial_condition)
)
else:
raise ValueError(
"time_integral method must be 'discrete' or 'continuous'"
) # pragma: no cover

def _interp_setup(self, entries, t):
# save attributes for interpolation
Expand Down Expand Up @@ -556,14 +578,14 @@ def __init__(
base_variables,
base_variables_casadi,
solution,
cumtrapz_ic=None,
time_integral: Optional[pybamm.ProcessedVariableTimeIntegral] = None,
):
self.dimensions = 1
super().__init__(
base_variables,
base_variables_casadi,
solution,
cumtrapz_ic=cumtrapz_ic,
time_integral=time_integral,
)

def _observe_raw_python(self):
Expand Down Expand Up @@ -653,14 +675,14 @@ def __init__(
base_variables,
base_variables_casadi,
solution,
cumtrapz_ic=None,
time_integral: Optional[pybamm.ProcessedVariableTimeIntegral] = None,
):
self.dimensions = 2
super().__init__(
base_variables,
base_variables_casadi,
solution,
cumtrapz_ic=cumtrapz_ic,
time_integral=time_integral,
)
first_dim_nodes = self.mesh.nodes
first_dim_edges = self.mesh.edges
Expand Down Expand Up @@ -819,14 +841,14 @@ def __init__(
base_variables,
base_variables_casadi,
solution,
cumtrapz_ic=None,
time_integral: Optional[pybamm.ProcessedVariableTimeIntegral] = None,
):
self.dimensions = 2
super(ProcessedVariable2D, self).__init__(
base_variables,
base_variables_casadi,
solution,
cumtrapz_ic=cumtrapz_ic,
time_integral=time_integral,
)
y_sol = self.mesh.edges["y"]
z_sol = self.mesh.edges["z"]
Expand Down
28 changes: 28 additions & 0 deletions src/pybamm/solvers/processed_variable_time_integral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from dataclasses import dataclass
from typing import Literal, Optional, Union
import numpy as np
import pybamm


@dataclass
class ProcessedVariableTimeIntegral:
method: Literal["discrete", "continuous"]
initial_condition: np.ndarray
discrete_times: Optional[np.ndarray]

@staticmethod
def from_pybamm_var(
var: Union[pybamm.DiscreteTimeSum, pybamm.ExplicitTimeIntegral],
) -> "ProcessedVariableTimeIntegral":
if isinstance(var, pybamm.DiscreteTimeSum):
return ProcessedVariableTimeIntegral(
method="discrete", initial_condition=0.0, discrete_times=var.sum_times
)
elif isinstance(var, pybamm.ExplicitTimeIntegral):
return ProcessedVariableTimeIntegral(
method="continuous",
initial_condition=var.initial_condition.evaluate(),
discrete_times=None,
)
else:
raise ValueError("Unsupported variable type") # pragma: no cover
28 changes: 17 additions & 11 deletions src/pybamm/solvers/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def update(self, variables):
self._update_variable(variable)

def _update_variable(self, variable):
cumtrapz_ic = None
time_integral = None
pybamm.logger.debug(f"Post-processing {variable}")
vars_pybamm = [
model.variables_and_events[variable] for model in self.all_models
Expand All @@ -591,16 +591,22 @@ def _update_variable(self, variable):
"solve. Please re-run the solve with `output_variables` set to "
"include this variable."
)
elif isinstance(var_pybamm, pybamm.ExplicitTimeIntegral):
cumtrapz_ic = var_pybamm.initial_condition
cumtrapz_ic = cumtrapz_ic.evaluate()
var_pybamm = var_pybamm.child
var_casadi = self.process_casadi_var(
var_pybamm,
inputs,
ys.shape,
elif isinstance(
var_pybamm, (pybamm.ExplicitTimeIntegral, pybamm.DiscreteTimeSum)
):
time_integral = pybamm.ProcessedVariableTimeIntegral.from_pybamm_var(
var_pybamm
)
model._variables_casadi[variable] = var_casadi
var_pybamm = var_pybamm.child
if variable in model._variables_casadi:
var_casadi = model._variables_casadi[variable]
else:
var_casadi = self.process_casadi_var(
var_pybamm,
inputs,
ys.shape,
)
model._variables_casadi[variable] = var_casadi
vars_pybamm[i] = var_pybamm
elif variable in model._variables_casadi:
var_casadi = model._variables_casadi[variable]
Expand All @@ -613,7 +619,7 @@ def _update_variable(self, variable):
model._variables_casadi[variable] = var_casadi
vars_casadi.append(var_casadi)
var = pybamm.process_variable(
vars_pybamm, vars_casadi, self, cumtrapz_ic=cumtrapz_ic
vars_pybamm, vars_casadi, self, time_integral=time_integral
)

self._variables[variable] = var
Expand Down
Loading

0 comments on commit e4eb82a

Please sign in to comment.