Skip to content

Commit

Permalink
openmp parallel loop for idaklu #4087
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jul 11, 2024
1 parent 4c19d51 commit 36d19a0
Show file tree
Hide file tree
Showing 19 changed files with 651 additions and 293 deletions.
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ pybind11_add_module(idaklu
pybamm/solvers/c_solvers/idaklu/casadi_solver.hpp
pybamm/solvers/c_solvers/idaklu/CasadiSolver.cpp
pybamm/solvers/c_solvers/idaklu/CasadiSolver.hpp
pybamm/solvers/c_solvers/idaklu/CasadiSolverGroup.cpp
pybamm/solvers/c_solvers/idaklu/CasadiSolverGroup.hpp
pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp
pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.hpp
pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.cpp
Expand Down Expand Up @@ -95,6 +97,12 @@ set_target_properties(
INSTALL_RPATH_USE_LINK_PATH TRUE
)

# openmp
find_package(OpenMP)
if(OpenMP_CXX_FOUND)
target_link_libraries(idaklu PRIVATE OpenMP::OpenMP_CXX)
endif()

set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${PROJECT_SOURCE_DIR})
# Sundials
find_package(SUNDIALS REQUIRED)
Expand Down
74 changes: 74 additions & 0 deletions compile_commands.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
[
{
"directory": "/home/mrobins/git/PyBaMM",
"command": "/usr/bin/c++ -DCASADI_SNPRINTF=snprintf -D_GLIBCXX_USE_CXX11_ABI=0 -Didaklu_EXPORTS -I/usr/include/suitesparse -isystem /home/mrobins/git/PyBaMM/pybind11/include -isystem /usr/include/python3.10 -isystem /home/mrobins/git/PyBaMM/env/lib/python3.10/site-packages/casadi/include -Werror=vla -O3 -DNDEBUG -std=c++14 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/casadi_functions.cpp.o -c /home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/casadi_functions.cpp",
"file": "/home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/casadi_functions.cpp",
"output": "CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/casadi_functions.cpp.o"
},
{
"directory": "/home/mrobins/git/PyBaMM",
"command": "/usr/bin/c++ -DCASADI_SNPRINTF=snprintf -D_GLIBCXX_USE_CXX11_ABI=0 -Didaklu_EXPORTS -I/usr/include/suitesparse -isystem /home/mrobins/git/PyBaMM/pybind11/include -isystem /usr/include/python3.10 -isystem /home/mrobins/git/PyBaMM/env/lib/python3.10/site-packages/casadi/include -Werror=vla -O3 -DNDEBUG -std=c++14 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/casadi_solver.cpp.o -c /home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/casadi_solver.cpp",
"file": "/home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/casadi_solver.cpp",
"output": "CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/casadi_solver.cpp.o"
},
{
"directory": "/home/mrobins/git/PyBaMM",
"command": "/usr/bin/c++ -DCASADI_SNPRINTF=snprintf -D_GLIBCXX_USE_CXX11_ABI=0 -Didaklu_EXPORTS -I/usr/include/suitesparse -isystem /home/mrobins/git/PyBaMM/pybind11/include -isystem /usr/include/python3.10 -isystem /home/mrobins/git/PyBaMM/env/lib/python3.10/site-packages/casadi/include -Werror=vla -O3 -DNDEBUG -std=c++14 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/CasadiSolver.cpp.o -c /home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/CasadiSolver.cpp",
"file": "/home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/CasadiSolver.cpp",
"output": "CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/CasadiSolver.cpp.o"
},
{
"directory": "/home/mrobins/git/PyBaMM",
"command": "/usr/bin/c++ -DCASADI_SNPRINTF=snprintf -D_GLIBCXX_USE_CXX11_ABI=0 -Didaklu_EXPORTS -I/usr/include/suitesparse -isystem /home/mrobins/git/PyBaMM/pybind11/include -isystem /usr/include/python3.10 -isystem /home/mrobins/git/PyBaMM/env/lib/python3.10/site-packages/casadi/include -Werror=vla -O3 -DNDEBUG -std=c++14 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/CasadiSolverGroup.cpp.o -c /home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/CasadiSolverGroup.cpp",
"file": "/home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/CasadiSolverGroup.cpp",
"output": "CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/CasadiSolverGroup.cpp.o"
},
{
"directory": "/home/mrobins/git/PyBaMM",
"command": "/usr/bin/c++ -DCASADI_SNPRINTF=snprintf -D_GLIBCXX_USE_CXX11_ABI=0 -Didaklu_EXPORTS -I/usr/include/suitesparse -isystem /home/mrobins/git/PyBaMM/pybind11/include -isystem /usr/include/python3.10 -isystem /home/mrobins/git/PyBaMM/env/lib/python3.10/site-packages/casadi/include -Werror=vla -O3 -DNDEBUG -std=c++14 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp.o -c /home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp",
"file": "/home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp",
"output": "CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp.o"
},
{
"directory": "/home/mrobins/git/PyBaMM",
"command": "/usr/bin/c++ -DCASADI_SNPRINTF=snprintf -D_GLIBCXX_USE_CXX11_ABI=0 -Didaklu_EXPORTS -I/usr/include/suitesparse -isystem /home/mrobins/git/PyBaMM/pybind11/include -isystem /usr/include/python3.10 -isystem /home/mrobins/git/PyBaMM/env/lib/python3.10/site-packages/casadi/include -Werror=vla -O3 -DNDEBUG -std=c++14 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.cpp.o -c /home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.cpp",
"file": "/home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.cpp",
"output": "CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.cpp.o"
},
{
"directory": "/home/mrobins/git/PyBaMM",
"command": "/usr/bin/c++ -DCASADI_SNPRINTF=snprintf -D_GLIBCXX_USE_CXX11_ABI=0 -Didaklu_EXPORTS -I/usr/include/suitesparse -isystem /home/mrobins/git/PyBaMM/pybind11/include -isystem /usr/include/python3.10 -isystem /home/mrobins/git/PyBaMM/env/lib/python3.10/site-packages/casadi/include -Werror=vla -O3 -DNDEBUG -std=c++14 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.cpp.o -c /home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.cpp",
"file": "/home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.cpp",
"output": "CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.cpp.o"
},
{
"directory": "/home/mrobins/git/PyBaMM",
"command": "/usr/bin/c++ -DCASADI_SNPRINTF=snprintf -D_GLIBCXX_USE_CXX11_ABI=0 -Didaklu_EXPORTS -I/usr/include/suitesparse -isystem /home/mrobins/git/PyBaMM/pybind11/include -isystem /usr/include/python3.10 -isystem /home/mrobins/git/PyBaMM/env/lib/python3.10/site-packages/casadi/include -Werror=vla -O3 -DNDEBUG -std=c++14 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/idaklu_jax.cpp.o -c /home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/idaklu_jax.cpp",
"file": "/home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/idaklu_jax.cpp",
"output": "CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/idaklu_jax.cpp.o"
},
{
"directory": "/home/mrobins/git/PyBaMM",
"command": "/usr/bin/c++ -DCASADI_SNPRINTF=snprintf -D_GLIBCXX_USE_CXX11_ABI=0 -Didaklu_EXPORTS -I/usr/include/suitesparse -isystem /home/mrobins/git/PyBaMM/pybind11/include -isystem /usr/include/python3.10 -isystem /home/mrobins/git/PyBaMM/env/lib/python3.10/site-packages/casadi/include -Werror=vla -O3 -DNDEBUG -std=c++14 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/python.cpp.o -c /home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/python.cpp",
"file": "/home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/python.cpp",
"output": "CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/python.cpp.o"
},
{
"directory": "/home/mrobins/git/PyBaMM",
"command": "/usr/bin/c++ -DCASADI_SNPRINTF=snprintf -D_GLIBCXX_USE_CXX11_ABI=0 -Didaklu_EXPORTS -I/usr/include/suitesparse -isystem /home/mrobins/git/PyBaMM/pybind11/include -isystem /usr/include/python3.10 -isystem /home/mrobins/git/PyBaMM/env/lib/python3.10/site-packages/casadi/include -Werror=vla -O3 -DNDEBUG -std=c++14 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/solution.cpp.o -c /home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/solution.cpp",
"file": "/home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/solution.cpp",
"output": "CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/solution.cpp.o"
},
{
"directory": "/home/mrobins/git/PyBaMM",
"command": "/usr/bin/c++ -DCASADI_SNPRINTF=snprintf -D_GLIBCXX_USE_CXX11_ABI=0 -Didaklu_EXPORTS -I/usr/include/suitesparse -isystem /home/mrobins/git/PyBaMM/pybind11/include -isystem /usr/include/python3.10 -isystem /home/mrobins/git/PyBaMM/env/lib/python3.10/site-packages/casadi/include -Werror=vla -O3 -DNDEBUG -std=c++14 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/options.cpp.o -c /home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/options.cpp",
"file": "/home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/options.cpp",
"output": "CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/options.cpp.o"
},
{
"directory": "/home/mrobins/git/PyBaMM",
"command": "/usr/bin/c++ -DCASADI_SNPRINTF=snprintf -D_GLIBCXX_USE_CXX11_ABI=0 -Didaklu_EXPORTS -I/usr/include/suitesparse -isystem /home/mrobins/git/PyBaMM/pybind11/include -isystem /usr/include/python3.10 -isystem /home/mrobins/git/PyBaMM/env/lib/python3.10/site-packages/casadi/include -Werror=vla -O3 -DNDEBUG -std=c++14 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu.cpp.o -c /home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu.cpp",
"file": "/home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu.cpp",
"output": "CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu.cpp.o"
}
]
3 changes: 2 additions & 1 deletion pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,9 +516,10 @@ def __call__(self, t=None, y=None, inputs=None):
"""
evaluate function
"""
# generated code assumes y is a column vector
# generated code assumes y and inputs are column vectors
if y is not None and y.ndim == 1:
y = y.reshape(-1, 1)
inputs = inputs.reshape(-1, 1)

if self._ninputs == 1:
# nothing to do for a single input
Expand Down
30 changes: 16 additions & 14 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def __init__(
if options is None:
options = default_options
else:
print("options", options)
for key, value in default_options.items():
if key not in options:
options[key] = value
Expand Down Expand Up @@ -928,15 +927,6 @@ def solve(
f"len(inputs) = {len(inputs)} and batch_size = {batch_size}"
)

# get a list-only version of calculate_sensitivities
if isinstance(calculate_sensitivities, bool):
if calculate_sensitivities:
calculate_sensitivities_list = [p for p in inputs.keys()]
else:
calculate_sensitivities_list = []
else:
calculate_sensitivities_list = calculate_sensitivities

# Make sure model isn't empty
if len(model.rhs) == 0 and len(model.algebraic) == 0:
if not isinstance(self, pybamm.DummySolver):
Expand Down Expand Up @@ -989,6 +979,15 @@ def solve(
self._set_up_model_inputs(model, inputs) for inputs in inputs_list
]

# get a list-only version of calculate_sensitivities
if isinstance(calculate_sensitivities, bool):
if calculate_sensitivities:
calculate_sensitivities_list = [p for p in inputs_list[0].keys()]
else:
calculate_sensitivities_list = []
else:
calculate_sensitivities_list = calculate_sensitivities

# Check that calculate_sensitivites or batch size have not been updated
calculate_sensitivities_list.sort()
if not hasattr(model, "calculate_sensitivities"):
Expand Down Expand Up @@ -1745,10 +1744,13 @@ def map_func_over_inputs_casadi(name, f, vars_for_processing, ninputs, nthreads)
nstates = vars_for_processing["y_and_S"].shape[0]
nparams = vars_for_processing["p_casadi_stacked"].shape[0]

if nthreads == 1:
parallelisation = "none"
else:
threads_per_input = nthreads // ninputs
if threads_per_input > 1:
threads_per_input = 1
if threads_per_input > 1:
parallelisation = "thread"
else:
parallelisation = "none"
y_and_S_inputs_stacked = casadi.MX.sym("y_and_S_stacked", nstates * ninputs)
p_casadi_inputs_stacked = casadi.MX.sym("p_stacked", nparams * ninputs)
v_inputs_stacked = casadi.MX.sym("v_stacked", nstates * ninputs)
Expand All @@ -1771,7 +1773,7 @@ def map_func_over_inputs_casadi(name, f, vars_for_processing, ninputs, nthreads)
inputs_2d = [t_2d, y_and_S_2d, p_casadi_2d]
inputs_stacked = [t_stacked, y_and_S_inputs_stacked, p_casadi_inputs_stacked]

mapped_f = f.map(ninputs, parallelisation, nthreads)(*inputs_2d)
mapped_f = f.map(ninputs, parallelisation, threads_per_input)(*inputs_2d)
if matrix_output:
# for matrix output we need to stack the outputs in a block diagonal matrix
splits = [i * nstates for i in range(ninputs + 1)]
Expand Down
11 changes: 7 additions & 4 deletions pybamm/solvers/c_solvers/idaklu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ Function generate_function(const std::string &data)
namespace py = pybind11;

PYBIND11_MAKE_OPAQUE(std::vector<np_array>);
PYBIND11_MAKE_OPAQUE(std::vector<Solution>);

PYBIND11_MODULE(idaklu, m)
{
m.doc() = "sundials solvers"; // optional module docstring

py::bind_vector<std::vector<np_array>>(m, "VectorNdArray");
py::bind_vector<std::vector<Solution>>(m, "VectorSolution");

m.def("solve_python", &solve_python,
"The solve function for python evaluators",
Expand All @@ -50,17 +52,17 @@ PYBIND11_MODULE(idaklu, m)
py::arg("number_of_sensitivity_parameters"),
py::return_value_policy::take_ownership);

py::class_<CasadiSolver>(m, "CasadiSolver")
.def("solve", &CasadiSolver::solve,
py::class_<CasadiSolverGroup>(m, "CasadiSolverGroup")
.def("solve", &CasadiSolverGroup::solve,
"perform a solve",
py::arg("t"),
py::arg("y0"),
py::arg("yp0"),
py::arg("inputs"),
py::return_value_policy::take_ownership);

m.def("create_casadi_solver", &create_casadi_solver,
"Create a casadi idaklu solver object",
m.def("create_casadi_solver_group", &create_casadi_solver_group,
"Create a casadi idaklu solver group object",
py::arg("number_of_states"),
py::arg("number_of_parameters"),
py::arg("rhs_alg"),
Expand All @@ -83,6 +85,7 @@ PYBIND11_MODULE(idaklu, m)
py::arg("dvar_dy_fcns"),
py::arg("dvar_dp_fcns"),
py::arg("options"),
py::arg("nsolvers"),
py::return_value_policy::take_ownership);

m.def("generate_function", &generate_function,
Expand Down
18 changes: 13 additions & 5 deletions pybamm/solvers/c_solvers/idaklu/CasadiSolver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,19 @@ class CasadiSolver
/**
* @brief Abstract solver method that returns a Solution class
*/
virtual Solution solve(
np_array t_np,
np_array y0_np,
np_array yp0_np,
np_array_dense inputs) = 0;
virtual void solve(
const realtype *t,
const int number_of_timesteps,
const realtype *y0,
const realtype *yp0,
const realtype *inputs,
const int length_of_return_vector,
realtype *y_return,
realtype *yS_return,
realtype *t_return,
int &t_i,
int &retval
) = 0;

/**
* Abstract method to initialize the solver, once vectors and solver classes
Expand Down
Loading

0 comments on commit 36d19a0

Please sign in to comment.