Skip to content

Commit

Permalink
EAMxx: add support for passing MPI comm from py to C in pyeamxx
Browse files Browse the repository at this point in the history
  • Loading branch information
bartgol committed Jun 11, 2024
1 parent 8d8e2af commit 8cf5d38
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 7 deletions.
36 changes: 36 additions & 0 deletions components/eamxx/cmake/Findmpi4py.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# - FindMPI4PY
# Find mpi4py includes
# This module defines:
# MPI4PY_INCLUDE_DIR, where to find mpi4py.h, etc.
# MPI4PY_FOUND

function (SetMpi4pyIncludeDir)
endfunction()

if (NOT TARGET mpi4py)
# If user provided an include dir, we will use that, otherwise we'll ask python to find it
if (NOT MPI4PY_INCLUDE_DIR)
execute_process(COMMAND
"${PYTHON_EXECUTABLE}" "-c" "import mpi4py; print (mpi4py.get_include())"
OUTPUT_VARIABLE OUTPUT
RESULT_VARIABLE RESULT
OUTPUT_STRIP_TRAILING_WHITESPACE)
if (RESULT)
set(MPI4PY_FOUND FALSE)
else ()
set (MPI4PY_INCLUDE_DIR ${OUTPUT} CACHE PATH "Path to mpi4py include directory" FORCE)
endif()
endif()

# If we still don't have an include dir, it means we have no mpi4py installed
if (NOT MPI4PY_INCLUDE_DIR)
set(MPI4PY_FOUND FALSE)
else ()
add_library(mpi4py INTERFACE)
target_include_directories(mpi4py INTERFACE SYSTEM ${MPI4PY_INCLUDE_DIR})
set(MPI4PY_FOUND TRUE)
endif()
endif()

include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(mpi4py DEFAULT_MSG MPI4PY_INCLUDE_DIR)
3 changes: 2 additions & 1 deletion components/eamxx/src/python/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
find_package(pybind11 REQUIRED)
find_package(mpi4py REQUIRED)

pybind11_add_module(pyeamxx pyeamxx.cpp)
target_link_libraries(pyeamxx PUBLIC scream_share scream_io diagnostics eamxx_physics)
target_link_libraries(pyeamxx PUBLIC mpi4py scream_share scream_io diagnostics eamxx_physics)
22 changes: 18 additions & 4 deletions components/eamxx/src/python/pyeamxx.cpp
Original file line number Diff line number Diff line change
@@ -1,19 +1,32 @@
#include "scream_session.hpp"
#include <share/scream_session.hpp>
#include "pyfield.hpp"
#include "pygrid.hpp"
#include "pyatmproc.hpp"
#include "pyparamlist.hpp"
#include "pyutils.hpp"

#include <ekat/mpi/ekat_comm.hpp>

#include <pybind11/pybind11.h>

#include <mpi.h>

namespace py = pybind11;

namespace scream {

void initialize () {
ekat::Comm comm(MPI_COMM_WORLD);
void initialize (MPI_Comm mpi_comm) {
ekat::Comm comm(mpi_comm);
initialize_scream_session(comm.am_i_root());
scorpio::init_subsystem(comm);
}

void initialize () {
initialize(MPI_COMM_WORLD);
}
void initialize (pybind11::object py_comm) {
initialize(get_c_comm(py_comm));
}
void finalize () {
scorpio::finalize_subsystem();
finalize_scream_session();
Expand All @@ -24,7 +37,8 @@ PYBIND11_MODULE (pyeamxx,m) {
m.doc() = "Python interfaces to certain EAMxx infrastructure code";

// Scream Session
m.def("init",&initialize);
m.def("init",py::overload_cast<>(&initialize));
m.def("init",py::overload_cast<pybind11::object>(&initialize));
m.def("finalize",&finalize);

// Call all other headers' registration routines
Expand Down
16 changes: 14 additions & 2 deletions components/eamxx/src/python/pygrid.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
#include "share/grid/grids_manager.hpp"
#include "share/grid/point_grid.hpp"
#include "pyfield.hpp"
#include "pyutils.hpp"

#include <pybind11/pybind11.h>

#include <mpi.h>

namespace scream {

// Small grids manager class, to hold a pre-built grid
Expand Down Expand Up @@ -38,15 +41,24 @@ struct PyGrid {
PyGrid () = default;

PyGrid(const std::string& name, int ncols, int nlevs)
: PyGrid (name,ncols,nlevs,MPI_COMM_WORLD)
{}

PyGrid(const std::string& name, int ncols, int nlevs, pybind11::object py_comm)
: PyGrid (name,ncols,nlevs,get_c_comm(py_comm))
{}

PyGrid(const std::string& name, int ncols, int nlevs, MPI_Comm mpi_comm)
{
ekat::Comm comm(MPI_COMM_WORLD);
ekat::Comm comm(mpi_comm);
grid = create_point_grid(name,ncols,nlevs,comm);
}
};

inline void pybind_pygrid (pybind11::module& m) {
pybind11::class_<PyGrid>(m,"Grid")
.def(pybind11::init<const std::string&,int,int>());
.def(pybind11::init<const std::string&,int,int>())
.def(pybind11::init<const std::string&,int,int,pybind11::object>());
}

} // namespace scream
Expand Down
23 changes: 23 additions & 0 deletions components/eamxx/src/python/pyutils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef PYUTILS_HPP
#define PYUTILS_HPP

#include <pybind11/pybind11.h>
#include <mpi4py/mpi4py.h>
#include <mpi.h>

#include <typeinfo>

MPI_Comm get_c_comm (pybind11::object py_comm) {
if (import_mpi4py() < 0) {
throw pybind11::error_already_set();
}
auto py_src = py_comm.ptr();
if (not PyObject_TypeCheck(py_src, &PyMPIComm_Type)) {
throw std::bad_cast();
}

auto comm_ptr = PyMPIComm_Get(py_src);
return *comm_ptr;
}

#endif // PYUTILS_HPP

0 comments on commit 8cf5d38

Please sign in to comment.