Skip to content

Commit

Permalink
Merge pull request #1261 from pybamm-team/issue-849-parallel-processing
Browse files Browse the repository at this point in the history
Solving for an ensemble of input parameters in parallel
  • Loading branch information
tlestang authored Jan 8, 2021
2 parents a15ba7c + d503182 commit 85b3a20
Show file tree
Hide file tree
Showing 5 changed files with 277 additions and 42 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## Features


- Updated solvers' method `solve()` so it can take a list of inputs dictionaries as the `inputs` keyword argument. In this case the model is solved for each input set in the list, and a list of solutions mapping the set of inputs to the solutions is returned. Note that `solve()` can still take a single dictionary as the `inputs` keyword argument. In this case the behaviour is unchanged compared to previous versions.
- Added temperature dependence to the transference number (`t_plus`) ([1317](https://github.com/pybamm-team/PyBaMM/pull/1317))
- Added new functionality for `Interpolant` ([#1312](https://github.com/pybamm-team/PyBaMM/pull/1312))
- Added option to express experiments (and extract solutions) in terms of cycles of operating condition ([#1309](https://github.com/pybamm-team/PyBaMM/pull/1309))
Expand All @@ -14,6 +16,7 @@

## Optimizations

- If solver method `solve()` is passed a list of inputs as the `inputs` keyword argument, the resolution of the model for each input set is spread acrosss several Python processes, usually running in parallel on different processors. The default number of processes is the number of processors available. `solve()` takes a new keyword argument `nproc` which can be used to set this number a manually.
- Variables are now post-processed using CasADi ([#1316](https://github.com/pybamm-team/PyBaMM/pull/1316))
- Operations such as `1*x` and `0+x` now directly return `x` ([#1252](https://github.com/pybamm-team/PyBaMM/pull/1252))

Expand Down
18 changes: 18 additions & 0 deletions pybamm/expression_tree/operations/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ def __init__(self, symbol):
python_str = python_str + "\nself._evaluate = evaluate"

self._python_str = python_str
self._result_var = result_var
self._symbol = symbol

# compile and run the generated python code,
Expand All @@ -507,6 +508,23 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
else:
return result

def __getstate__(self):
# Control the state of instances of EvaluatorPython
# before pickling. Method "_evaluate" cannot be pickled.
# See https://github.com/pybamm-team/PyBaMM/issues/1283
state = self.__dict__.copy()
del state["_evaluate"]
return state

def __setstate__(self, state):
# Restore pickled attributes and
# compile code from "python_str"
# Execution of bytecode (re)adds attribute
# "_method"
self.__dict__.update(state)
compiled_function = compile(self._python_str, self._result_var, "exec")
exec(compiled_function)


class EvaluatorJax:
"""
Expand Down
166 changes: 127 additions & 39 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import sys
import itertools
import multiprocessing as mp
import warnings


Expand Down Expand Up @@ -516,6 +517,7 @@ def solve(
external_variables=None,
inputs=None,
initial_conditions=None,
nproc=None,
):
"""
Execute the solver setup and calculate the solution of the model at
Expand All @@ -531,12 +533,22 @@ def solve(
external_variables : dict
A dictionary of external variables and their corresponding
values at the current time
inputs : dict, optional
Any input parameters to pass to the model when solving
inputs : dict or list, optional
A dictionary or list of dictionaries describing any input parameters to
pass to the model when solving
initial_conditions : :class:`pybamm.Symbol`, optional
Initial conditions to use when solving the model. If None (default),
`model.concatenated_initial_conditions` is used. Otherwise, must be a symbol
of size `len(model.rhs) + len(model.algebraic)`.
nproc : int, optional
Number of processes to use when solving for more than one set of input
parameters. Defaults to value returned by "os.cpu_count()".
Returns
-------
:class:`pybamm.Solution` or list of :class:`pybamm.Solution` objects.
If type of `inputs` is `list`, return a list of corresponding
:class:`pybamm.Solution` objects.
Raises
------
Expand Down Expand Up @@ -580,14 +592,35 @@ def solve(
raise pybamm.SolverError("t_eval must increase monotonically")

# Set up external variables and inputs
ext_and_inputs = self._set_up_ext_and_inputs(model, external_variables, inputs)
#
# Argument "inputs" can be either a list of input dicts or
# a single dict. The remaining of this function is only working
# with variable "input_list", which is a list of dictionaries.
# If "inputs" is a single dict, "inputs_list" is a list of only one dict.
inputs_list = inputs if isinstance(inputs, list) else [inputs]
ext_and_inputs_list = [
self._set_up_ext_and_inputs(model, external_variables, inputs)
for inputs in inputs_list
]

# Cannot use multiprocessing with model in "jax" format
if(len(inputs_list) > 1) and model.convert_to_format == "jax":
raise pybamm.SolverError(
"Cannot solve list of inputs with multiprocessing "
"when model in format \"jax\"."
)

# Set up
timer = pybamm.Timer()

# Set up (if not done already)
if model not in self.models_set_up:
self.set_up(model, ext_and_inputs, t_eval)
# It is assumed that when len(inputs_list) > 1, model set
# up (initial condition, time-scale and length-scale) does
# not depend on input parameters. Thefore only `ext_and_inputs[0]`
# is passed to `set_up`.
# See https://github.com/pybamm-team/PyBaMM/pull/1261
self.set_up(model, ext_and_inputs_list[0], t_eval)
self.models_set_up.update(
{model: {"initial conditions": model.concatenated_initial_conditions}}
)
Expand All @@ -598,22 +631,45 @@ def solve(
# If the new initial conditions are different, set up again
# Doing the whole setup again might be slow, but no need to prematurely
# optimize this
self.set_up(model, ext_and_inputs, t_eval)
self.set_up(model, ext_and_inputs_list[0], t_eval)
self.models_set_up[model][
"initial conditions"
] = model.concatenated_initial_conditions
set_up_time = timer.time()
timer.reset()

# (Re-)calculate consistent initial conditions
self._set_initial_conditions(model, ext_and_inputs, update_rhs=True)
# Assuming initial conditions do not depend on input parameters
# when len(inputs_list) > 1, only `ext_and_inputs_list[0]`
# is passed to `_set_initial_conditions`.
# See https://github.com/pybamm-team/PyBaMM/pull/1261
if len(inputs_list) > 1:
all_inputs_names = set(
itertools.chain.from_iterable(
[ext_and_inputs.keys() for ext_and_inputs in ext_and_inputs_list]
)
)
initial_conditions_node_names = set(
[it.name for it in model.concatenated_initial_conditions.pre_order()]
)
if all_inputs_names.issubset(initial_conditions_node_names):
raise pybamm.SolverError(
"Input parameters cannot appear in expression "
"for initial conditions."
)

self._set_initial_conditions(model, ext_and_inputs_list[0], update_rhs=True)

# Non-dimensionalise time
t_eval_dimensionless = t_eval / model.timescale_eval

# Calculate discontinuities
discontinuities = [
event.expression.evaluate(inputs=inputs)
# Assuming that discontinuities do not depend on
# input parameters when len(input_list) > 1, only
# `input_list[0]` is passed to `evaluate`.
# See https://github.com/pybamm-team/PyBaMM/pull/1261
event.expression.evaluate(inputs=inputs_list[0])
for event in model.discontinuity_events_eval
]

Expand All @@ -638,6 +694,11 @@ def solve(
pybamm.logger.info(
"Discontinuity events found at t = {}".format(discontinuities)
)
if isinstance(inputs, list):
raise pybamm.SolverError(
"Cannot solve for a list of input parameters"
" sets with discontinuities"
)
else:
pybamm.logger.info("No discontinuity events found")

Expand Down Expand Up @@ -665,51 +726,72 @@ def solve(
# object, restarting the solver at each discontinuity (and recalculating a
# consistent state afterwards if a dae)
old_y0 = model.y0
solution = None
solutions = None
for start_index, end_index in zip(start_indices, end_indices):
pybamm.logger.info(
"Calling solver for {} < t < {}".format(
t_eval_dimensionless[start_index] * model.timescale_eval,
t_eval_dimensionless[end_index - 1] * model.timescale_eval,
)
)
new_solution = self._integrate(
model, t_eval_dimensionless[start_index:end_index], ext_and_inputs
)
new_solution.solve_time = timer.time()
if solution is None:
solution = new_solution
ninputs = len(ext_and_inputs_list)
if ninputs == 1:
new_solution = self._integrate(
model,
t_eval_dimensionless[start_index:end_index],
ext_and_inputs_list[0],
)
new_solutions = [new_solution]
else:
with mp.Pool(processes=nproc) as p:
new_solutions = p.starmap(
self._integrate,
zip(
[model] * ninputs,
[t_eval_dimensionless[start_index:end_index]] * ninputs,
ext_and_inputs_list,
),
)
p.close()
p.join()
# Setting the solve time for each segment.
# pybamm.Solution.append assumes attribute
# solve_time.
solve_time = timer.time()
for sol in new_solutions:
sol.solve_time = solve_time
if start_index == start_indices[0]:
solutions = [sol for sol in new_solutions]
else:
solution.append(new_solution, start_index=0)
for i, new_solution in enumerate(new_solutions):
solutions[i].append(new_solution, start_index=0)

if solution.termination != "final time":
if solutions[0].termination != "final time":
break

if end_index != len(t_eval_dimensionless):
# setup for next integration subsection
last_state = solution.y[:, -1]
last_state = solutions[0].y[:, -1]
# update y0 (for DAE solvers, this updates the initial guess for the
# rootfinder)
model.y0 = last_state
if len(model.algebraic) > 0:
model.y0 = self.calculate_consistent_state(
model, t_eval_dimensionless[end_index], ext_and_inputs
model, t_eval_dimensionless[end_index], ext_and_inputs_list[0]
)

# Assign times
solution.set_up_time = set_up_time
solution.solve_time = timer.time()

# restore old y0
model.y0 = old_y0

# Add model and inputs to solution
solution.model = model
solution.inputs = ext_and_inputs
solve_time = timer.time()
for i, solution in enumerate(solutions):
# Assign times
solution.set_up_time = set_up_time
solution.solve_time = solve_time
# Add model and inputs to solution
solution.model = model
solution.inputs = ext_and_inputs_list[i]

# Copy the timescale_eval and lengthscale_evals
solution.timescale_eval = model.timescale_eval
solution.length_scales_eval = model.length_scales_eval
# Copy the timescale_eval and lengthscale_evals
solution.timescale_eval = model.timescale_eval
solution.length_scales_eval = model.length_scales_eval

# Check if extrapolation occurred
extrapolation = self.check_extrapolation(solution, model.events)
Expand All @@ -722,30 +804,36 @@ def solve(
)

# Identify the event that caused termination
termination = self.get_termination_reason(solution, model.events)
termination = self.get_termination_reason(solutions[0], model.events)

# restore old y0
model.y0 = old_y0

pybamm.logger.info("Finish solving {} ({})".format(model.name, termination))
pybamm.logger.info(
(
"Set-up time: {}, Solve time: {} (of which integration time: {}), "
"Total time: {}"
).format(
solution.set_up_time,
solution.solve_time,
solution.integration_time,
solution.total_time,
solutions[0].set_up_time,
solutions[0].solve_time,
solutions[0].integration_time,
solutions[0].total_time,
)
)

# Raise error if solution only contains one timestep (except for algebraic
# Raise error if solutions[0] only contains one timestep (except for algebraic
# solvers, where we may only expect one time in the solution)
if self.algebraic_solver is False and len(solution.t) == 1:
if self.algebraic_solver is False and len(solutions[0].t) == 1:
raise pybamm.SolverError(
"Solution time vector has length 1. "
"Check whether simulation terminated too early."
)

return solution
if ninputs == 1:
return solutions[0]
else:
return solutions

def step(
self,
Expand Down
Loading

0 comments on commit 85b3a20

Please sign in to comment.