Skip to content

Commit

Permalink
#1863 finish draft of casadi solver, and restructure cpp code
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Mar 17, 2022
1 parent 6910821 commit aa3b66a
Show file tree
Hide file tree
Showing 9 changed files with 327 additions and 152 deletions.
46 changes: 3 additions & 43 deletions pybamm/solvers/c_solvers/idaklu.cpp
Original file line number Diff line number Diff line change
@@ -1,24 +1,5 @@
#include <math.h>
#include <stdio.h>
#include "idaklu_python.hpp"

#include <idas/idas.h> /* prototypes for IDAS fcts., consts. */
#include <nvector/nvector_serial.h> /* access to serial N_Vector */
#include <sundials/sundials_math.h> /* defs. of SUNRabs, SUNRexp, etc. */
#include <sundials/sundials_types.h> /* defs. of realtype, sunindextype */
#include <sunlinsol/sunlinsol_klu.h> /* access to KLU linear solver */
#include <sunmatrix/sunmatrix_sparse.h> /* access to sparse SUNMatrix */

#include <pybind11/functional.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>

//#include <iostream>
namespace py = pybind11;


using np_array = py::array_t<realtype>;
PYBIND11_MAKE_OPAQUE(std::vector<np_array>);
using residual_type = std::function<np_array(realtype, np_array, np_array)>;
using sensitivities_type = std::function<void(
std::vector<np_array>&, realtype, const np_array&,
Expand Down Expand Up @@ -177,8 +158,8 @@ int jacobian(realtype tt, realtype cj, N_Vector yy, N_Vector yp,

np_array jac_np_row_vals = python_functions.get_jac_row_vals();
int n_row_vals = jac_np_row_vals.request().size;
auto jac_np_row_vals_ptr = jac_np_row_vals.unchecked<1>();

auto jac_np_row_vals_ptr = jac_np_row_vals.unchecked<1>();
// just copy across row vals (this might be unneeded)
for (i = 0; i < n_row_vals; i++)
{
Expand Down Expand Up @@ -308,7 +289,7 @@ class Solution
};

/* main program */
Solution solve(np_array t_np, np_array y0_np, np_array yp0_np,
Solution solve_python(np_array t_np, np_array y0_np, np_array yp0_np,
residual_type res, jacobian_type jac,
sensitivities_type sens,
jac_get_type gjd, jac_get_type gjrv, jac_get_type gjcp,
Expand Down Expand Up @@ -497,24 +478,3 @@ Solution solve(np_array t_np, np_array y0_np, np_array yp0_np,
return sol;
}

PYBIND11_MODULE(idaklu, m)
{
m.doc() = "sundials solvers"; // optional module docstring

py::bind_vector<std::vector<np_array>>(m, "VectorNdArray");

m.def("solve", &solve, "The solve function", py::arg("t"), py::arg("y0"),
py::arg("yp0"), py::arg("res"), py::arg("jac"), py::arg("sens"),
py::arg("get_jac_data"),
py::arg("get_jac_row_vals"), py::arg("get_jac_col_ptr"), py::arg("nnz"),
py::arg("events"), py::arg("number_of_events"), py::arg("use_jacobian"),
py::arg("rhs_alg_id"), py::arg("atol"), py::arg("rtol"),
py::arg("number_of_sensitivity_parameters"),
py::return_value_policy::take_ownership);

py::class_<Solution>(m, "solution")
.def_readwrite("t", &Solution::t)
.def_readwrite("y", &Solution::y)
.def_readwrite("yS", &Solution::yS)
.def_readwrite("flag", &Solution::flag);
}
14 changes: 14 additions & 0 deletions pybamm/solvers/c_solvers/idaklu.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#ifndef PYBAMM_IDAKLU_HPP
#define PYBAMM_IDAKLU_HPP

#include "solution.hpp"

Solution solve_python(np_array t_np, np_array y0_np, np_array yp0_np,
residual_type res, jacobian_type jac,
sensitivities_type sens,
jac_get_type gjd, jac_get_type gjrv, jac_get_type gjcp,
int nnz, event_type event,
int number_of_events, int use_jacobian, np_array rhs_alg_id,
np_array atol_np, double rel_tol, int number_of_parameters);

#endif // PYBAMM_IDAKLU_HPP
Loading

0 comments on commit aa3b66a

Please sign in to comment.