From aa3b66a58f78dbe3d62ef5be4f99ef1da08478f4 Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Thu, 17 Mar 2022 13:52:17 +0000 Subject: [PATCH] #1863 finish draft of casadi solver, and restructure cpp code --- pybamm/solvers/c_solvers/idaklu.cpp | 46 +--- pybamm/solvers/c_solvers/idaklu.hpp | 14 + pybamm/solvers/c_solvers/idaklu_casadi.cpp | 295 +++++++++++++-------- pybamm/solvers/c_solvers/idaklu_casadi.hpp | 27 ++ pybamm/solvers/c_solvers/idaklu_python.cpp | 54 ++++ pybamm/solvers/c_solvers/idaklu_python.hpp | 16 ++ pybamm/solvers/c_solvers/solution.cpp | 0 pybamm/solvers/c_solvers/solution.hpp | 21 ++ setup.py | 6 +- 9 files changed, 327 insertions(+), 152 deletions(-) create mode 100644 pybamm/solvers/c_solvers/idaklu.hpp create mode 100644 pybamm/solvers/c_solvers/idaklu_casadi.hpp create mode 100644 pybamm/solvers/c_solvers/idaklu_python.cpp create mode 100644 pybamm/solvers/c_solvers/idaklu_python.hpp create mode 100644 pybamm/solvers/c_solvers/solution.cpp create mode 100644 pybamm/solvers/c_solvers/solution.hpp diff --git a/pybamm/solvers/c_solvers/idaklu.cpp b/pybamm/solvers/c_solvers/idaklu.cpp index d08796ba95..bf655de059 100644 --- a/pybamm/solvers/c_solvers/idaklu.cpp +++ b/pybamm/solvers/c_solvers/idaklu.cpp @@ -1,24 +1,5 @@ -#include -#include +#include "idaklu_python.hpp" -#include /* prototypes for IDAS fcts., consts. */ -#include /* access to serial N_Vector */ -#include /* defs. of SUNRabs, SUNRexp, etc. */ -#include /* defs. of realtype, sunindextype */ -#include /* access to KLU linear solver */ -#include /* access to sparse SUNMatrix */ - -#include -#include -#include -#include - -//#include -namespace py = pybind11; - - -using np_array = py::array_t; -PYBIND11_MAKE_OPAQUE(std::vector); using residual_type = std::function; using sensitivities_type = std::function&, realtype, const np_array&, @@ -177,8 +158,8 @@ int jacobian(realtype tt, realtype cj, N_Vector yy, N_Vector yp, np_array jac_np_row_vals = python_functions.get_jac_row_vals(); int n_row_vals = jac_np_row_vals.request().size; - auto jac_np_row_vals_ptr = jac_np_row_vals.unchecked<1>(); + auto jac_np_row_vals_ptr = jac_np_row_vals.unchecked<1>(); // just copy across row vals (this might be unneeded) for (i = 0; i < n_row_vals; i++) { @@ -308,7 +289,7 @@ class Solution }; /* main program */ -Solution solve(np_array t_np, np_array y0_np, np_array yp0_np, +Solution solve_python(np_array t_np, np_array y0_np, np_array yp0_np, residual_type res, jacobian_type jac, sensitivities_type sens, jac_get_type gjd, jac_get_type gjrv, jac_get_type gjcp, @@ -497,24 +478,3 @@ Solution solve(np_array t_np, np_array y0_np, np_array yp0_np, return sol; } -PYBIND11_MODULE(idaklu, m) -{ - m.doc() = "sundials solvers"; // optional module docstring - - py::bind_vector>(m, "VectorNdArray"); - - m.def("solve", &solve, "The solve function", py::arg("t"), py::arg("y0"), - py::arg("yp0"), py::arg("res"), py::arg("jac"), py::arg("sens"), - py::arg("get_jac_data"), - py::arg("get_jac_row_vals"), py::arg("get_jac_col_ptr"), py::arg("nnz"), - py::arg("events"), py::arg("number_of_events"), py::arg("use_jacobian"), - py::arg("rhs_alg_id"), py::arg("atol"), py::arg("rtol"), - py::arg("number_of_sensitivity_parameters"), - py::return_value_policy::take_ownership); - - py::class_(m, "solution") - .def_readwrite("t", &Solution::t) - .def_readwrite("y", &Solution::y) - .def_readwrite("yS", &Solution::yS) - .def_readwrite("flag", &Solution::flag); -} diff --git a/pybamm/solvers/c_solvers/idaklu.hpp b/pybamm/solvers/c_solvers/idaklu.hpp new file mode 100644 index 0000000000..a1f325cde0 --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu.hpp @@ -0,0 +1,14 @@ +#ifndef PYBAMM_IDAKLU_HPP +#define PYBAMM_IDAKLU_HPP + +#include "solution.hpp" + +Solution solve_python(np_array t_np, np_array y0_np, np_array yp0_np, + residual_type res, jacobian_type jac, + sensitivities_type sens, + jac_get_type gjd, jac_get_type gjrv, jac_get_type gjcp, + int nnz, event_type event, + int number_of_events, int use_jacobian, np_array rhs_alg_id, + np_array atol_np, double rel_tol, int number_of_parameters); + +#endif // PYBAMM_IDAKLU_HPP diff --git a/pybamm/solvers/c_solvers/idaklu_casadi.cpp b/pybamm/solvers/c_solvers/idaklu_casadi.cpp index b4ea08ce26..47282d1d5e 100644 --- a/pybamm/solvers/c_solvers/idaklu_casadi.cpp +++ b/pybamm/solvers/c_solvers/idaklu_casadi.cpp @@ -1,26 +1,10 @@ -#include -#include - -#include /* prototypes for IDAS fcts., consts. */ -#include /* access to serial N_Vector */ -#include /* defs. of SUNRabs, SUNRexp, etc. */ -#include /* defs. of realtype, sunindextype */ -#include /* access to KLU linear solver */ -#include /* access to sparse SUNMatrix */ - -#include -#include -#include -#include - -#include + +#include "idaklu_python.hpp" + #include //#include -namespace py = pybind11; - -using Function = casadi::Function -using casadi_int= casadi::casadi_int +using casadi_int = casadi::casadi_int using casadi_axpy = casadi::casadi_axpy class CasadiFunction { @@ -61,22 +45,45 @@ class PybammFunctions { int number_of_events; CasadiFunction rhs_alg; CasadiFunction sens; - CasadiFunction jac; + CasadiFunction jac_times_cjmass; + const np_array &jac_times_cjmass_rowvals, + const np_array &jac_times_cjmass_colptrs, + CasadiFunction jac_action; + CasadiFunction jacp_action; CasadiFunction mass_action; CasadiFunction event; - PybammFunctions(const Function &rhs_alg, const Function &jac, + PybammFunctions(const Function &rhs_alg, + const Function &jac_times_cjmass, + const np_array &jac_times_cjmass_rowvals, + const np_array &jac_times_cjmass_colptrs, + const Function &jac_action, + const Function &jacp_action, const Function &mass_action, const Function &sens, const Function &event, const int n_s, int n_e, const int n_p) : number_of_states(n_s), number_of_events(n_e), number_of_parameters(n_p), - res(res), jac(jac), + rhs_alg(rhs_alg), + jac_times_cjmass(jac_times_cjmass), + jac_times_cjmass_rowvals(jac_times_cjmass_rowvals), + jac_times_cjmass_colptrs(jac_times_cjmass_colptrs), + jac_action(jac_times), + jacp_action(jacp_action), mass_action(mass_action), sens(sens), - event(event) + event(event), + tmp(number_of_states) {} + + realtype *get_tmp() { + return tmp.data() + } + +private: + std::vector tmp; + }; int residual(realtype tres, N_Vector yy, N_Vector yp, N_Vector rr, @@ -105,10 +112,67 @@ int residual(realtype tres, N_Vector yy, N_Vector yp, N_Vector rr, return 0; } +// Purpose This function computes the product Jv of the DAE system Jacobian J +// (or an approximation to it) and a given vector v, where J is defined by Eq. (2.6). +// J = ∂F/∂y + cj ∂F/∂y˙ +// Arguments tt is the current value of the independent variable. +// yy is the current value of the dependent variable vector, y(t). +// yp is the current value of ˙y(t). +// rr is the current value of the residual vector F(t, y, y˙). +// v is the vector by which the Jacobian must be multiplied to the right. +// Jv is the computed output vector. +// cj is the scalar in the system Jacobian, proportional to the inverse of the step +// size (α in Eq. (2.6) ). +// user data is a pointer to user data, the same as the user data parameter passed to +// IDASetUserData. +// tmp1 +// tmp2 are pointers to memory allocated for variables of type N Vector which can +// be used by IDALsJacTimesVecFn as temporary storage or work space. +int jtimes(realtype tt, N_Vector yy, N_Vector yp, N_Vector rr, + N_Vector v, N_Vector Jv, realtype cj, void *user data, + N_Vector tmp1, N_Vector tmp2) { + PybammFunctions *p_python_functions = + static_cast(user_data); + + // rr has ∂F/∂y v + p_python_functions->jac_action.m_arg[0] = &tres; + p_python_functions->jac_action.m_arg[1] = NV_DATA_S(yy); + p_python_functions->jac_action.m_arg[2] = &cj; + p_python_functions->jac_action.m_arg[3] = NV_DATA_S(v); + p_python_functions->jac_action.m_res[0] = NV_DATA_S(rr); + p_python_functions->jac_action(); + + // tmp1 has -∂F/∂y˙ v + p_python_functions->mass_action.m_arg[0] = NV_DATA_S(v); + p_python_functions->mass_action.m_res[0] = NV_DATA_S(tmp1); + p_python_functions->mass_action(); + + // AXPY: y <- a*x + y + // rr has ∂F/∂y v + cj ∂F/∂y˙ v + const ns = p_python_functions->number_of_states; + casadi_axpy(ns, -cj, NV_DATA_S(tmp1), NV_DATA_S(rr)); + + return 0; +} + + +// Arguments tt is the current value of the independent variable t. +// cj is the scalar in the system Jacobian, proportional to the inverse of the step +// size (α in Eq. (2.6) ). +// yy is the current value of the dependent variable vector, y(t). +// yp is the current value of ˙y(t). +// rr is the current value of the residual vector F(t, y, y˙). +// Jac is the output (approximate) Jacobian matrix (of type SUNMatrix), J = +// ∂F/∂y + cj ∂F/∂y˙. +// user data is a pointer to user data, the same as the user data parameter passed to +// IDASetUserData. +// tmp1 +// tmp2 +// tmp3 are pointers to memory allocated for variables of type N Vector which can +// be used by IDALsJacFn function as temporary storage or work space. int jacobian(realtype tt, realtype cj, N_Vector yy, N_Vector yp, N_Vector resvec, SUNMatrix JJ, void *user_data, N_Vector tempv1, - N_Vector tempv2, N_Vector tempv3) -{ + N_Vector tempv2, N_Vector tempv3) { PybammFunctions *p_python_functions = static_cast(user_data); @@ -119,12 +183,34 @@ int jacobian(realtype tt, realtype cj, N_Vector yy, N_Vector yp, realtype *jac_data = SUNSparseMatrix_Data(JJ); // args are t, y, cj, put result in jacobian data matrix - p_python_functions->jac.m_arg[0] = &tres ; - p_python_functions->jac.m_arg[1] = NV_DATA_S(yy); - p_python_functions->jac.m_arg[2] = &cj; - p_python_functions->jac.m_res[0] = jac_data; + p_python_functions->jac_times_cjmass.m_arg[0] = &tt; + p_python_functions->jac_times_cjmass.m_arg[1] = NV_DATA_S(yy); + p_python_functions->jac_times_cjmass.m_arg[2] = &cj; + p_python_functions->jac_times_cjmass.m_res[0] = jac_data; p_python_functions->jac(); + // row vals and col ptrs + const np_array &jac_times_cjmass_rowvals = python_functions.jac_times_cjmass_rowvals; + const int n_row_vals = jac_times_cjmass_rowvals.request().size; + auto p_jac_times_cjmass_rowvals = jac_times_cjmass_rowvals.unchecked<1>(); + + // just copy across row vals (do I need to do this every time?) + // (or just in the setup?) + for (i = 0; i < n_row_vals; i++) { + std::cout << "check row vals " << jac_rowvals[i] << " " << p_jac_times_cjmass_rowvals[i] << std::endl; + jac_rowvals[i] = p_jac_times_cjmass_rowvals[i]; + } + + const np_array &jac_times_cjmass_colptrs = python_functions.jac_times_cjmass_colptrs; + const int n_col_ptrs = jac_times_cjmass_colptrs.request().size; + auto p_jac_times_cjmass_colptrs = jac_times_cjmass_colptrs.unchecked<1>(); + + // just copy across col ptrs (do I need to do this every time?) + for (i = 0; i < n_col_ptrs; i++) { + std::cout << "check col ptrs " << jac_colptrs[i] << " " << p_jac_times_cjmass_colptrs[i] << std::endl; + jac_colptrs[i] = p_jac_times_cjmass_colptrs[i]; + } + return (0); } @@ -135,17 +221,14 @@ int events(realtype t, N_Vector yy, N_Vector yp, realtype *events_ptr, static_cast(user_data); // args are t, y, put result in events_ptr - p_python_functions->jac.m_arg[0] = &tres ; - p_python_functions->jac.m_arg[1] = NV_DATA_S(yy); - p_python_functions->jac.m_res[0] = events_ptr; - p_python_functions->jac(); + p_python_functions->events.m_arg[0] = &tres ; + p_python_functions->events.m_arg[1] = NV_DATA_S(yy); + p_python_functions->events.m_res[0] = events_ptr; + p_python_functions->events(); return (0); } -int sensitivities(int Ns, realtype t, N_Vector yy, N_Vector yp, - N_Vector resval, N_Vector *yS, N_Vector *ypS, N_Vector *resvalS, - void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3) { // This function computes the sensitivity residual for all sensitivity // equations. It must compute the vectors // (∂F/∂y)s i (t)+(∂F/∂ ẏ) ṡ i (t)+(∂F/∂p i ) and store them in resvalS[i]. @@ -167,6 +250,10 @@ int sensitivities(int Ns, realtype t, N_Vector yy, N_Vector yp, // occurred (in which case idas will attempt to correct), // or a negative value if it failed unrecoverably (in which case the integration is halted and IDA SRES FAIL is returned) // +int sensitivities(int Ns, realtype t, N_Vector yy, N_Vector yp, + N_Vector resval, N_Vector *yS, N_Vector *ypS, N_Vector *resvalS, + void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3) { + PybammFunctions *p_python_functions = static_cast(user_data); @@ -177,66 +264,52 @@ int sensitivities(int Ns, realtype t, N_Vector yy, N_Vector yp, for (int i = 0; i < np; i++) { p_python_functions->sens.m_res[i] = NV_DATA_S(resvalS[i]); } + // resvalsS now has (∂F/∂p i ) p_python_functions->sens(); - // memory managed by sundials, so pass a destructor that does nothing - auto state_vector_shape = std::vector{n, 1}; - np_array y_np = np_array(state_vector_shape, N_VGetArrayPointer(yy), - py::capsule(&yy, [](void* p) {})); - np_array yp_np = np_array(state_vector_shape, N_VGetArrayPointer(yp), - py::capsule(&yp, [](void* p) {})); - - std::vector yS_np(np); - for (int i = 0; i < np; i++) { - auto capsule = py::capsule(yS + i, [](void* p) {}); - yS_np[i] = np_array(state_vector_shape, N_VGetArrayPointer(yS[i]), capsule); - } - - std::vector ypS_np(np); for (int i = 0; i < np; i++) { - auto capsule = py::capsule(ypS + i, [](void* p) {}); - ypS_np[i] = np_array(state_vector_shape, N_VGetArrayPointer(ypS[i]), capsule); + // put (∂F/∂y)s i (t) in tmp1 + p_python_functions->jac_action.m_arg[0] = &tres; + p_python_functions->jac_action.m_arg[1] = NV_DATA_S(yy); + p_python_functions->jac_action.m_arg[2] = NV_DATA_S(yS[i]); + p_python_functions->jac_action.m_res[1] = NV_DATA_S(tmp1); + p_python_functions->jac_action(); + + // put -(∂F/∂ ẏ) ṡ i (t) in tmp2 + p_python_functions->mass_action.m_arg[0] = NV_DATA_S(ypS[i]); + p_python_functions->mass_action.m_res[1] = NV_DATA_S(tmp2); + p_python_functions->mass_action(); + + // (∂F/∂y)s i (t)+(∂F/∂ ẏ) ṡ i (t)+(∂F/∂p i ) + // AXPY: y <- a*x + y + const ns = p_python_functions->number_of_states; + casadi_axpy(ns, 1., NV_DATA_S(tmp1), NV_DATA_S(resvalsS[i])); + casadi_axpy(ns, -1., NV_DATA_S(tmp2), NV_DATA_S(resvalsS[i])); } - std::vector resvalS_np(np); - for (int i = 0; i < np; i++) { - auto capsule = py::capsule(resvalS + i, [](void* p) {}); - resvalS_np[i] = np_array(state_vector_shape, - N_VGetArrayPointer(resvalS[i]), capsule); - } - - realtype *ptr1 = static_cast(resvalS_np[0].request().ptr); - const realtype* resvalSval = N_VGetArrayPointer(resvalS[0]); - - python_functions.sensitivities(resvalS_np, t, y_np, yp_np, yS_np, ypS_np); - return 0; } -class Solution -{ -public: - Solution(int retval, np_array t_np, np_array y_np, np_array yS_np) - : flag(retval), t(t_np), y(y_np), yS(yS_np) - { - } - int flag; - np_array t; - np_array y; - np_array yS; -}; /* main program */ -Solution solve(np_array t_np, np_array y0_np, np_array yp0_np, - residual_type res, jacobian_type jac, - sensitivities_type sens, - jac_get_type gjd, jac_get_type gjrv, jac_get_type gjcp, - int nnz, event_type event, - int number_of_events, int use_jacobian, np_array rhs_alg_id, + +Solution solve_casadi(np_array t_np, np_array y0_np, np_array yp0_np, + const Function &rhs_alg, + const Function &jac_times_cjmass, + const np_array &jac_times_cjmass_rowvals, + const np_array &jac_times_cjmass_colptrs, + const int jac_times_cjmass_nnz, + const Function &jac_action, + const Function &jacp_action, + const Function &mass_action, + const Function &sens, + const Function &event, + const int number_of_events, + int use_jacobian, + np_array rhs_alg_id, np_array atol_np, double rel_tol, int number_of_parameters) { - IdasInterface interface("pybamm_idaklu_casadi", dae); auto t = t_np.unchecked<1>(); auto y0 = y0_np.unchecked<1>(); @@ -253,6 +326,8 @@ Solution solve(np_array t_np, np_array y0_np, np_array yp0_np, SUNMatrix J; SUNLinearSolver LS; + + // allocate vectors yy = N_VNew_Serial(number_of_states); yp = N_VNew_Serial(number_of_states); @@ -299,15 +374,40 @@ Solution solve(np_array t_np, np_array y0_np, np_array yp0_np, IDARootInit(ida_mem, number_of_events, events); // set pybamm functions by passing pointer to it - PybammFunctions pybamm_functions(res, jac, sens, gjd, gjrv, gjcp, event, - number_of_states, number_of_events, - number_of_parameters); + PybammFunctions pybamm_functions( + rhs_alg, + jac_times_cjmass, + jac_times_cjmass_rowvals, + jac_times_cjmass_colptrs, + jac_action, jacp_action, mass_action, + sens, event, + number_of_states, number_of_events, + number_of_parameters); + void *user_data = &pybamm_functions; IDASetUserData(ida_mem, user_data); // set linear solver J = SUNSparseMatrix(number_of_states, number_of_states, nnz, CSR_MAT); + // copy across row vals and col ptrs + const int n_row_vals = jac_times_cjmass_rowvals.request().size; + auto p_jac_times_cjmass_rowvals = jac_times_cjmass_rowvals.unchecked<1>(); + + sunindextype *jac_rowvals = SUNSparseMatrix_IndexValues(J); + for (i = 0; i < n_row_vals; i++) { + jac_rowvals[i] = p_jac_times_cjmass_rowvals[i]; + } + + const np_array &jac_times_cjmass_colptrs = python_functions.jac_times_cjmass_colptrs; + const int n_col_ptrs = jac_times_cjmass_colptrs.request().size; + auto p_jac_times_cjmass_colptrs = jac_times_cjmass_colptrs.unchecked<1>(); + + sunindextype *jac_colptrs = SUNSparseMatrix_IndexPointers(J); + for (i = 0; i < n_col_ptrs; i++) { + jac_colptrs[i] = p_jac_times_cjmass_colptrs[i]; + } + LS = SUNLinSol_KLU(yy, J); IDASetLinearSolver(ida_mem, LS, J); @@ -419,24 +519,3 @@ Solution solve(np_array t_np, np_array y0_np, np_array yp0_np, return sol; } -PYBIND11_MODULE(idaklu, m) -{ - m.doc() = "sundials solvers"; // optional module docstring - - py::bind_vector>(m, "VectorNdArray"); - - m.def("solve", &solve, "The solve function", py::arg("t"), py::arg("y0"), - py::arg("yp0"), py::arg("res"), py::arg("jac"), py::arg("sens"), - py::arg("get_jac_data"), - py::arg("get_jac_row_vals"), py::arg("get_jac_col_ptr"), py::arg("nnz"), - py::arg("events"), py::arg("number_of_events"), py::arg("use_jacobian"), - py::arg("rhs_alg_id"), py::arg("atol"), py::arg("rtol"), - py::arg("number_of_sensitivity_parameters"), - py::return_value_policy::take_ownership); - - py::class_(m, "solution") - .def_readwrite("t", &Solution::t) - .def_readwrite("y", &Solution::y) - .def_readwrite("yS", &Solution::yS) - .def_readwrite("flag", &Solution::flag); -} diff --git a/pybamm/solvers/c_solvers/idaklu_casadi.hpp b/pybamm/solvers/c_solvers/idaklu_casadi.hpp new file mode 100644 index 0000000000..d6113e51d1 --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu_casadi.hpp @@ -0,0 +1,27 @@ + +#ifndef PYBAMM_IDAKLU_CASADI_HPP +#define PYBAMM_IDAKLU_CASADI_HPP + +#include "solution.hpp" +#include + +using Function = casadi::Function + +Solution solve_casadi(np_array t_np, np_array y0_np, np_array yp0_np, + const Function &rhs_alg, + const Function &jac_times_cjmass, + const np_array &jac_times_cjmass_rowvals, + const np_array &jac_times_cjmass_colptrs, + const int jac_times_cjmass_nnz, + const Function &jac_action, + const Function &jacp_action, + const Function &mass_action, + const Function &sens, + const Function &event, + const int number_of_events, + int use_jacobian, + np_array rhs_alg_id, + np_array atol_np, double rel_tol, int number_of_parameters) + + +#endif // PYBAMM_IDAKLU_CASADI_HPP diff --git a/pybamm/solvers/c_solvers/idaklu_python.cpp b/pybamm/solvers/c_solvers/idaklu_python.cpp new file mode 100644 index 0000000000..9409d636e3 --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu_python.cpp @@ -0,0 +1,54 @@ + +#include "idaklu_python.hpp" +#include "idaklu.hpp" +#include "idaklu_casadi.hpp" + +#include +#include +#include + +PYBIND11_MAKE_OPAQUE(std::vector); + +PYBIND11_MODULE(idaklu, m) +{ + m.doc() = "sundials solvers"; // optional module docstring + + py::bind_vector>(m, "VectorNdArray"); + + m.def("solve_python", &solve_python, "The solve function for python evaluators", + py::arg("t"), py::arg("y0"), + py::arg("yp0"), py::arg("res"), py::arg("jac"), py::arg("sens"), + py::arg("get_jac_data"), + py::arg("get_jac_row_vals"), py::arg("get_jac_col_ptr"), py::arg("nnz"), + py::arg("events"), py::arg("number_of_events"), py::arg("use_jacobian"), + py::arg("rhs_alg_id"), py::arg("atol"), py::arg("rtol"), + py::arg("number_of_sensitivity_parameters"), + py::return_value_policy::take_ownership); + + m.def("solve_casadi", &solve_casadi, "The solve function for casadi evaluators", + py::arg("t"), py::arg("y0"), py::arg("yp0"), + py::arg("rhs_alg"), + py::arg("jac_times_cjmass"), + py::arg("jac_times_cjmass_rowvals"), + py::arg("jac_times_cjmass_colptrs"), + py::arg("jac_times_cjmass_nnz"), + py::arg("jac_action"), + py::arg("jacp_action"), + py::arg("mass_action"), + py::arg("sens"), + py::arg("events"), py::arg("number_of_events"), + py::arg("use_jacobian"), + py::arg("rhs_alg_id"), + py::arg("atol"), py::arg("rtol"), + py::arg("number_of_sensitivity_parameters"), + py::return_value_policy::take_ownership); + + + py::class_(m, "solution") + .def_readwrite("t", &Solution::t) + .def_readwrite("y", &Solution::y) + .def_readwrite("yS", &Solution::yS) + .def_readwrite("flag", &Solution::flag); +} + + diff --git a/pybamm/solvers/c_solvers/idaklu_python.hpp b/pybamm/solvers/c_solvers/idaklu_python.hpp new file mode 100644 index 0000000000..b5f6e04461 --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu_python.hpp @@ -0,0 +1,16 @@ +#ifndef PYBAMM_IDAKLU_PYTHON_HPP +#define PYBAMM_IDAKLU_PYTHON_HPP + +#include /* prototypes for IDAS fcts., consts. */ +#include /* access to serial N_Vector */ +#include /* defs. of SUNRabs, SUNRexp, etc. */ +#include /* defs. of realtype, sunindextype */ +#include /* access to KLU linear solver */ +#include /* access to sparse SUNMatrix */ + +#include + +namespace py = pybind11; +using np_array = py::array_t; + +#endif // PYBAMM_IDAKLU_PYTHON_HPP diff --git a/pybamm/solvers/c_solvers/solution.cpp b/pybamm/solvers/c_solvers/solution.cpp new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pybamm/solvers/c_solvers/solution.hpp b/pybamm/solvers/c_solvers/solution.hpp new file mode 100644 index 0000000000..c22dcf2b76 --- /dev/null +++ b/pybamm/solvers/c_solvers/solution.hpp @@ -0,0 +1,21 @@ +#ifndef PYBAMM_SOLUTION_HPP +#define PYBAMM_SOLUTION_HPP + +#include "idaklu_python.hpp" + +class Solution +{ +public: + Solution(int retval, np_array t_np, np_array y_np, np_array yS_np) + : flag(retval), t(t_np), y(y_np), yS(yS_np) + { + } + + int flag; + np_array t; + np_array y; + np_array yS; +}; + + +#endif // PYBAMM_SOLUTION_HPP diff --git a/setup.py b/setup.py index 06dfc4f2aa..7bfbcbfb4b 100644 --- a/setup.py +++ b/setup.py @@ -147,7 +147,11 @@ def compile_KLU(): pybamm_data.append("./plotting/pybamm.mplstyle") pybamm_data.append("../CMakeBuild.py") -idaklu_ext = Extension("pybamm.solvers.idaklu", ["pybamm/solvers/c_solvers/idaklu.cpp"]) +idaklu_ext = Extension("pybamm.solvers.idaklu", [ + "pybamm/solvers/c_solvers/idaklu.cpp" + "pybamm/solvers/c_solvers/idaklu_casadi.cpp" + "pybamm/solvers/c_solvers/idaklu_python.cpp" +]) ext_modules = [idaklu_ext] if compile_KLU() else [] # Defines __version__