Skip to content

Commit

Permalink
feat: make solver changes so that _integrate takes a list of inputs #…
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jul 5, 2024
1 parent f270232 commit c38ab0d
Show file tree
Hide file tree
Showing 9 changed files with 226 additions and 169 deletions.
13 changes: 12 additions & 1 deletion pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,18 @@ def __call__(self, t=None, y=None, inputs=None):
if y is not None and y.ndim == 1:
y = y.reshape(-1, 1)

result = self._evaluate(self._constants, t, y, inputs)
if isinstance(inputs, list):
ny = y.shape[0]
ni = len(inputs)
result = np.zeros((ni * ny, 1))
i = 0
for input in inputs:
result[i : i + ny] += self._evaluate(
self._constants, t, y[i : i + ny], input
)
i += ny
else:
result = self._evaluate(self._constants, t, y, inputs)

return result

Expand Down
16 changes: 9 additions & 7 deletions pybamm/solvers/algebraic_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def tol(self):
def tol(self, value):
self._tol = value

def _integrate(self, model, t_eval, inputs_dict=None):
def _integrate(self, model, t_eval, inputs_list=None):
"""
Calculate the solution of the algebraic equations through root-finding
Expand All @@ -57,14 +57,16 @@ def _integrate(self, model, t_eval, inputs_dict=None):
The model whose solution to calculate.
t_eval : :class:`numpy.array`, size (k,)
The times at which to compute the solution
inputs_dict : dict, optional
inputs_list: list of dict, optional
Any input parameters to pass to the model when solving
"""
inputs_dict = inputs_dict or {}
inputs_list = inputs_list or {}
if model.convert_to_format == "casadi":
inputs = casadi.vertcat(*[x for x in inputs_dict.values()])
inputs = casadi.vertcat(
*[x for inputs in inputs_list for x in inputs.values()]
)
else:
inputs = inputs_dict
inputs = inputs_list

y0 = model.y0
if isinstance(y0, casadi.DM):
Expand Down Expand Up @@ -230,8 +232,8 @@ def jac_norm(y, jac_fn=jac_fn):
y_diff = np.r_[[y0_diff] * len(t_eval)].T
y_sol = np.r_[y_diff, y_alg]
# Return solution object (no events, so pass None to t_event, y_event)
sol = pybamm.Solution(
t_eval, y_sol, model, inputs_dict, termination="final time"
sol = pybamm.Solution.from_concatenated_state(
t_eval, y_sol, model, inputs_list, termination="final time"
)
sol.integration_time = integration_time
return sol
127 changes: 54 additions & 73 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import copy
import itertools
from scipy.sparse import block_diag
import multiprocessing as mp
import numbers
import sys
import warnings
Expand Down Expand Up @@ -104,19 +103,19 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
model : :class:`pybamm.BaseModel`
The model whose solution to calculate. Must have attributes rhs and
initial_conditions
inputs : dict, optional
inputs : list of dict, optional
Any input parameters to pass to the model when solving
t_eval : numeric type, optional
The times (in seconds) at which to compute the solution
"""
inputs = inputs or {}
inputs = inputs or [{}]

if ics_only:
pybamm.logger.info("Start solver set-up, initial_conditions only")
else:
pybamm.logger.info("Start solver set-up")

self._check_and_prepare_model_inplace(model, inputs, ics_only)
self._check_and_prepare_model_inplace(model)

# set default calculate sensitivities on model
if not hasattr(model, "calculate_sensitivities"):
Expand All @@ -141,18 +140,21 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
"initial_conditions",
vars_for_processing,
use_jacobian=False,
ninputs=len(inputs),
)
model.initial_conditions_eval = initial_conditions
model.jacp_initial_conditions_eval = jacp_ic

# evaluate initial condition
y0_total_size = (
y0_total_size = len(inputs) * (
model.len_rhs + model.len_rhs_sens + model.len_alg + model.len_alg_sens
)
y_zero = np.zeros((y0_total_size, 1))
if model.convert_to_format == "casadi":
# stack inputs
inputs_casadi = casadi.vertcat(*[x for x in inputs.values()])
inputs_casadi = casadi.vertcat(
*[x for inpt in inputs for x in inpt.items()]
)
model.y0 = initial_conditions(0.0, y_zero, inputs_casadi)
if jacp_ic is None:
model.y0S = None
Expand Down Expand Up @@ -180,11 +182,14 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
# Process rhs, algebraic, residual and event expressions
# and wrap in callables
rhs, jac_rhs, jacp_rhs, jac_rhs_action = process(
model.concatenated_rhs, "RHS", vars_for_processing
model.concatenated_rhs, "RHS", vars_for_processing, ninputs=len(inputs)
)

algebraic, jac_algebraic, jacp_algebraic, jac_algebraic_action = process(
model.concatenated_algebraic, "algebraic", vars_for_processing
model.concatenated_algebraic,
"algebraic",
vars_for_processing,
ninputs=len(inputs),
)

# combine rhs and algebraic functions
Expand All @@ -202,7 +207,9 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
jac_rhs_algebraic,
jacp_rhs_algebraic,
jac_rhs_algebraic_action,
) = process(rhs_algebraic, "rhs_algebraic", vars_for_processing)
) = process(
rhs_algebraic, "rhs_algebraic", vars_for_processing, ninputs=len(inputs)
)

(
casadi_switch_events,
Expand Down Expand Up @@ -250,6 +257,8 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
model.casadi_rhs = casadi.Function(
"rhs", [t_casadi, y_and_S, p_casadi_stacked], [explicit_rhs]
)
if len(inputs) > 1:
model.casadi_rhs = model.casadi_rhs.map(len(inputs), "openmp")
model.casadi_switch_events = casadi_switch_events
model.casadi_algebraic = algebraic
model.casadi_sensitivities = jacp_rhs_algebraic
Expand Down Expand Up @@ -281,6 +290,7 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
vars_for_processing,
use_jacobian=True,
return_jacp_stacked=True,
ninputs=len(inputs),
)

pybamm.logger.info("Finish solver set-up")
Expand All @@ -307,7 +317,7 @@ def _wrangle_name(cls, name: str) -> str:
name = name.replace(string, replacement)
return name

def _check_and_prepare_model_inplace(self, model, inputs, ics_only):
def _check_and_prepare_model_inplace(self, model):
"""
Performs checks on the model and prepares it for solving.
"""
Expand Down Expand Up @@ -426,10 +436,11 @@ def _set_up_model_sensitivities_inplace(
num_parameters = 0
for name in model.calculate_sensitivities:
# if not a number, assume its a vector
if isinstance(inputs[name], numbers.Number):
if isinstance(inputs[0][name], numbers.Number):
num_parameters += 1
else:
num_parameters += len(inputs[name])
num_parameters += len(inputs[0][name])
num_parameters *= len(inputs)
model.len_rhs_sens = model.len_rhs * num_parameters
model.len_alg_sens = model.len_alg * num_parameters
else:
Expand Down Expand Up @@ -575,6 +586,7 @@ def _set_up_events(self, model, t_eval, inputs, vars_for_processing):
f"event_{n}",
vars_for_processing,
use_jacobian=False,
ninputs=len(inputs),
)[0]
# use the actual casadi object as this will go into the rhs
casadi_switch_events.append(event_casadi)
Expand All @@ -585,6 +597,7 @@ def _set_up_events(self, model, t_eval, inputs, vars_for_processing):
f"event_{n}",
vars_for_processing,
use_jacobian=False,
ninputs=len(inputs),
)[0]
if event.event_type == pybamm.EventType.TERMINATION:
terminate_events.append(event_call)
Expand Down Expand Up @@ -670,7 +683,7 @@ def calculate_consistent_state(self, model, time=0, inputs=None):
The model for which to calculate initial conditions.
time : float
The time at which to calculate the states
inputs: dict, optional
inputs: list of dict, optional
Any input parameters to pass to the model when solving
Returns
Expand Down Expand Up @@ -826,12 +839,7 @@ def solve(
f'"{existing_model.name}". Please create a separate '
"solver for this model"
)
# It is assumed that when len(inputs_list) > 1, model set
# up (initial condition, time-scale and length-scale) does
# not depend on input parameters. Therefore, only `model_inputs[0]`
# is passed to `set_up`.
# See https://github.com/pybamm-team/PyBaMM/pull/1261
self.set_up(model, model_inputs_list[0], t_eval)
self.set_up(model, model_inputs_list, t_eval)
self._model_set_up.update(
{model: {"initial conditions": model.concatenated_initial_conditions}}
)
Expand All @@ -847,40 +855,20 @@ def solve(
else:
# If the new initial conditions are different
# and cannot be evaluated directly, set up again
self.set_up(model, model_inputs_list[0], t_eval, ics_only=True)
self.set_up(model, model_inputs_list, t_eval, ics_only=True)
self._model_set_up[model]["initial conditions"] = (
model.concatenated_initial_conditions
)

set_up_time = timer.time()
timer.reset()

# (Re-)calculate consistent initial conditions
# Assuming initial conditions do not depend on input parameters
# when len(inputs_list) > 1, only `model_inputs_list[0]`
# is passed to `_set_initial_conditions`.
# See https://github.com/pybamm-team/PyBaMM/pull/1261
if len(inputs_list) > 1:
all_inputs_names = set(
itertools.chain.from_iterable(
[model_inputs.keys() for model_inputs in model_inputs_list]
)
)
initial_conditions_node_names = set(
[it.name for it in model.concatenated_initial_conditions.pre_order()]
)
if all_inputs_names.issubset(initial_conditions_node_names):
raise pybamm.SolverError(
"Input parameters cannot appear in expression "
"for initial conditions."
)

self._set_initial_conditions(
model, t_eval[0], model_inputs_list[0], update_rhs=True
model, t_eval[0], model_inputs_list, update_rhs=True
)

# Check initial conditions don't violate events
self._check_events_with_initial_conditions(t_eval, model, model_inputs_list[0])
self._check_events_with_initial_conditions(t_eval, model, model_inputs_list)

# Process discontinuities
(
Expand All @@ -898,34 +886,12 @@ def solve(
pybamm.logger.verbose(
f"Calling solver for {t_eval[start_index]} < t < {t_eval[end_index - 1]}"
)
ninputs = len(model_inputs_list)
if ninputs == 1:
new_solution = self._integrate(
model,
t_eval[start_index:end_index],
model_inputs_list[0],
)
new_solutions = [new_solution]
else:
if model.convert_to_format == "jax":
# Jax can parallelize over the inputs efficiently
new_solutions = self._integrate(
model,
t_eval[start_index:end_index],
model_inputs_list,
)
else:
with mp.get_context(self._mp_context).Pool(processes=nproc) as p:
new_solutions = p.starmap(
self._integrate,
zip(
[model] * ninputs,
[t_eval[start_index:end_index]] * ninputs,
model_inputs_list,
),
)
p.close()
p.join()
new_solutions = self._integrate(
model,
t_eval[start_index:end_index],
model_inputs_list,
)

# Setting the solve time for each segment.
# pybamm.Solution.__add__ assumes attribute solve_time.
solve_time = timer.time()
Expand All @@ -948,7 +914,7 @@ def solve(
model.y0 = last_state
if len(model.algebraic) > 0:
model.y0 = self.calculate_consistent_state(
model, t_eval[end_index], model_inputs_list[0]
model, t_eval[end_index], model_inputs_list
)
solve_time = timer.time()

Expand Down Expand Up @@ -995,7 +961,7 @@ def solve(
)

# Return solution(s)
if ninputs == 1:
if len(inputs) == 1:
return solutions[0]
else:
return solutions
Expand Down Expand Up @@ -1427,7 +1393,12 @@ def _set_up_model_inputs(model, inputs):


def process(
symbol, name, vars_for_processing, use_jacobian=None, return_jacp_stacked=None
symbol,
name,
vars_for_processing,
use_jacobian=None,
return_jacp_stacked=None,
ninputs=1,
):
"""
Parameters
Expand Down Expand Up @@ -1663,4 +1634,14 @@ def jacp(*args, **kwargs):
name, [t_casadi, y_and_S, p_casadi_stacked], [casadi_expression]
)

if ninputs > 1:
parallelisation = "openmp"
func = func.map(ninputs, parallelisation)
if jac is not None:
jac = jac.map(ninputs, parallelisation)
if jacp is not None:
jacp = jacp.map(ninputs, parallelisation)
if jac_action is not None:
jac_action = jac_action.map(ninputs, parallelisation)

return func, jac, jacp, jac_action
12 changes: 6 additions & 6 deletions pybamm/solvers/casadi_algebraic_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def tol(self):
def tol(self, value):
self._tol = value

def _integrate(self, model, t_eval, inputs_dict=None):
def _integrate(self, model, t_eval, inputs_list=None):
"""
Calculate the solution of the algebraic equations through root-finding
Expand All @@ -47,14 +47,14 @@ def _integrate(self, model, t_eval, inputs_dict=None):
The model whose solution to calculate.
t_eval : :class:`numpy.array`, size (k,)
The times at which to compute the solution
inputs_dict : dict, optional
inputs_list: list of dict, optional
Any input parameters to pass to the model when solving.
"""
# Record whether there are any symbolic inputs
inputs_dict = inputs_dict or {}
inputs_list = inputs_list or {}

# Create casadi objects for the root-finder
inputs = casadi.vertcat(*[v for v in inputs_dict.values()])
inputs = casadi.vertcat(*[v for inputs in inputs_list for v in inputs.values()])

y0 = model.y0

Expand Down Expand Up @@ -164,11 +164,11 @@ def _integrate(self, model, t_eval, inputs_dict=None):
except AttributeError:
explicit_sensitivities = False

sol = pybamm.Solution(
sol = pybamm.Solution.from_concatenated_state(
[t_eval],
y_sol,
model,
inputs_dict,
inputs_list,
termination="final time",
sensitivities=explicit_sensitivities,
)
Expand Down
Loading

0 comments on commit c38ab0d

Please sign in to comment.