Skip to content

Commit

Permalink
all three solvers work now in parallel #4087
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jul 12, 2024
1 parent 474c4dc commit 9b2b3d1
Show file tree
Hide file tree
Showing 10 changed files with 213 additions and 104 deletions.
2 changes: 2 additions & 0 deletions pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ def find_symbols(
# Index has a different syntax than other univariate operations
if isinstance(symbol, pybamm.Index):
symbol_str = f"{children_vars[0]}[{symbol.slice.start}:{symbol.slice.stop}]"
elif isinstance(symbol, pybamm.AbsoluteValue):
symbol_str = f"{symbol.name}({children_vars[0]})"
else:
symbol_str = symbol.name + children_vars[0]

Expand Down
126 changes: 64 additions & 62 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import sys
import warnings
import platform
import asyncio

# import asyncio
import multiprocessing as mp


import casadi
Expand Down Expand Up @@ -772,9 +774,7 @@ def _handle_integrate_defaults(
y0S_list = model.y0S_list
return inputs_list, batched_inputs, nbatches, batch_size, y0S_list

def _integrate(
self, model, t_eval, inputs_list=None, batched_inputs=None, nproc=None
):
def _integrate(self, model, t_eval, inputs_list=None, batched_inputs=None):
"""
Solve a DAE model defined by residuals with initial conditions y0.
Expand Down Expand Up @@ -804,53 +804,55 @@ def _integrate(
)

# async io is not parallel, but if solve is io bound, it can be faster
async def solve_model_batches():
async def solve_model_async(y0, y0S, inputs, inputs_array):
return self._integrate_batch(
model, t_eval, y0, y0S, inputs, inputs_array
)

coro = []
for i in range(nbatches):
coro.append(
asyncio.create_task(
solve_model_async(
model.y0_list[i],
y0S_list[i],
inputs_list[i * batch_size : (i + 1) * batch_size],
batched_inputs[i],
)
)
)
return await asyncio.gather(*coro)

new_solutions = asyncio.run(solve_model_batches())
# async def solve_model_batches():
# async def solve_model_async(y0, y0S, inputs, inputs_array):
# return self._integrate_batch(
# model, t_eval, y0, y0S, inputs, inputs_array
# )

# coro = []
# for i in range(nbatches):
# coro.append(
# asyncio.create_task(
# solve_model_async(
# model.y0_list[i],
# y0S_list[i],
# inputs_list[i * batch_size : (i + 1) * batch_size],
# batched_inputs[i],
# )
# )
# )
# return await asyncio.gather(*coro)

# new_solutions = asyncio.run(solve_model_batches())

# new_solutions = []
# for i in range(nbatches):
# new_solutions.append(self._integrate_batch(model, t_eval, model.y0_list[i], y0S_list[i], inputs_list[i * batch_size : (i + 1) * batch_size], batched_inputs[i]))

# with mp.get_context(self._mp_context).Pool(processes=nproc) as p:
# model_list = [model] * nbatches
# t_eval_list = [t_eval] * nbatches
# y0_list = model.y0_list
# inputs_list_of_list = [
# inputs_list[i * batch_size : (i + 1) * batch_size]
# for i in range(nbatches)
# ]
# new_solutions = p.starmap(
# self._integrate_batch,
# zip(
# model_list,
# t_eval_list,
# y0_list,
# y0S_list,
# inputs_list_of_list,
# batched_inputs,
# ),
# )
# p.close()
# p.join()
threads_per_batch = max(self._base_options["num_threads"] // nbatches, 1)
nproc = self._base_options["num_threads"] // threads_per_batch
with mp.get_context(self._mp_context).Pool(processes=nproc) as p:
model_list = [model] * nbatches
t_eval_list = [t_eval] * nbatches
y0_list = model.y0_list
inputs_list_of_list = [
inputs_list[i * batch_size : (i + 1) * batch_size]
for i in range(nbatches)
]
new_solutions = p.starmap(
self._integrate_batch,
zip(
model_list,
t_eval_list,
y0_list,
y0S_list,
inputs_list_of_list,
batched_inputs,
),
)
p.close()
p.join()
new_solutions_flat = [sol for sublist in new_solutions for sol in sublist]
return new_solutions_flat

Expand Down Expand Up @@ -1700,9 +1702,11 @@ def _zip_state_vector(model, y_diff, y_alg):
len_alg = model.len_alg + model.len_alg_sens
batch_size = model.batch_size
y_diff_list = [
y_diff[i * len_rhs : (i + 1) * len_rhs] for i in range(batch_size)
y_diff[i * len_rhs : (i + 1) * len_rhs, :] for i in range(batch_size)
]
y_alg_list = [
y_alg[i * len_alg : (i + 1) * len_alg, :] for i in range(batch_size)
]
y_alg_list = [y_alg[i * len_alg : (i + 1) * len_alg] for i in range(batch_size)]
if isinstance(y_diff, casadi.DM):
y = casadi.vertcat(
*[val for pair in zip(y_diff_list, y_alg_list) for val in pair]
Expand Down Expand Up @@ -1744,10 +1748,7 @@ def map_func_over_inputs_casadi(name, f, vars_for_processing, ninputs, nthreads)
nstates = vars_for_processing["y_and_S"].shape[0]
nparams = vars_for_processing["p_casadi_stacked"].shape[0]

threads_per_input = nthreads // ninputs
if threads_per_input > 1:
threads_per_input = 1
if threads_per_input > 1:
if nthreads > 1:
parallelisation = "thread"
else:
parallelisation = "none"
Expand All @@ -1773,7 +1774,7 @@ def map_func_over_inputs_casadi(name, f, vars_for_processing, ninputs, nthreads)
inputs_2d = [t_2d, y_and_S_2d, p_casadi_2d]
inputs_stacked = [t_stacked, y_and_S_inputs_stacked, p_casadi_inputs_stacked]

mapped_f = f.map(ninputs, parallelisation, threads_per_input)(*inputs_2d)
mapped_f = f.map(ninputs, parallelisation, nthreads)(*inputs_2d)
if matrix_output:
# for matrix output we need to stack the outputs in a block diagonal matrix
splits = [i * nstates for i in range(ninputs + 1)]
Expand Down Expand Up @@ -1849,6 +1850,8 @@ def process(
else:
inputs_batch = [inputs[0]]
is_event = "event" in name
nbatches = len(inputs) // batch_size
nthreads_per_batch = max(nthreads // nbatches, 1)

def report(string):
# don't log event conversion
Expand Down Expand Up @@ -2028,31 +2031,30 @@ def jacp(t, y, inputs):
name, [t_casadi, y_and_S, p_casadi_stacked], [casadi_expression]
)

ninputs = len(inputs_batch)
if ninputs > 1:
if batch_size > 1:
func = map_func_over_inputs_casadi(
name, func, vars_for_processing, ninputs, nthreads
name, func, vars_for_processing, batch_size, nthreads_per_batch
)
jac = map_func_over_inputs_casadi(
name + "_jac",
jac,
vars_for_processing,
ninputs,
nthreads,
batch_size,
nthreads_per_batch,
)
jacp = map_func_over_inputs_casadi(
name + "_jacp",
jacp,
vars_for_processing,
ninputs,
nthreads,
batch_size,
nthreads_per_batch,
)
jac_action = map_func_over_inputs_casadi(
name + "_jac_action",
jac_action,
vars_for_processing,
ninputs,
nthreads,
batch_size,
nthreads_per_batch,
)

return func, jac, jacp, jac_action
4 changes: 1 addition & 3 deletions pybamm/solvers/c_solvers/idaklu/CasadiSolverGroup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,12 @@ std::vector<Solution> CasadiSolverGroup::solve(np_array t_np, np_array y0_np, np
const std::size_t solves_per_thread = n_groups / m_solvers.size();
const std::size_t remainder_solves = n_groups % m_solvers.size();

const std::size_t nthreads = m_solvers.size();

const realtype *t = t_np.data();
const realtype *y0 = y0_np.data();
const realtype *yp0 = yp0_np.data();
const realtype *inputs_data = inputs.data();

omp_set_num_threads(nthreads);
omp_set_num_threads(m_solvers.size());
#pragma omp parallel for
for (int i = 0; i < m_solvers.size(); i++) {
for (int j = 0; j < solves_per_thread; j++) {
Expand Down
2 changes: 2 additions & 0 deletions pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ CasadiSolverOpenMP::CasadiSolverOpenMP(
}
}

// todo: should use std::vector for these...
res = new realtype[max_res_size];
res_dvar_dy = new realtype[max_res_dvar_dy];
res_dvar_dp = new realtype[max_res_dvar_dp];
Expand Down Expand Up @@ -261,6 +262,7 @@ void CasadiSolverOpenMP::CalcVarsSensitivities(
for(int k=0; k<number_of_parameters; k++)
dens_dvar_dp[k]=0;
for(int k=0; k<spdp.nnz(); k++)
// todo: get_row() will allocate a new array each time, refactor this
dens_dvar_dp[spdp.get_row()[k]] = res_dvar_dp[k];
// Calculate sensitivities
for(int paramk=0; paramk<number_of_parameters; paramk++) {
Expand Down
2 changes: 2 additions & 0 deletions pybamm/solvers/c_solvers/idaklu/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
#include <pybind11/stl.h>

namespace py = pybind11;

// note: we rely on c_style ordering for numpy arrays so don't change this!
using np_array = py::array_t<realtype, py::array::c_style | py::array::forcecast>;
using np_array_int = py::array_t<int64_t>;

Expand Down
8 changes: 8 additions & 0 deletions pybamm/solvers/casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ class CasadiSolver(pybamm.BaseSolver):
The maximum number of integrators that the solver will retain before
ejecting past integrators using an LRU methodology. A value of 0 or
None leaves the number of integrators unbound. Default is 100.
options: dict, optional
List of options, defaults are:
options = {
# Number of threads available for OpenMP
"num_threads": 1,
}
"""

def __init__(
Expand All @@ -86,6 +92,7 @@ def __init__(
return_solution_if_failed_early=False,
perturb_algebraic_initial_conditions=None,
integrators_maxcount=100,
options=None,
):
super().__init__(
"problem dependent",
Expand All @@ -94,6 +101,7 @@ def __init__(
root_method,
root_tol,
extrap_tol,
options=options,
)
if mode in ["safe", "fast", "fast with events", "safe without grid"]:
self.mode = mode
Expand Down
Loading

0 comments on commit 9b2b3d1

Please sign in to comment.