Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable multithreading in IDAKLU #2947

Merged
merged 4 commits into from
May 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Features

- Enable multithreading in IDAKLU solver ([#2947](https://github.com/pybamm-team/PyBaMM/pull/2947))
- If a solution contains cycles and steps, the cycle number and step number are now saved when `solution.save_data()` is called ([#2931](https://github.com/pybamm-team/PyBaMM/pull/2931))

## Optimizations
Expand Down
1 change: 1 addition & 0 deletions FindSUNDIALS.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ set(SUNDIALS_WANT_COMPONENTS
sundials_sunlinsollapackdense
sundials_sunmatrixsparse
sundials_nvecserial
sundials_nvecopenmp
)

# find the SUNDIALS libraries
Expand Down
18 changes: 10 additions & 8 deletions pybamm/solvers/c_solvers/idaklu/casadi_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "common.hpp"
#include <memory>


CasadiSolver *
create_casadi_solver(int number_of_states, int number_of_parameters,
const Function &rhs_alg, const Function &jac_times_cjmass,
Expand Down Expand Up @@ -53,16 +54,17 @@ CasadiSolver::CasadiSolver(np_array atol_np, double rel_tol,
#endif

// allocate vectors
int num_threads = options.num_threads;
#if SUNDIALS_VERSION_MAJOR >= 6
yy = N_VNew_Serial(number_of_states, sunctx);
yp = N_VNew_Serial(number_of_states, sunctx);
avtol = N_VNew_Serial(number_of_states, sunctx);
id = N_VNew_Serial(number_of_states, sunctx);
yy = N_VNew_OpenMP(number_of_states, num_threads, sunctx);
yp = N_VNew_OpenMP(number_of_states, num_threads, sunctx);
avtol = N_VNew_OpenMP(number_of_states, num_threads, sunctx);
id = N_VNew_OpenMP(number_of_states, num_threads, sunctx);
#else
yy = N_VNew_Serial(number_of_states);
yp = N_VNew_Serial(number_of_states);
avtol = N_VNew_Serial(number_of_states);
id = N_VNew_Serial(number_of_states);
yy = N_VNew_OpenMP(number_of_states, num_threads);
yp = N_VNew_OpenMP(number_of_states, num_threads);
avtol = N_VNew_OpenMP(number_of_states, num_threads);
id = N_VNew_OpenMP(number_of_states, num_threads);
#endif

if (number_of_parameters > 0)
Expand Down
36 changes: 18 additions & 18 deletions pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,19 @@ int residual_casadi(realtype tres, N_Vector yy, N_Vector yp, N_Vector rr,
static_cast<CasadiFunctions *>(user_data);

p_python_functions->rhs_alg.m_arg[0] = &tres;
p_python_functions->rhs_alg.m_arg[1] = NV_DATA_S(yy);
p_python_functions->rhs_alg.m_arg[1] = NV_DATA_OMP(yy);
p_python_functions->rhs_alg.m_arg[2] = p_python_functions->inputs.data();
p_python_functions->rhs_alg.m_res[0] = NV_DATA_S(rr);
p_python_functions->rhs_alg.m_res[0] = NV_DATA_OMP(rr);
p_python_functions->rhs_alg();

realtype *tmp = p_python_functions->get_tmp_state_vector();
p_python_functions->mass_action.m_arg[0] = NV_DATA_S(yp);
p_python_functions->mass_action.m_arg[0] = NV_DATA_OMP(yp);
p_python_functions->mass_action.m_res[0] = tmp;
p_python_functions->mass_action();

// AXPY: y <- a*x + y
const int ns = p_python_functions->number_of_states;
casadi::casadi_axpy(ns, -1., tmp, NV_DATA_S(rr));
casadi::casadi_axpy(ns, -1., tmp, NV_DATA_OMP(rr));

DEBUG_VECTOR(yy);
DEBUG_VECTOR(yp);
Expand Down Expand Up @@ -101,22 +101,22 @@ int jtimes_casadi(realtype tt, N_Vector yy, N_Vector yp, N_Vector rr,

// Jv has ∂F/∂y v
p_python_functions->jac_action.m_arg[0] = &tt;
p_python_functions->jac_action.m_arg[1] = NV_DATA_S(yy);
p_python_functions->jac_action.m_arg[1] = NV_DATA_OMP(yy);
p_python_functions->jac_action.m_arg[2] = p_python_functions->inputs.data();
p_python_functions->jac_action.m_arg[3] = NV_DATA_S(v);
p_python_functions->jac_action.m_res[0] = NV_DATA_S(Jv);
p_python_functions->jac_action.m_arg[3] = NV_DATA_OMP(v);
p_python_functions->jac_action.m_res[0] = NV_DATA_OMP(Jv);
p_python_functions->jac_action();

// tmp has -∂F/∂y˙ v
realtype *tmp = p_python_functions->get_tmp_state_vector();
p_python_functions->mass_action.m_arg[0] = NV_DATA_S(v);
p_python_functions->mass_action.m_arg[0] = NV_DATA_OMP(v);
p_python_functions->mass_action.m_res[0] = tmp;
p_python_functions->mass_action();

// AXPY: y <- a*x + y
// Jv has ∂F/∂y v + cj ∂F/∂y˙ v
const int ns = p_python_functions->number_of_states;
casadi::casadi_axpy(ns, -cj, tmp, NV_DATA_S(Jv));
casadi::casadi_axpy(ns, -cj, tmp, NV_DATA_OMP(Jv));

return 0;
}
Expand Down Expand Up @@ -163,7 +163,7 @@ int jacobian_casadi(realtype tt, realtype cj, N_Vector yy, N_Vector yp,

// args are t, y, cj, put result in jacobian data matrix
p_python_functions->jac_times_cjmass.m_arg[0] = &tt;
p_python_functions->jac_times_cjmass.m_arg[1] = NV_DATA_S(yy);
p_python_functions->jac_times_cjmass.m_arg[1] = NV_DATA_OMP(yy);
p_python_functions->jac_times_cjmass.m_arg[2] =
p_python_functions->inputs.data();
p_python_functions->jac_times_cjmass.m_arg[3] = &cj;
Expand Down Expand Up @@ -227,7 +227,7 @@ int events_casadi(realtype t, N_Vector yy, N_Vector yp, realtype *events_ptr,

// args are t, y, put result in events_ptr
p_python_functions->events.m_arg[0] = &t;
p_python_functions->events.m_arg[1] = NV_DATA_S(yy);
p_python_functions->events.m_arg[1] = NV_DATA_OMP(yy);
p_python_functions->events.m_arg[2] = p_python_functions->inputs.data();
p_python_functions->events.m_res[0] = events_ptr;
p_python_functions->events();
Expand Down Expand Up @@ -270,11 +270,11 @@ int sensitivities_casadi(int Ns, realtype t, N_Vector yy, N_Vector yp,

// args are t, y put result in rr
p_python_functions->sens.m_arg[0] = &t;
p_python_functions->sens.m_arg[1] = NV_DATA_S(yy);
p_python_functions->sens.m_arg[1] = NV_DATA_OMP(yy);
p_python_functions->sens.m_arg[2] = p_python_functions->inputs.data();
for (int i = 0; i < np; i++)
{
p_python_functions->sens.m_res[i] = NV_DATA_S(resvalS[i]);
p_python_functions->sens.m_res[i] = NV_DATA_OMP(resvalS[i]);
}
// resvalsS now has (∂F/∂p i )
p_python_functions->sens();
Expand All @@ -284,23 +284,23 @@ int sensitivities_casadi(int Ns, realtype t, N_Vector yy, N_Vector yp,
// put (∂F/∂y)s i (t) in tmp
realtype *tmp = p_python_functions->get_tmp_state_vector();
p_python_functions->jac_action.m_arg[0] = &t;
p_python_functions->jac_action.m_arg[1] = NV_DATA_S(yy);
p_python_functions->jac_action.m_arg[1] = NV_DATA_OMP(yy);
p_python_functions->jac_action.m_arg[2] = p_python_functions->inputs.data();
p_python_functions->jac_action.m_arg[3] = NV_DATA_S(yS[i]);
p_python_functions->jac_action.m_arg[3] = NV_DATA_OMP(yS[i]);
p_python_functions->jac_action.m_res[0] = tmp;
p_python_functions->jac_action();

const int ns = p_python_functions->number_of_states;
casadi::casadi_axpy(ns, 1., tmp, NV_DATA_S(resvalS[i]));
casadi::casadi_axpy(ns, 1., tmp, NV_DATA_OMP(resvalS[i]));

// put -(∂F/∂ ẏ) ṡ i (t) in tmp2
p_python_functions->mass_action.m_arg[0] = NV_DATA_S(ypS[i]);
p_python_functions->mass_action.m_arg[0] = NV_DATA_OMP(ypS[i]);
p_python_functions->mass_action.m_res[0] = tmp;
p_python_functions->mass_action();

// (∂F/∂y)s i (t)+(∂F/∂ ẏ) ṡ i (t)+(∂F/∂p i )
// AXPY: y <- a*x + y
casadi::casadi_axpy(ns, -1., tmp, NV_DATA_S(resvalS[i]));
casadi::casadi_axpy(ns, -1., tmp, NV_DATA_OMP(resvalS[i]));
}

return 0;
Expand Down
1 change: 1 addition & 0 deletions pybamm/solvers/c_solvers/idaklu/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <idas/idas_bbdpre.h> /* access to IDABBDPRE preconditioner */

#include <nvector/nvector_serial.h> /* access to serial N_Vector */
#include <nvector/nvector_openmp.h> /* access to openmp N_Vector */
#include <sundials/sundials_math.h> /* defs. of SUNRabs, SUNRexp, etc. */
#include <sundials/sundials_config.h> /* defs. of SUNRabs, SUNRexp, etc. */
#include <sundials/sundials_types.h> /* defs. of realtype, sunindextype */
Expand Down
3 changes: 2 additions & 1 deletion pybamm/solvers/c_solvers/idaklu/options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ Options::Options(py::dict options)
linsol_max_iterations(options["linsol_max_iterations"].cast<int>()),
linear_solver(options["linear_solver"].cast<std::string>()),
precon_half_bandwidth(options["precon_half_bandwidth"].cast<int>()),
precon_half_bandwidth_keep(options["precon_half_bandwidth_keep"].cast<int>())
precon_half_bandwidth_keep(options["precon_half_bandwidth_keep"].cast<int>()),
num_threads(options["num_threads"].cast<int>())
{

using_sparse_matrix = true;
Expand Down
1 change: 1 addition & 0 deletions pybamm/solvers/c_solvers/idaklu/options.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ struct Options {
int linsol_max_iterations;
int precon_half_bandwidth;
int precon_half_bandwidth_keep;
int num_threads;
explicit Options(py::dict options);

};
Expand Down
4 changes: 4 additions & 0 deletions pybamm/solvers/idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ class IDAKLUSolver(pybamm.BaseSolver):
# for iterative linear solver preconditioner, bandwidth of
# approximate jacobian that is kept
"precon_half_bandwidth_keep": 5

# Number of threads available for OpenMP
"num_threads": 1
}

Note: These options only have an effect if model.convert_to_format == 'casadi'
Expand All @@ -100,6 +103,7 @@ def __init__(
"linsol_max_iterations": 5,
"precon_half_bandwidth": 5,
"precon_half_bandwidth_keep": 5,
"num_threads": 1,
}
if options is None:
options = default_options
Expand Down
5 changes: 3 additions & 2 deletions scripts/install_KLU_Sundials.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,11 @@ def download_extract_library(url, download_dir):
KLU_INCLUDE_DIR = os.path.join(install_dir, "include")
KLU_LIBRARY_DIR = os.path.join(install_dir, "lib")
cmake_args = [
"-DLAPACK_ENABLE=ON",
"-DENABLE_LAPACK=ON",
"-DSUNDIALS_INDEX_SIZE=32",
"-DEXAMPLES_ENABLE:BOOL=OFF",
"-DKLU_ENABLE=ON",
"-DENABLE_KLU=ON",
"-DENABLE_OPENMP=ON",
"-DKLU_INCLUDE_DIR={}".format(KLU_INCLUDE_DIR),
"-DKLU_LIBRARY_DIR={}".format(KLU_LIBRARY_DIR),
"-DCMAKE_INSTALL_PREFIX=" + install_dir,
Expand Down