From 8cf5d386989ebfcba4f124f6fd28e04b26bd911e Mon Sep 17 00:00:00 2001 From: Luca Bertagna Date: Tue, 11 Jun 2024 13:20:34 -0600 Subject: [PATCH] EAMxx: add support for passing MPI comm from py to C in pyeamxx --- components/eamxx/cmake/Findmpi4py.cmake | 36 ++++++++++++++++++++++ components/eamxx/src/python/CMakeLists.txt | 3 +- components/eamxx/src/python/pyeamxx.cpp | 22 ++++++++++--- components/eamxx/src/python/pygrid.hpp | 16 ++++++++-- components/eamxx/src/python/pyutils.hpp | 23 ++++++++++++++ 5 files changed, 93 insertions(+), 7 deletions(-) create mode 100644 components/eamxx/cmake/Findmpi4py.cmake create mode 100644 components/eamxx/src/python/pyutils.hpp diff --git a/components/eamxx/cmake/Findmpi4py.cmake b/components/eamxx/cmake/Findmpi4py.cmake new file mode 100644 index 00000000000..685b46ca586 --- /dev/null +++ b/components/eamxx/cmake/Findmpi4py.cmake @@ -0,0 +1,36 @@ +# - FindMPI4PY +# Find mpi4py includes +# This module defines: +# MPI4PY_INCLUDE_DIR, where to find mpi4py.h, etc. +# MPI4PY_FOUND + +function (SetMpi4pyIncludeDir) +endfunction() + +if (NOT TARGET mpi4py) + # If user provided an include dir, we will use that, otherwise we'll ask python to find it + if (NOT MPI4PY_INCLUDE_DIR) + execute_process(COMMAND + "${PYTHON_EXECUTABLE}" "-c" "import mpi4py; print (mpi4py.get_include())" + OUTPUT_VARIABLE OUTPUT + RESULT_VARIABLE RESULT + OUTPUT_STRIP_TRAILING_WHITESPACE) + if (RESULT) + set(MPI4PY_FOUND FALSE) + else () + set (MPI4PY_INCLUDE_DIR ${OUTPUT} CACHE PATH "Path to mpi4py include directory" FORCE) + endif() + endif() + + # If we still don't have an include dir, it means we have no mpi4py installed + if (NOT MPI4PY_INCLUDE_DIR) + set(MPI4PY_FOUND FALSE) + else () + add_library(mpi4py INTERFACE) + target_include_directories(mpi4py INTERFACE SYSTEM ${MPI4PY_INCLUDE_DIR}) + set(MPI4PY_FOUND TRUE) + endif() +endif() + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(mpi4py DEFAULT_MSG MPI4PY_INCLUDE_DIR) diff --git a/components/eamxx/src/python/CMakeLists.txt b/components/eamxx/src/python/CMakeLists.txt index 3e1dd22072f..8c7cf2aaba9 100644 --- a/components/eamxx/src/python/CMakeLists.txt +++ b/components/eamxx/src/python/CMakeLists.txt @@ -1,4 +1,5 @@ find_package(pybind11 REQUIRED) +find_package(mpi4py REQUIRED) pybind11_add_module(pyeamxx pyeamxx.cpp) -target_link_libraries(pyeamxx PUBLIC scream_share scream_io diagnostics eamxx_physics) +target_link_libraries(pyeamxx PUBLIC mpi4py scream_share scream_io diagnostics eamxx_physics) diff --git a/components/eamxx/src/python/pyeamxx.cpp b/components/eamxx/src/python/pyeamxx.cpp index b6fc908eb5f..5834ca10f74 100644 --- a/components/eamxx/src/python/pyeamxx.cpp +++ b/components/eamxx/src/python/pyeamxx.cpp @@ -1,19 +1,32 @@ -#include "scream_session.hpp" +#include #include "pyfield.hpp" #include "pygrid.hpp" #include "pyatmproc.hpp" #include "pyparamlist.hpp" +#include "pyutils.hpp" + +#include #include +#include + namespace py = pybind11; + namespace scream { -void initialize () { - ekat::Comm comm(MPI_COMM_WORLD); +void initialize (MPI_Comm mpi_comm) { + ekat::Comm comm(mpi_comm); initialize_scream_session(comm.am_i_root()); scorpio::init_subsystem(comm); } + +void initialize () { + initialize(MPI_COMM_WORLD); +} +void initialize (pybind11::object py_comm) { + initialize(get_c_comm(py_comm)); +} void finalize () { scorpio::finalize_subsystem(); finalize_scream_session(); @@ -24,7 +37,8 @@ PYBIND11_MODULE (pyeamxx,m) { m.doc() = "Python interfaces to certain EAMxx infrastructure code"; // Scream Session - m.def("init",&initialize); + m.def("init",py::overload_cast<>(&initialize)); + m.def("init",py::overload_cast(&initialize)); m.def("finalize",&finalize); // Call all other headers' registration routines diff --git a/components/eamxx/src/python/pygrid.hpp b/components/eamxx/src/python/pygrid.hpp index aed51ead889..874b1aef17a 100644 --- a/components/eamxx/src/python/pygrid.hpp +++ b/components/eamxx/src/python/pygrid.hpp @@ -4,9 +4,12 @@ #include "share/grid/grids_manager.hpp" #include "share/grid/point_grid.hpp" #include "pyfield.hpp" +#include "pyutils.hpp" #include +#include + namespace scream { // Small grids manager class, to hold a pre-built grid @@ -38,15 +41,24 @@ struct PyGrid { PyGrid () = default; PyGrid(const std::string& name, int ncols, int nlevs) + : PyGrid (name,ncols,nlevs,MPI_COMM_WORLD) + {} + + PyGrid(const std::string& name, int ncols, int nlevs, pybind11::object py_comm) + : PyGrid (name,ncols,nlevs,get_c_comm(py_comm)) + {} + + PyGrid(const std::string& name, int ncols, int nlevs, MPI_Comm mpi_comm) { - ekat::Comm comm(MPI_COMM_WORLD); + ekat::Comm comm(mpi_comm); grid = create_point_grid(name,ncols,nlevs,comm); } }; inline void pybind_pygrid (pybind11::module& m) { pybind11::class_(m,"Grid") - .def(pybind11::init()); + .def(pybind11::init()) + .def(pybind11::init()); } } // namespace scream diff --git a/components/eamxx/src/python/pyutils.hpp b/components/eamxx/src/python/pyutils.hpp new file mode 100644 index 00000000000..6467df57eda --- /dev/null +++ b/components/eamxx/src/python/pyutils.hpp @@ -0,0 +1,23 @@ +#ifndef PYUTILS_HPP +#define PYUTILS_HPP + +#include +#include +#include + +#include + +MPI_Comm get_c_comm (pybind11::object py_comm) { + if (import_mpi4py() < 0) { + throw pybind11::error_already_set(); + } + auto py_src = py_comm.ptr(); + if (not PyObject_TypeCheck(py_src, &PyMPIComm_Type)) { + throw std::bad_cast(); + } + + auto comm_ptr = PyMPIComm_Get(py_src); + return *comm_ptr; +} + +#endif // PYUTILS_HPP