From 36d19a009f9f560026b5a4517eec59b4cf5fa649 Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Thu, 11 Jul 2024 16:32:47 +0000 Subject: [PATCH] openmp parallel loop for idaklu #4087 --- CMakeLists.txt | 8 + compile_commands.json | 74 ++++++ .../operations/evaluate_python.py | 3 +- pybamm/solvers/base_solver.py | 30 ++- pybamm/solvers/c_solvers/idaklu.cpp | 11 +- .../solvers/c_solvers/idaklu/CasadiSolver.hpp | 18 +- .../c_solvers/idaklu/CasadiSolverGroup.cpp | 165 ++++++++++++ .../c_solvers/idaklu/CasadiSolverGroup.hpp | 50 ++++ .../c_solvers/idaklu/CasadiSolverOpenMP.cpp | 168 ++++-------- .../c_solvers/idaklu/CasadiSolverOpenMP.hpp | 18 +- .../c_solvers/idaklu/casadi_functions.hpp | 2 +- .../c_solvers/idaklu/casadi_solver.cpp | 97 +++++-- .../c_solvers/idaklu/casadi_solver.hpp | 32 ++- pybamm/solvers/c_solvers/idaklu/common.hpp | 3 +- pybamm/solvers/c_solvers/idaklu/options.cpp | 10 +- pybamm/solvers/c_solvers/idaklu/options.hpp | 2 +- pybamm/solvers/c_solvers/idaklu/solution.hpp | 5 + pybamm/solvers/idaklu_solver.py | 247 ++++++++++-------- tests/unit/test_solvers/test_idaklu_solver.py | 1 + 19 files changed, 651 insertions(+), 293 deletions(-) create mode 100644 compile_commands.json create mode 100644 pybamm/solvers/c_solvers/idaklu/CasadiSolverGroup.cpp create mode 100644 pybamm/solvers/c_solvers/idaklu/CasadiSolverGroup.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index b9fe37c331..f6905f8b74 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -45,6 +45,8 @@ pybind11_add_module(idaklu pybamm/solvers/c_solvers/idaklu/casadi_solver.hpp pybamm/solvers/c_solvers/idaklu/CasadiSolver.cpp pybamm/solvers/c_solvers/idaklu/CasadiSolver.hpp + pybamm/solvers/c_solvers/idaklu/CasadiSolverGroup.cpp + pybamm/solvers/c_solvers/idaklu/CasadiSolverGroup.hpp pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.hpp pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.cpp @@ -95,6 +97,12 @@ set_target_properties( INSTALL_RPATH_USE_LINK_PATH TRUE ) +# openmp +find_package(OpenMP) +if(OpenMP_CXX_FOUND) + target_link_libraries(idaklu PRIVATE OpenMP::OpenMP_CXX) +endif() + set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${PROJECT_SOURCE_DIR}) # Sundials find_package(SUNDIALS REQUIRED) diff --git a/compile_commands.json b/compile_commands.json new file mode 100644 index 0000000000..2baf2c2a59 --- /dev/null +++ b/compile_commands.json @@ -0,0 +1,74 @@ +[ +{ + "directory": "/home/mrobins/git/PyBaMM", + "command": "/usr/bin/c++ -DCASADI_SNPRINTF=snprintf -D_GLIBCXX_USE_CXX11_ABI=0 -Didaklu_EXPORTS -I/usr/include/suitesparse -isystem /home/mrobins/git/PyBaMM/pybind11/include -isystem /usr/include/python3.10 -isystem /home/mrobins/git/PyBaMM/env/lib/python3.10/site-packages/casadi/include -Werror=vla -O3 -DNDEBUG -std=c++14 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/casadi_functions.cpp.o -c /home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/casadi_functions.cpp", + "file": "/home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/casadi_functions.cpp", + "output": "CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/casadi_functions.cpp.o" +}, +{ + "directory": "/home/mrobins/git/PyBaMM", + "command": "/usr/bin/c++ -DCASADI_SNPRINTF=snprintf -D_GLIBCXX_USE_CXX11_ABI=0 -Didaklu_EXPORTS -I/usr/include/suitesparse -isystem /home/mrobins/git/PyBaMM/pybind11/include -isystem /usr/include/python3.10 -isystem /home/mrobins/git/PyBaMM/env/lib/python3.10/site-packages/casadi/include -Werror=vla -O3 -DNDEBUG -std=c++14 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/casadi_solver.cpp.o -c /home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/casadi_solver.cpp", + "file": "/home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/casadi_solver.cpp", + "output": "CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/casadi_solver.cpp.o" +}, +{ + "directory": "/home/mrobins/git/PyBaMM", + "command": "/usr/bin/c++ -DCASADI_SNPRINTF=snprintf -D_GLIBCXX_USE_CXX11_ABI=0 -Didaklu_EXPORTS -I/usr/include/suitesparse -isystem /home/mrobins/git/PyBaMM/pybind11/include -isystem /usr/include/python3.10 -isystem /home/mrobins/git/PyBaMM/env/lib/python3.10/site-packages/casadi/include -Werror=vla -O3 -DNDEBUG -std=c++14 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/CasadiSolver.cpp.o -c /home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/CasadiSolver.cpp", + "file": "/home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/CasadiSolver.cpp", + "output": "CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/CasadiSolver.cpp.o" +}, +{ + "directory": "/home/mrobins/git/PyBaMM", + "command": "/usr/bin/c++ -DCASADI_SNPRINTF=snprintf -D_GLIBCXX_USE_CXX11_ABI=0 -Didaklu_EXPORTS -I/usr/include/suitesparse -isystem /home/mrobins/git/PyBaMM/pybind11/include -isystem /usr/include/python3.10 -isystem /home/mrobins/git/PyBaMM/env/lib/python3.10/site-packages/casadi/include -Werror=vla -O3 -DNDEBUG -std=c++14 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/CasadiSolverGroup.cpp.o -c /home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/CasadiSolverGroup.cpp", + "file": "/home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/CasadiSolverGroup.cpp", + "output": "CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/CasadiSolverGroup.cpp.o" +}, +{ + "directory": "/home/mrobins/git/PyBaMM", + "command": "/usr/bin/c++ -DCASADI_SNPRINTF=snprintf -D_GLIBCXX_USE_CXX11_ABI=0 -Didaklu_EXPORTS -I/usr/include/suitesparse -isystem /home/mrobins/git/PyBaMM/pybind11/include -isystem /usr/include/python3.10 -isystem /home/mrobins/git/PyBaMM/env/lib/python3.10/site-packages/casadi/include -Werror=vla -O3 -DNDEBUG -std=c++14 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp.o -c /home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp", + "file": "/home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp", + "output": "CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp.o" +}, +{ + "directory": "/home/mrobins/git/PyBaMM", + "command": "/usr/bin/c++ -DCASADI_SNPRINTF=snprintf -D_GLIBCXX_USE_CXX11_ABI=0 -Didaklu_EXPORTS -I/usr/include/suitesparse -isystem /home/mrobins/git/PyBaMM/pybind11/include -isystem /usr/include/python3.10 -isystem /home/mrobins/git/PyBaMM/env/lib/python3.10/site-packages/casadi/include -Werror=vla -O3 -DNDEBUG -std=c++14 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.cpp.o -c /home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.cpp", + "file": "/home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.cpp", + "output": "CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.cpp.o" +}, +{ + "directory": "/home/mrobins/git/PyBaMM", + "command": "/usr/bin/c++ -DCASADI_SNPRINTF=snprintf -D_GLIBCXX_USE_CXX11_ABI=0 -Didaklu_EXPORTS -I/usr/include/suitesparse -isystem /home/mrobins/git/PyBaMM/pybind11/include -isystem /usr/include/python3.10 -isystem /home/mrobins/git/PyBaMM/env/lib/python3.10/site-packages/casadi/include -Werror=vla -O3 -DNDEBUG -std=c++14 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.cpp.o -c /home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.cpp", + "file": "/home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.cpp", + "output": "CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.cpp.o" +}, +{ + "directory": "/home/mrobins/git/PyBaMM", + "command": "/usr/bin/c++ -DCASADI_SNPRINTF=snprintf -D_GLIBCXX_USE_CXX11_ABI=0 -Didaklu_EXPORTS -I/usr/include/suitesparse -isystem /home/mrobins/git/PyBaMM/pybind11/include -isystem /usr/include/python3.10 -isystem /home/mrobins/git/PyBaMM/env/lib/python3.10/site-packages/casadi/include -Werror=vla -O3 -DNDEBUG -std=c++14 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/idaklu_jax.cpp.o -c /home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/idaklu_jax.cpp", + "file": "/home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/idaklu_jax.cpp", + "output": "CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/idaklu_jax.cpp.o" +}, +{ + "directory": "/home/mrobins/git/PyBaMM", + "command": "/usr/bin/c++ -DCASADI_SNPRINTF=snprintf -D_GLIBCXX_USE_CXX11_ABI=0 -Didaklu_EXPORTS -I/usr/include/suitesparse -isystem /home/mrobins/git/PyBaMM/pybind11/include -isystem /usr/include/python3.10 -isystem /home/mrobins/git/PyBaMM/env/lib/python3.10/site-packages/casadi/include -Werror=vla -O3 -DNDEBUG -std=c++14 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/python.cpp.o -c /home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/python.cpp", + "file": "/home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/python.cpp", + "output": "CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/python.cpp.o" +}, +{ + "directory": "/home/mrobins/git/PyBaMM", + "command": "/usr/bin/c++ -DCASADI_SNPRINTF=snprintf -D_GLIBCXX_USE_CXX11_ABI=0 -Didaklu_EXPORTS -I/usr/include/suitesparse -isystem /home/mrobins/git/PyBaMM/pybind11/include -isystem /usr/include/python3.10 -isystem /home/mrobins/git/PyBaMM/env/lib/python3.10/site-packages/casadi/include -Werror=vla -O3 -DNDEBUG -std=c++14 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/solution.cpp.o -c /home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/solution.cpp", + "file": "/home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/solution.cpp", + "output": "CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/solution.cpp.o" +}, +{ + "directory": "/home/mrobins/git/PyBaMM", + "command": "/usr/bin/c++ -DCASADI_SNPRINTF=snprintf -D_GLIBCXX_USE_CXX11_ABI=0 -Didaklu_EXPORTS -I/usr/include/suitesparse -isystem /home/mrobins/git/PyBaMM/pybind11/include -isystem /usr/include/python3.10 -isystem /home/mrobins/git/PyBaMM/env/lib/python3.10/site-packages/casadi/include -Werror=vla -O3 -DNDEBUG -std=c++14 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/options.cpp.o -c /home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/options.cpp", + "file": "/home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu/options.cpp", + "output": "CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu/options.cpp.o" +}, +{ + "directory": "/home/mrobins/git/PyBaMM", + "command": "/usr/bin/c++ -DCASADI_SNPRINTF=snprintf -D_GLIBCXX_USE_CXX11_ABI=0 -Didaklu_EXPORTS -I/usr/include/suitesparse -isystem /home/mrobins/git/PyBaMM/pybind11/include -isystem /usr/include/python3.10 -isystem /home/mrobins/git/PyBaMM/env/lib/python3.10/site-packages/casadi/include -Werror=vla -O3 -DNDEBUG -std=c++14 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu.cpp.o -c /home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu.cpp", + "file": "/home/mrobins/git/PyBaMM/pybamm/solvers/c_solvers/idaklu.cpp", + "output": "CMakeFiles/idaklu.dir/pybamm/solvers/c_solvers/idaklu.cpp.o" +} +] diff --git a/pybamm/expression_tree/operations/evaluate_python.py b/pybamm/expression_tree/operations/evaluate_python.py index 5ad74d939b..93a056d00d 100644 --- a/pybamm/expression_tree/operations/evaluate_python.py +++ b/pybamm/expression_tree/operations/evaluate_python.py @@ -516,9 +516,10 @@ def __call__(self, t=None, y=None, inputs=None): """ evaluate function """ - # generated code assumes y is a column vector + # generated code assumes y and inputs are column vectors if y is not None and y.ndim == 1: y = y.reshape(-1, 1) + inputs = inputs.reshape(-1, 1) if self._ninputs == 1: # nothing to do for a single input diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index b0f793e94a..d0031ece9d 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -73,7 +73,6 @@ def __init__( if options is None: options = default_options else: - print("options", options) for key, value in default_options.items(): if key not in options: options[key] = value @@ -928,15 +927,6 @@ def solve( f"len(inputs) = {len(inputs)} and batch_size = {batch_size}" ) - # get a list-only version of calculate_sensitivities - if isinstance(calculate_sensitivities, bool): - if calculate_sensitivities: - calculate_sensitivities_list = [p for p in inputs.keys()] - else: - calculate_sensitivities_list = [] - else: - calculate_sensitivities_list = calculate_sensitivities - # Make sure model isn't empty if len(model.rhs) == 0 and len(model.algebraic) == 0: if not isinstance(self, pybamm.DummySolver): @@ -989,6 +979,15 @@ def solve( self._set_up_model_inputs(model, inputs) for inputs in inputs_list ] + # get a list-only version of calculate_sensitivities + if isinstance(calculate_sensitivities, bool): + if calculate_sensitivities: + calculate_sensitivities_list = [p for p in inputs_list[0].keys()] + else: + calculate_sensitivities_list = [] + else: + calculate_sensitivities_list = calculate_sensitivities + # Check that calculate_sensitivites or batch size have not been updated calculate_sensitivities_list.sort() if not hasattr(model, "calculate_sensitivities"): @@ -1745,10 +1744,13 @@ def map_func_over_inputs_casadi(name, f, vars_for_processing, ninputs, nthreads) nstates = vars_for_processing["y_and_S"].shape[0] nparams = vars_for_processing["p_casadi_stacked"].shape[0] - if nthreads == 1: - parallelisation = "none" - else: + threads_per_input = nthreads // ninputs + if threads_per_input > 1: + threads_per_input = 1 + if threads_per_input > 1: parallelisation = "thread" + else: + parallelisation = "none" y_and_S_inputs_stacked = casadi.MX.sym("y_and_S_stacked", nstates * ninputs) p_casadi_inputs_stacked = casadi.MX.sym("p_stacked", nparams * ninputs) v_inputs_stacked = casadi.MX.sym("v_stacked", nstates * ninputs) @@ -1771,7 +1773,7 @@ def map_func_over_inputs_casadi(name, f, vars_for_processing, ninputs, nthreads) inputs_2d = [t_2d, y_and_S_2d, p_casadi_2d] inputs_stacked = [t_stacked, y_and_S_inputs_stacked, p_casadi_inputs_stacked] - mapped_f = f.map(ninputs, parallelisation, nthreads)(*inputs_2d) + mapped_f = f.map(ninputs, parallelisation, threads_per_input)(*inputs_2d) if matrix_output: # for matrix output we need to stack the outputs in a block diagonal matrix splits = [i * nstates for i in range(ninputs + 1)] diff --git a/pybamm/solvers/c_solvers/idaklu.cpp b/pybamm/solvers/c_solvers/idaklu.cpp index 9f99d4d3f4..7dbc9c3111 100644 --- a/pybamm/solvers/c_solvers/idaklu.cpp +++ b/pybamm/solvers/c_solvers/idaklu.cpp @@ -21,12 +21,14 @@ Function generate_function(const std::string &data) namespace py = pybind11; PYBIND11_MAKE_OPAQUE(std::vector); +PYBIND11_MAKE_OPAQUE(std::vector); PYBIND11_MODULE(idaklu, m) { m.doc() = "sundials solvers"; // optional module docstring py::bind_vector>(m, "VectorNdArray"); + py::bind_vector>(m, "VectorSolution"); m.def("solve_python", &solve_python, "The solve function for python evaluators", @@ -50,8 +52,8 @@ PYBIND11_MODULE(idaklu, m) py::arg("number_of_sensitivity_parameters"), py::return_value_policy::take_ownership); - py::class_(m, "CasadiSolver") - .def("solve", &CasadiSolver::solve, + py::class_(m, "CasadiSolverGroup") + .def("solve", &CasadiSolverGroup::solve, "perform a solve", py::arg("t"), py::arg("y0"), @@ -59,8 +61,8 @@ PYBIND11_MODULE(idaklu, m) py::arg("inputs"), py::return_value_policy::take_ownership); - m.def("create_casadi_solver", &create_casadi_solver, - "Create a casadi idaklu solver object", + m.def("create_casadi_solver_group", &create_casadi_solver_group, + "Create a casadi idaklu solver group object", py::arg("number_of_states"), py::arg("number_of_parameters"), py::arg("rhs_alg"), @@ -83,6 +85,7 @@ PYBIND11_MODULE(idaklu, m) py::arg("dvar_dy_fcns"), py::arg("dvar_dp_fcns"), py::arg("options"), + py::arg("nsolvers"), py::return_value_policy::take_ownership); m.def("generate_function", &generate_function, diff --git a/pybamm/solvers/c_solvers/idaklu/CasadiSolver.hpp b/pybamm/solvers/c_solvers/idaklu/CasadiSolver.hpp index dac94579f3..7663aba439 100644 --- a/pybamm/solvers/c_solvers/idaklu/CasadiSolver.hpp +++ b/pybamm/solvers/c_solvers/idaklu/CasadiSolver.hpp @@ -32,11 +32,19 @@ class CasadiSolver /** * @brief Abstract solver method that returns a Solution class */ - virtual Solution solve( - np_array t_np, - np_array y0_np, - np_array yp0_np, - np_array_dense inputs) = 0; + virtual void solve( + const realtype *t, + const int number_of_timesteps, + const realtype *y0, + const realtype *yp0, + const realtype *inputs, + const int length_of_return_vector, + realtype *y_return, + realtype *yS_return, + realtype *t_return, + int &t_i, + int &retval + ) = 0; /** * Abstract method to initialize the solver, once vectors and solver classes diff --git a/pybamm/solvers/c_solvers/idaklu/CasadiSolverGroup.cpp b/pybamm/solvers/c_solvers/idaklu/CasadiSolverGroup.cpp new file mode 100644 index 0000000000..e672df87bc --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/CasadiSolverGroup.cpp @@ -0,0 +1,165 @@ +#include "CasadiSolverGroup.hpp" +#include + +std::vector CasadiSolverGroup::solve(np_array t_np, np_array y0_np, np_array yp0_np, np_array inputs) { + auto n_coeffs = number_of_states + number_of_parameters * number_of_states; + + if (y0_np.ndim() != 2) + throw std::domain_error("y0 has wrong number of dimensions. Expected 2 but got " + std::to_string(y0_np.ndim())); + if (yp0_np.ndim() != 2) + throw std::domain_error("yp0 has wrong number of dimensions. Expected 2 but got " + std::to_string(yp0_np.ndim())); + if (inputs.ndim() != 2) + throw std::domain_error("inputs has wrong number of dimensions. Expected 2 but got " + std::to_string(inputs.ndim())); + + auto n_groups = y0_np.shape()[0]; + + if (y0_np.shape()[1] != n_coeffs) + throw std::domain_error( + "y0 has wrong number of cols. Expected " + std::to_string(n_coeffs) + + " but got " + std::to_string(y0_np.shape()[1])); + + if (yp0_np.shape()[1] != n_coeffs) + throw std::domain_error( + "yp0 has wrong number of cols. Expected " + std::to_string(n_coeffs) + + " but got " + std::to_string(yp0_np.shape()[1])); + + if (yp0_np.shape()[0] != n_groups) + throw std::domain_error( + "yp0 has wrong number of rows. Expected " + std::to_string(n_groups) + + " but got " + std::to_string(yp0_np.shape()[0])); + + if (inputs.shape()[0] != n_groups) + throw std::domain_error( + "inputs has wrong number of rows. Expected " + std::to_string(n_groups) + + " but got " + std::to_string(inputs.shape()[0])); + + + const int number_of_timesteps = t_np.shape(0); + + // set return vectors + std::vector retval_returns(n_groups); + std::vector t_i_returns(n_groups); + std::vector t_returns(n_groups); + std::vector y_returns(n_groups); + std::vector yS_returns(n_groups); + for (int i = 0; i < n_groups; i++) { + + t_returns[i] = new realtype[number_of_timesteps]; + y_returns[i] = new realtype[number_of_timesteps * + length_of_return_vector]; + yS_returns[i] = new realtype[number_of_parameters * + number_of_timesteps * + length_of_return_vector]; + } + + + const std::size_t solves_per_thread = n_groups / m_solvers.size(); + const std::size_t remainder_solves = n_groups % m_solvers.size(); + + const std::size_t nthreads = m_solvers.size(); + + const realtype *t = t_np.data(); + const realtype *y0 = y0_np.data(); + const realtype *yp0 = yp0_np.data(); + const realtype *inputs_data = inputs.data(); + + omp_set_num_threads(nthreads); + #pragma omp parallel for + for (int i = 0; i < m_solvers.size(); i++) { + for (int j = 0; j < solves_per_thread; j++) { + const std::size_t index = i * solves_per_thread + j; + const realtype *y = y0 + index * y0_np.shape(1); + const realtype *yp = yp0 + index * yp0_np.shape(1); + const realtype *input = inputs_data + index * inputs.shape(1); + realtype *y_return = y_returns[index]; + realtype *yS_return = yS_returns[index]; + realtype *t_return = t_returns[index]; + int &t_i = t_i_returns[index]; + int &retval = retval_returns[index]; + m_solvers[i]->solve(t, number_of_timesteps, y, yp, input, length_of_return_vector, y_return, yS_return, t_return, t_i, retval); + } + } + + for (int i = 0; i < remainder_solves; i++) { + const std::size_t index = n_groups - remainder_solves + i; + const realtype *y = y0 + index * y0_np.shape(1); + const realtype *yp = yp0 + index * yp0_np.shape(1); + const realtype *input = inputs_data + index * inputs.shape(1); + realtype *y_return = y_returns[index]; + realtype *yS_return = yS_returns[index]; + realtype *t_return = t_returns[index]; + int &t_i = t_i_returns[index]; + int &retval = retval_returns[index]; + m_solvers[i]->solve(t, number_of_timesteps, y, yp, input, length_of_return_vector, y_return, yS_return, t_return, t_i, retval); + } + + // create solutions + std::vector solutions(n_groups); + for (int i = 0; i < n_groups; i++) { + int t_i = t_i_returns[i]; + int retval = retval_returns[i]; + realtype *t_return = t_returns[i]; + realtype *y_return = y_returns[i]; + realtype *yS_return = yS_returns[i]; + + py::capsule free_t_when_done( + t_return, + [](void *f) { + realtype *vect = reinterpret_cast(f); + delete[] vect; + } + ); + py::capsule free_y_when_done( + y_return, + [](void *f) { + realtype *vect = reinterpret_cast(f); + delete[] vect; + } + ); + py::capsule free_yS_when_done( + yS_return, + [](void *f) { + realtype *vect = reinterpret_cast(f); + delete[] vect; + } + ); + + np_array t_ret = np_array( + t_i, + &t_return[0], + free_t_when_done + ); + np_array y_ret = np_array( + t_i * length_of_return_vector, + &y_return[0], + free_y_when_done + ); + // Note: Ordering of vector is differnet if computing variables vs returning + // the complete state vector + np_array yS_ret; + if (is_output_variables) { + yS_ret = np_array( + std::vector { + number_of_timesteps, + length_of_return_vector, + number_of_parameters + }, + &yS_return[0], + free_yS_when_done + ); + } else { + yS_ret = np_array( + std::vector { + number_of_parameters, + number_of_timesteps, + length_of_return_vector + }, + &yS_return[0], + free_yS_when_done + ); + } + solutions[i] = Solution(retval, t_ret, y_ret, yS_ret); + } + + return solutions; +} diff --git a/pybamm/solvers/c_solvers/idaklu/CasadiSolverGroup.hpp b/pybamm/solvers/c_solvers/idaklu/CasadiSolverGroup.hpp new file mode 100644 index 0000000000..8436a05430 --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/CasadiSolverGroup.hpp @@ -0,0 +1,50 @@ +#ifndef PYBAMM_IDAKLU_CASADI_SOLVER_GROUP_HPP +#define PYBAMM_IDAKLU_CASADI_SOLVER_GROUP_HPP + +#include "CasadiSolver.hpp" +#include "common.hpp" + +/** + * @brief class for a group of solvers. + */ +class CasadiSolverGroup +{ +public: + + /** + * @brief Default constructor + */ + CasadiSolverGroup(std::vector> solvers, int number_of_states, int number_of_parameters, int length_of_return_vector, bool is_output_variables): + m_solvers(std::move(solvers)), + number_of_states(number_of_states), + number_of_parameters(number_of_parameters), + length_of_return_vector(length_of_return_vector), + is_output_variables(is_output_variables) + {} + + // no copy constructor (unique_ptr cannot be copied) + CasadiSolverGroup(CasadiSolverGroup &) = delete; + + /** + * @brief Default destructor + */ + ~CasadiSolverGroup() = default; + + /** + * @brief solver method that returns a vector of Solutions + */ + std::vector solve( + np_array t_np, + np_array y0_np, + np_array yp0_np, + np_array inputs); + + private: + std::vector> m_solvers; + int number_of_states; + int number_of_parameters; + int length_of_return_vector; + bool is_output_variables; +}; + +#endif // PYBAMM_IDAKLU_CASADI_SOLVER_GROUP_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp b/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp index ad51eda4e1..8c1f06711f 100644 --- a/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp +++ b/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp @@ -75,14 +75,39 @@ CasadiSolverOpenMP::CasadiSolverOpenMP( if (options.preconditioner != "none") { precon_type = SUN_PREC_LEFT; } + + // allocate temp buffers for output variables + size_t max_res_size = 0; // maximum result size (for common result buffer) + size_t max_res_dvar_dy = 0, max_res_dvar_dp = 0; + if (functions->var_casadi_fcns.size() > 0) { + // return only the requested variables list after computation + for (auto& var_fcn : functions->var_casadi_fcns) { + max_res_size = std::max(max_res_size, size_t(var_fcn.nnz_out())); + for (auto& dvar_fcn : functions->dvar_dy_fcns) + max_res_dvar_dy = std::max(max_res_dvar_dy, size_t(dvar_fcn.nnz_out())); + for (auto& dvar_fcn : functions->dvar_dp_fcns) + max_res_dvar_dp = std::max(max_res_dvar_dp, size_t(dvar_fcn.nnz_out())); + } + } + + res = new realtype[max_res_size]; + res_dvar_dy = new realtype[max_res_dvar_dy]; + res_dvar_dp = new realtype[max_res_dvar_dp]; } void CasadiSolverOpenMP::AllocateVectors() { // Create vectors - yy = N_VNew_OpenMP(number_of_states, options.num_threads, sunctx); - yp = N_VNew_OpenMP(number_of_states, options.num_threads, sunctx); - avtol = N_VNew_OpenMP(number_of_states, options.num_threads, sunctx); - id = N_VNew_OpenMP(number_of_states, options.num_threads, sunctx); + if (options.num_threads == 1) { + yy = N_VNew_Serial(number_of_states, sunctx); + yp = N_VNew_Serial(number_of_states, sunctx); + avtol = N_VNew_Serial(number_of_states, sunctx); + id = N_VNew_Serial(number_of_states, sunctx); + } else { + yy = N_VNew_OpenMP(number_of_states, options.num_threads, sunctx); + yp = N_VNew_OpenMP(number_of_states, options.num_threads, sunctx); + avtol = N_VNew_OpenMP(number_of_states, options.num_threads, sunctx); + id = N_VNew_OpenMP(number_of_states, options.num_threads, sunctx); + } } void CasadiSolverOpenMP::SetMatrix() { @@ -247,36 +272,27 @@ void CasadiSolverOpenMP::CalcVarsSensitivities( } } -Solution CasadiSolverOpenMP::solve( - np_array t_np, - np_array y0_np, - np_array yp0_np, - np_array_dense inputs +void CasadiSolverOpenMP::solve( + const realtype *t, + const int number_of_timesteps, + const realtype *y0, + const realtype *yp0, + const realtype *inputs, + const int length_of_return_vector, + realtype *y_return, + realtype *yS_return, + realtype *t_return, + int &t_i, + int &retval ) { DEBUG("CasadiSolver::solve"); - int number_of_timesteps = t_np.request().size; - auto t = t_np.unchecked<1>(); - realtype t0 = RCONST(t(0)); - auto y0 = y0_np.unchecked<1>(); - auto yp0 = yp0_np.unchecked<1>(); - auto n_coeffs = number_of_states + number_of_parameters * number_of_states; - - if (y0.size() != n_coeffs) - throw std::domain_error( - "y0 has wrong size. Expected " + std::to_string(n_coeffs) + - " but got " + std::to_string(y0.size())); - - if (yp0.size() != n_coeffs) - throw std::domain_error( - "yp0 has wrong size. Expected " + std::to_string(n_coeffs) + - " but got " + std::to_string(yp0.size())); + realtype t0 = RCONST(t[0]); // set inputs - auto p_inputs = inputs.unchecked<2>(); for (int i = 0; i < functions->inputs.size(); i++) - functions->inputs[i] = p_inputs(i, 0); + functions->inputs[i] = inputs[i]; // set initial conditions realtype *yval = N_VGetArrayPointer(yy); @@ -304,68 +320,19 @@ Solution CasadiSolverOpenMP::solve( // correct initial values DEBUG("IDACalcIC"); - IDACalcIC(ida_mem, IDA_YA_YDP_INIT, t(1)); + IDACalcIC(ida_mem, IDA_YA_YDP_INIT, t[1]); if (number_of_parameters > 0) IDAGetSens(ida_mem, &t0, yyS); realtype tret; - realtype t_final = t(number_of_timesteps - 1); - - // set return vectors - int length_of_return_vector = 0; - size_t max_res_size = 0; // maximum result size (for common result buffer) - size_t max_res_dvar_dy = 0, max_res_dvar_dp = 0; - if (functions->var_casadi_fcns.size() > 0) { - // return only the requested variables list after computation - for (auto& var_fcn : functions->var_casadi_fcns) { - max_res_size = std::max(max_res_size, size_t(var_fcn.nnz_out())); - length_of_return_vector += var_fcn.nnz_out(); - for (auto& dvar_fcn : functions->dvar_dy_fcns) - max_res_dvar_dy = std::max(max_res_dvar_dy, size_t(dvar_fcn.nnz_out())); - for (auto& dvar_fcn : functions->dvar_dp_fcns) - max_res_dvar_dp = std::max(max_res_dvar_dp, size_t(dvar_fcn.nnz_out())); - } - } else { - // Return full y state-vector - length_of_return_vector = number_of_states; - } - realtype *t_return = new realtype[number_of_timesteps]; - realtype *y_return = new realtype[number_of_timesteps * - length_of_return_vector]; - realtype *yS_return = new realtype[number_of_parameters * - number_of_timesteps * - length_of_return_vector]; + realtype t_final = t[number_of_timesteps - 1]; - res = new realtype[max_res_size]; - res_dvar_dy = new realtype[max_res_dvar_dy]; - res_dvar_dp = new realtype[max_res_dvar_dp]; - py::capsule free_t_when_done( - t_return, - [](void *f) { - realtype *vect = reinterpret_cast(f); - delete[] vect; - } - ); - py::capsule free_y_when_done( - y_return, - [](void *f) { - realtype *vect = reinterpret_cast(f); - delete[] vect; - } - ); - py::capsule free_yS_when_done( - yS_return, - [](void *f) { - realtype *vect = reinterpret_cast(f); - delete[] vect; - } - ); // Initial state (t_i=0) - int t_i = 0; + t_i = 0; size_t ySk = 0; - t_return[t_i] = t(t_i); + t_return[t_i] = t[t_i]; if (functions->var_casadi_fcns.size() > 0) { // Evaluate casadi functions for each requested variable and store CalcVars(y_return, length_of_return_vector, t_i, @@ -383,11 +350,10 @@ Solution CasadiSolverOpenMP::solve( } // Subsequent states (t_i>0) - int retval; t_i = 1; while (true) { - realtype t_next = t(t_i); + realtype t_next = t[t_i]; IDASetStopTime(ida_mem, t_next); DEBUG("IDASolve"); retval = IDASolve(ida_mem, t_final, &tret, yy, yp, IDA_NORMAL); @@ -433,42 +399,6 @@ Solution CasadiSolverOpenMP::solve( } } - np_array t_ret = np_array( - t_i, - &t_return[0], - free_t_when_done - ); - np_array y_ret = np_array( - t_i * length_of_return_vector, - &y_return[0], - free_y_when_done - ); - // Note: Ordering of vector is differnet if computing variables vs returning - // the complete state vector - np_array yS_ret; - if (functions->var_casadi_fcns.size() > 0) { - yS_ret = np_array( - std::vector { - number_of_timesteps, - length_of_return_vector, - number_of_parameters - }, - &yS_return[0], - free_yS_when_done - ); - } else { - yS_ret = np_array( - std::vector { - number_of_parameters, - number_of_timesteps, - length_of_return_vector - }, - &yS_return[0], - free_yS_when_done - ); - } - - Solution sol(retval, t_ret, y_ret, yS_ret); if (options.print_stats) { @@ -513,6 +443,4 @@ Solution CasadiSolverOpenMP::solve( py::print("\tNumber of nonlinear iterations performed =", nniters); py::print("\tNumber of nonlinear convergence failures =", nncfails); } - - return sol; } diff --git a/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.hpp b/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.hpp index 2312f9cf8f..ab34db455e 100644 --- a/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.hpp +++ b/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.hpp @@ -122,11 +122,19 @@ class CasadiSolverOpenMP : public CasadiSolver /** * @brief The main solve method that solves for each variable and time step */ - Solution solve( - np_array t_np, - np_array y0_np, - np_array yp0_np, - np_array_dense inputs) override; + void solve( + const realtype *t, + const int number_of_timesteps, + const realtype *y0_np, + const realtype *yp0_np, + const realtype *inputs, + const int length_of_return_vector, + realtype *y_return, + realtype *yS_return, + realtype *t_return, + int &t_i, + int &retval + ) override; /** * @brief Concrete implementation of initialization method diff --git a/pybamm/solvers/c_solvers/idaklu/casadi_functions.hpp b/pybamm/solvers/c_solvers/idaklu/casadi_functions.hpp index 1aaee0b77a..400f4b8492 100644 --- a/pybamm/solvers/c_solvers/idaklu/casadi_functions.hpp +++ b/pybamm/solvers/c_solvers/idaklu/casadi_functions.hpp @@ -86,7 +86,7 @@ class CasadiFunction std::vector m_res; private: - const Function &m_func; + const Function m_func; std::vector m_iw; std::vector m_w; }; diff --git a/pybamm/solvers/c_solvers/idaklu/casadi_solver.cpp b/pybamm/solvers/c_solvers/idaklu/casadi_solver.cpp index 9fcfa06510..4cc936cdcd 100644 --- a/pybamm/solvers/c_solvers/idaklu/casadi_solver.cpp +++ b/pybamm/solvers/c_solvers/idaklu/casadi_solver.cpp @@ -5,8 +5,9 @@ #include "common.hpp" #include #include +#include -CasadiSolver *create_casadi_solver( +CasadiSolverGroup *create_casadi_solver_group( int number_of_states, int number_of_parameters, const Function &rhs_alg, @@ -28,9 +29,12 @@ CasadiSolver *create_casadi_solver( const std::vector& var_casadi_fcns, const std::vector& dvar_dy_fcns, const std::vector& dvar_dp_fcns, - py::dict options + py::dict options, + const int nsolvers ) { - auto options_cpp = Options(options); + const int nthreads = options["num_threads"].cast(); + const int nsolvers_limited = std::min(nsolvers, nthreads); + auto options_cpp = Options(options, nsolvers_limited); auto functions = std::make_unique( rhs_alg, jac_times_cjmass, @@ -53,13 +57,63 @@ CasadiSolver *create_casadi_solver( options_cpp ); - CasadiSolver *casadiSolver = nullptr; + std::vector> solvers; + for (int i = 0; i < nsolvers_limited; i++) { + solvers.emplace_back(create_casadi_solver( + std::make_unique(*functions), + number_of_parameters, + jac_times_cjmass_colptrs, + jac_times_cjmass_rowvals, + jac_times_cjmass_nnz, + jac_bandwidth_lower, + jac_bandwidth_upper, + number_of_events, + rhs_alg_id, + atol_np, + rel_tol, + inputs_length, + options_cpp + )); + } + + // calculate length of return vector as needed for allocating ouput + int length_of_return_vector = 0; + if (functions->var_casadi_fcns.size() > 0) { + // return only the requested variables list after computation + for (auto& var_fcn : functions->var_casadi_fcns) { + length_of_return_vector += var_fcn.nnz_out(); + } + } else { + // Return full y state-vector + length_of_return_vector = number_of_states; + } + + const bool is_output_variables = functions->var_casadi_fcns.size() > 0; + return new CasadiSolverGroup(std::move(solvers), number_of_states, number_of_parameters, length_of_return_vector, is_output_variables); +} + +std::unique_ptr create_casadi_solver( + std::unique_ptr functions, + int number_of_parameters, + const np_array_int &jac_times_cjmass_colptrs, + const np_array_int &jac_times_cjmass_rowvals, + const int jac_times_cjmass_nnz, + const int jac_bandwidth_lower, + const int jac_bandwidth_upper, + const int number_of_events, + np_array rhs_alg_id, + np_array atol_np, + double rel_tol, + int inputs_length, + Options options_cpp +) { + // Instantiate solver class if (options_cpp.linear_solver == "SUNLinSol_Dense") { DEBUG("\tsetting SUNLinSol_Dense linear solver"); - casadiSolver = new CasadiSolverOpenMP_Dense( + return std::unique_ptr(new CasadiSolverOpenMP_Dense( atol_np, rel_tol, rhs_alg_id, @@ -70,12 +124,12 @@ CasadiSolver *create_casadi_solver( jac_bandwidth_upper, std::move(functions), options_cpp - ); + )); } else if (options_cpp.linear_solver == "SUNLinSol_KLU") { DEBUG("\tsetting SUNLinSol_KLU linear solver"); - casadiSolver = new CasadiSolverOpenMP_KLU( + return std::unique_ptr(new CasadiSolverOpenMP_KLU( atol_np, rel_tol, rhs_alg_id, @@ -86,12 +140,12 @@ CasadiSolver *create_casadi_solver( jac_bandwidth_upper, std::move(functions), options_cpp - ); + )); } else if (options_cpp.linear_solver == "SUNLinSol_Band") { DEBUG("\tsetting SUNLinSol_Band linear solver"); - casadiSolver = new CasadiSolverOpenMP_Band( + return std::unique_ptr(new CasadiSolverOpenMP_Band( atol_np, rel_tol, rhs_alg_id, @@ -102,12 +156,12 @@ CasadiSolver *create_casadi_solver( jac_bandwidth_upper, std::move(functions), options_cpp - ); + )); } else if (options_cpp.linear_solver == "SUNLinSol_SPBCGS") { DEBUG("\tsetting SUNLinSol_SPBCGS_linear solver"); - casadiSolver = new CasadiSolverOpenMP_SPBCGS( + return std::unique_ptr(new CasadiSolverOpenMP_SPBCGS( atol_np, rel_tol, rhs_alg_id, @@ -118,12 +172,12 @@ CasadiSolver *create_casadi_solver( jac_bandwidth_upper, std::move(functions), options_cpp - ); + )); } else if (options_cpp.linear_solver == "SUNLinSol_SPFGMR") { DEBUG("\tsetting SUNLinSol_SPFGMR_linear solver"); - casadiSolver = new CasadiSolverOpenMP_SPFGMR( + return std::unique_ptr(new CasadiSolverOpenMP_SPFGMR( atol_np, rel_tol, rhs_alg_id, @@ -134,12 +188,12 @@ CasadiSolver *create_casadi_solver( jac_bandwidth_upper, std::move(functions), options_cpp - ); + )); } else if (options_cpp.linear_solver == "SUNLinSol_SPGMR") { DEBUG("\tsetting SUNLinSol_SPGMR solver"); - casadiSolver = new CasadiSolverOpenMP_SPGMR( + return std::unique_ptr(new CasadiSolverOpenMP_SPGMR( atol_np, rel_tol, rhs_alg_id, @@ -150,12 +204,12 @@ CasadiSolver *create_casadi_solver( jac_bandwidth_upper, std::move(functions), options_cpp - ); + )); } else if (options_cpp.linear_solver == "SUNLinSol_SPTFQMR") { DEBUG("\tsetting SUNLinSol_SPGMR solver"); - casadiSolver = new CasadiSolverOpenMP_SPTFQMR( + return std::unique_ptr(new CasadiSolverOpenMP_SPTFQMR( atol_np, rel_tol, rhs_alg_id, @@ -166,12 +220,7 @@ CasadiSolver *create_casadi_solver( jac_bandwidth_upper, std::move(functions), options_cpp - ); + )); } - - if (casadiSolver == nullptr) { - throw std::invalid_argument("Unsupported solver requested"); - } - - return casadiSolver; + throw std::invalid_argument("Unsupported solver requested"); } diff --git a/pybamm/solvers/c_solvers/idaklu/casadi_solver.hpp b/pybamm/solvers/c_solvers/idaklu/casadi_solver.hpp index 335907a93a..0a15d9d1d3 100644 --- a/pybamm/solvers/c_solvers/idaklu/casadi_solver.hpp +++ b/pybamm/solvers/c_solvers/idaklu/casadi_solver.hpp @@ -1,14 +1,14 @@ #ifndef PYBAMM_IDAKLU_CREATE_CASADI_SOLVER_HPP #define PYBAMM_IDAKLU_CREATE_CASADI_SOLVER_HPP -#include "CasadiSolver.hpp" +#include "CasadiSolverGroup.hpp" /** - * Creates a concrete casadi solver given a linear solver, as specified in + * Creates a group of casadi solvers given a linear solver, as specified in * options_cpp.linear_solver. - * @brief Create a concrete casadi solver given a linear solver + * @brief Create a group of casadi solvers given a linear solver */ -CasadiSolver *create_casadi_solver( +CasadiSolverGroup *create_casadi_solver_group( int number_of_states, int number_of_parameters, const Function &rhs_alg, @@ -30,7 +30,29 @@ CasadiSolver *create_casadi_solver( const std::vector& var_casadi_fcns, const std::vector& dvar_dy_fcns, const std::vector& dvar_dp_fcns, - py::dict options + py::dict options, + const int nsolvers +); + +/** + * Creates a concrete casadi solver given a linear solver, as specified in + * options_cpp.linear_solver. + * @brief Create a concrete casadi solver given a linear solver + */ +std::unique_ptr create_casadi_solver( + std::unique_ptr functions, + int number_of_parameters, + const np_array_int &jac_times_cjmass_colptrs, + const np_array_int &jac_times_cjmass_rowvals, + const int jac_times_cjmass_nnz, + const int jac_bandwidth_lower, + const int jac_bandwidth_upper, + const int number_of_events, + np_array rhs_alg_id, + np_array atol_np, + double rel_tol, + int inputs_length, + Options options_cpp ); #endif // PYBAMM_IDAKLU_CREATE_CASADI_SOLVER_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/common.hpp b/pybamm/solvers/c_solvers/idaklu/common.hpp index e0abbb5a1d..f31156868e 100644 --- a/pybamm/solvers/c_solvers/idaklu/common.hpp +++ b/pybamm/solvers/c_solvers/idaklu/common.hpp @@ -29,8 +29,7 @@ #include namespace py = pybind11; -using np_array = py::array_t; -using np_array_dense = py::array_t; +using np_array = py::array_t; using np_array_int = py::array_t; #ifdef NDEBUG diff --git a/pybamm/solvers/c_solvers/idaklu/options.cpp b/pybamm/solvers/c_solvers/idaklu/options.cpp index efad4d5de0..a14d15a57f 100644 --- a/pybamm/solvers/c_solvers/idaklu/options.cpp +++ b/pybamm/solvers/c_solvers/idaklu/options.cpp @@ -5,7 +5,7 @@ using namespace std::string_literals; -Options::Options(py::dict options) +Options::Options(py::dict options, const int nsolvers) : print_stats(options["print_stats"].cast()), jacobian(options["jacobian"].cast()), preconditioner(options["preconditioner"].cast()), @@ -13,9 +13,13 @@ Options::Options(py::dict options) linear_solver(options["linear_solver"].cast()), precon_half_bandwidth(options["precon_half_bandwidth"].cast()), precon_half_bandwidth_keep(options["precon_half_bandwidth_keep"].cast()), - num_threads(options["num_threads"].cast()) + num_threads(options["num_threads"].cast() / nsolvers) { - + // need at least one thread + if (num_threads < 1) + { + num_threads = 1; + } using_sparse_matrix = true; using_banded_matrix = false; if (jacobian == "sparse") diff --git a/pybamm/solvers/c_solvers/idaklu/options.hpp b/pybamm/solvers/c_solvers/idaklu/options.hpp index b70d0f4a30..7870d01e25 100644 --- a/pybamm/solvers/c_solvers/idaklu/options.hpp +++ b/pybamm/solvers/c_solvers/idaklu/options.hpp @@ -18,7 +18,7 @@ struct Options { int precon_half_bandwidth; int precon_half_bandwidth_keep; int num_threads; - explicit Options(py::dict options); + explicit Options(py::dict options, const int nsolvers); }; diff --git a/pybamm/solvers/c_solvers/idaklu/solution.hpp b/pybamm/solvers/c_solvers/idaklu/solution.hpp index 92e22d02b6..7ead212251 100644 --- a/pybamm/solvers/c_solvers/idaklu/solution.hpp +++ b/pybamm/solvers/c_solvers/idaklu/solution.hpp @@ -17,6 +17,11 @@ class Solution { } + /** + * @brief default Constructor + */ + Solution() = default; + int flag; np_array t; np_array y; diff --git a/pybamm/solvers/idaklu_solver.py b/pybamm/solvers/idaklu_solver.py index 5ea8086b7d..900b3cf2e8 100644 --- a/pybamm/solvers/idaklu_solver.py +++ b/pybamm/solvers/idaklu_solver.py @@ -163,6 +163,7 @@ def _check_atol_type(self, atol, size): def set_up(self, model, inputs=None, t_eval=None, ics_only=False, batch_size=1): base_set_up_return = super().set_up(model, inputs, t_eval, ics_only, batch_size) + nbatches = len(inputs) // batch_size if isinstance(inputs, dict): inputs_list = [inputs] @@ -471,7 +472,7 @@ def sensfn(resvalS, t, y, inputs, yp, yS, ypS): "dvar_dp_idaklu_fcns": self.dvar_dp_idaklu_fcns, } - solver = idaklu.create_casadi_solver( + solver = idaklu.create_casadi_solver_group( number_of_states=nstates, number_of_parameters=self._setup["number_of_sensitivity_parameters"], rhs_alg=self._setup["rhs_algebraic"], @@ -494,6 +495,7 @@ def sensfn(resvalS, t, y, inputs, yp, yS, ypS): dvar_dy_fcns=self._setup["dvar_dy_idaklu_fcns"], dvar_dp_fcns=self._setup["dvar_dp_idaklu_fcns"], options=self._options, + nsolvers=nbatches, ) self._setup["solver"] = solver @@ -512,156 +514,185 @@ def sensfn(resvalS, t, y, inputs, yp, yS, ypS): return base_set_up_return - def _integrate_batch(self, model, t_eval, y0, y0S, inputs_list, inputs): + def _integrate( + self, model, t_eval, inputs_list=None, batched_inputs=None, nproc=None + ): """ - Overloads the _integrate_batch method from BaseSolver to use the IDAKLU solver + Overloads the _integrate method from BaseSolver to use the IDAKLU solver """ + inputs_list, batched_inputs, nbatches, batch_size, y0S_list = ( + self._handle_integrate_defaults(model, inputs_list, batched_inputs) + ) + # do this here cause y0 is set after set_up (calc consistent conditions) def handle_y0(y0): + if y0 is None: + return y0 if isinstance(y0, casadi.DM): y0 = y0.full() y0 = y0.flatten() return y0 - y0 = handle_y0(y0) + y0_list = [handle_y0(y0) for y0 in model.y0_list] + batched_inputs = [handle_y0(inputs) for inputs in batched_inputs] # only casadi solver needs sensitivity ics - if model.convert_to_format == "casadi" and y0S is not None: + if model.convert_to_format == "casadi" and y0S_list[0] is not None: # concatentate the senstivity initial conditions to the state vector - y0S = handle_y0(y0S) - y0full = np.concatenate([y0, y0S]) + y0S_list = [handle_y0(y0S) for y0S in y0S_list] + y0full_list = [ + np.concatenate([y0, y0S]) for y0, y0S in zip(y0_list, y0S_list) + ] else: - y0full = y0 + y0full_list = y0_list # solver works with ydot0 set to zero - ydot0full = np.zeros_like(y0full) + ydot0full_list = [np.zeros_like(y0full) for y0full in y0full_list] try: atol = model.atol except AttributeError: atol = self.atol + # solver takes individual y0s and ydot0s as rows in a 2d array + y0full = np.vstack(y0full_list) + ydot0full = np.vstack(ydot0full_list) + inputs = np.vstack(batched_inputs) + rtol = self.rtol - atol = self._check_atol_type(atol, y0.size) + atol = self._check_atol_type(atol, y0_list[0].size) timer = pybamm.Timer() if model.convert_to_format == "casadi": - sol = self._setup["solver"].solve( + sols = self._setup["solver"].solve( t_eval, y0full, ydot0full, inputs, ) else: - ydot0 = np.zeros_like(y0) - sol = idaklu.solve_python( - t_eval, - y0, - ydot0, - self._setup["resfn"], - self._setup["jac_class"].jac_res, - self._setup["sensfn"], - self._setup["jac_class"].get_jac_data, - self._setup["jac_class"].get_jac_row_vals, - self._setup["jac_class"].get_jac_col_ptrs, - self._setup["jac_class"].nnz, - self._setup["rootfn"], - self._setup["num_of_events"], - 1, - self._setup["ids"], - atol, - rtol, - inputs, - self._setup["number_of_sensitivity_parameters"], - ) + ydot0 = np.zeros_like(y0_list[0]) + sols = [] + for y0, inputs in zip(y0_list, batched_inputs): + inputs = inputs.reshape(-1, 1) + sols.append( + idaklu.solve_python( + t_eval, + y0, + ydot0, + self._setup["resfn"], + self._setup["jac_class"].jac_res, + self._setup["sensfn"], + self._setup["jac_class"].get_jac_data, + self._setup["jac_class"].get_jac_row_vals, + self._setup["jac_class"].get_jac_col_ptrs, + self._setup["jac_class"].nnz, + self._setup["rootfn"], + self._setup["num_of_events"], + 1, + self._setup["ids"], + atol, + rtol, + inputs, + self._setup["number_of_sensitivity_parameters"], + ) + ) integration_time = timer.time() - if sol.flag not in [0, 2]: - raise pybamm.SolverError("idaklu solver failed") + for sol in sols: + if sol.flag not in [0, 2]: + raise pybamm.SolverError("idaklu solver failed") number_of_sensitivity_parameters = self._setup[ "number_of_sensitivity_parameters" ] sensitivity_names = self._setup["sensitivity_names"] - t = sol.t - number_of_timesteps = t.size - number_of_states = y0.shape[0] + pybamm_sols = [] - sol_y = sol.y - sol_yS = sol.yS - if self.output_variables: - # Substitute empty vectors for state vector 'y' - y_out = np.zeros((number_of_timesteps * number_of_states, 0)) - else: - y_out = sol_y.reshape((number_of_timesteps, number_of_states)) - - # return sensitivity solution, we need to flatten yS to - # (#timesteps * #states (where t is changing the quickest),) - # to match format used by Solution - # note that yS is (n_p, n_t, n_y) - if number_of_sensitivity_parameters != 0: - yS_out = { - name: sol_yS[i].reshape(-1, 1) - for i, name in enumerate(sensitivity_names) - } - # add "all" stacked sensitivities ((#timesteps * #states,#sens_params)) - yS_out["all"] = np.hstack([yS_out[name] for name in sensitivity_names]) - else: - yS_out = False - - # 0 = solved for all t_eval - if sol.flag == 0: - termination = "final time" - # 2 = found root(s) - elif sol.flag == 2: - termination = "event" - - batchsols = pybamm.Solution.from_concatenated_state( - sol.t, - np.transpose(y_out), - model, - inputs_list, - np.array([t[-1]]), - np.transpose(y_out[-1])[:, np.newaxis], - termination, - sensitivities=yS_out, - ) - for s in batchsols: - s.integration_time = integration_time + for i, sol in enumerate(sols): + inputs_sublist = inputs_list[i * batch_size : (i + 1) * batch_size] + t = sol.t + number_of_timesteps = t.size + number_of_states = y0_list[0].shape[0] + + sol_y = sol.y + sol_yS = sol.yS if self.output_variables: - # Populate variables and sensititivies dictionaries directly - number_of_samples = sol_y.shape[0] // number_of_timesteps - sol_y = sol_y.reshape((number_of_timesteps, number_of_samples)) - startk = 0 - for _, var in enumerate(self.output_variables): - # ExplicitTimeIntegral's are not computed as part of the solver and - # do not need to be converted - if isinstance( - model.variables_and_events[var], pybamm.ExplicitTimeIntegral - ): - continue - len_of_var = ( - self._setup["var_casadi_fcns"][var](0, 0, 0).sparsity().nnz() - ) - s._variables[var] = pybamm.ProcessedVariableComputed( - [model.variables_and_events[var]], - [self._setup["var_casadi_fcns"][var]], - [sol_y[:, startk : (startk + len_of_var)]], - s, - ) - # Add sensitivities - s[var]._sensitivities = {} - if model.calculate_sensitivities: - for paramk, param in enumerate(inputs_list[0].keys()): - s[var].add_sensitivity( - param, - [sol_yS[:, startk : (startk + len_of_var), paramk]], - ) - startk += len_of_var - return batchsols + # Substitute empty vectors for state vector 'y' + y_out = np.zeros((number_of_timesteps * number_of_states, 0)) + else: + y_out = sol_y.reshape((number_of_timesteps, number_of_states)) + + # return sensitivity solution, we need to flatten yS to + # (#timesteps * #states (where t is changing the quickest),) + # to match format used by Solution + # note that yS is (n_p, n_t, n_y) + if number_of_sensitivity_parameters != 0: + yS_out = { + name: sol_yS[i].reshape(-1, 1) + for i, name in enumerate(sensitivity_names) + } + # add "all" stacked sensitivities ((#timesteps * #states,#sens_params)) + yS_out["all"] = np.hstack([yS_out[name] for name in sensitivity_names]) + else: + yS_out = False + + # 0 = solved for all t_eval + if sol.flag == 0: + termination = "final time" + # 2 = found root(s) + elif sol.flag == 2: + termination = "event" + + batchsols = pybamm.Solution.from_concatenated_state( + sol.t, + np.transpose(y_out), + model, + inputs_sublist, + np.array([t[-1]]), + np.transpose(y_out[-1])[:, np.newaxis], + termination, + sensitivities=yS_out, + ) + for s in batchsols: + s.integration_time = integration_time + if self.output_variables: + # Populate variables and sensititivies dictionaries directly + number_of_samples = sol_y.shape[0] // number_of_timesteps + sol_y = sol_y.reshape((number_of_timesteps, number_of_samples)) + startk = 0 + for _, var in enumerate(self.output_variables): + # ExplicitTimeIntegral's are not computed as part of the solver and + # do not need to be converted + if isinstance( + model.variables_and_events[var], pybamm.ExplicitTimeIntegral + ): + continue + len_of_var = ( + self._setup["var_casadi_fcns"][var](0, 0, 0) + .sparsity() + .nnz() + ) + s._variables[var] = pybamm.ProcessedVariableComputed( + [model.variables_and_events[var]], + [self._setup["var_casadi_fcns"][var]], + [sol_y[:, startk : (startk + len_of_var)]], + s, + ) + # Add sensitivities + s[var]._sensitivities = {} + if model.calculate_sensitivities: + for paramk, param in enumerate(inputs_list[0].keys()): + s[var].add_sensitivity( + param, + [sol_yS[:, startk : (startk + len_of_var), paramk]], + ) + startk += len_of_var + pybamm_sols += batchsols + return pybamm_sols def jaxify( self, diff --git a/tests/unit/test_solvers/test_idaklu_solver.py b/tests/unit/test_solvers/test_idaklu_solver.py index ee13c8efd8..cf541f6e88 100644 --- a/tests/unit/test_solvers/test_idaklu_solver.py +++ b/tests/unit/test_solvers/test_idaklu_solver.py @@ -253,6 +253,7 @@ def test_multiple_inputs_initial_conditions(self): # check solution for inputs, solution in zip(inputs_list, solutions): + print("checking input", inputs) np.testing.assert_array_equal(solution.t, t_eval) np.testing.assert_allclose( solution.y[0],