Skip to content

Commit

Permalink
Remove MPI static globals (espressomd#4858)
Browse files Browse the repository at this point in the history
Fixes espressomd#4856

Description of changes:
- fix multiple bugs caused by undefined behavior due to the static initialization order of MPI global objects
- ESPResSo is now compatible with Boost 1.84+
  • Loading branch information
kodiakhq[bot] authored and jngrad committed Feb 29, 2024
1 parent d09de49 commit dabb318
Show file tree
Hide file tree
Showing 24 changed files with 180 additions and 68 deletions.
9 changes: 7 additions & 2 deletions src/core/EspressoSystemStandAlone.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,18 @@
#include <utils/Vector.hpp>

#include <boost/mpi.hpp>
#include <boost/mpi/environment.hpp>

#include <memory>

EspressoSystemStandAlone::EspressoSystemStandAlone(int argc, char **argv) {
auto mpi_env = mpi_init(argc, argv);
m_mpi_env = mpi_init(argc, argv);

boost::mpi::communicator world;
head_node = world.rank() == 0;

// initialize the MpiCallbacks framework
Communication::init(mpi_env);
Communication::init(m_mpi_env);

// default-construct global state of the system
#ifdef VIRTUAL_SITES
Expand All @@ -50,6 +51,10 @@ EspressoSystemStandAlone::EspressoSystemStandAlone(int argc, char **argv) {
mpi_loop();
}

EspressoSystemStandAlone::~EspressoSystemStandAlone() {
Communication::deinit();
}

void EspressoSystemStandAlone::set_box_l(Utils::Vector3d const &box_l) const {
if (!head_node)
return;
Expand Down
10 changes: 10 additions & 0 deletions src/core/EspressoSystemStandAlone.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,29 @@

#include <utils/Vector.hpp>

#include <memory>

namespace boost {
namespace mpi {
class environment;
}
} // namespace boost

/** Manager for a stand-alone ESPResSo system.
* The system is default-initialized, MPI-ready and has no script interface.
*/
class EspressoSystemStandAlone {
public:
EspressoSystemStandAlone(int argc, char **argv);
~EspressoSystemStandAlone();
void set_box_l(Utils::Vector3d const &box_l) const;
void set_node_grid(Utils::Vector3i const &node_grid) const;
void set_time_step(double time_step) const;
void set_skin(double new_skin) const;

private:
bool head_node;
std::shared_ptr<boost::mpi::environment> m_mpi_env;
};

#endif
17 changes: 12 additions & 5 deletions src/core/MpiCallbacks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
#include <boost/mpi/collectives/broadcast.hpp>
#include <boost/mpi/collectives/reduce.hpp>
#include <boost/mpi/communicator.hpp>
#include <boost/mpi/environment.hpp>
#include <boost/optional.hpp>
#include <boost/range/algorithm/remove_if.hpp>

Expand Down Expand Up @@ -364,8 +365,8 @@ class MpiCallbacks {
template <typename F, class = std::enable_if_t<std::is_same<
typename detail::functor_types<F>::argument_types,
std::tuple<Args...>>::value>>
CallbackHandle(MpiCallbacks *cb, F &&f)
: m_id(cb->add(std::forward<F>(f))), m_cb(cb) {}
CallbackHandle(std::shared_ptr<MpiCallbacks> cb, F &&f)
: m_id(cb->add(std::forward<F>(f))), m_cb(std::move(cb)) {}

CallbackHandle(CallbackHandle const &) = delete;
CallbackHandle(CallbackHandle &&rhs) noexcept = default;
Expand All @@ -374,7 +375,7 @@ class MpiCallbacks {

private:
int m_id;
MpiCallbacks *m_cb;
std::shared_ptr<MpiCallbacks> m_cb;

public:
/**
Expand All @@ -400,7 +401,6 @@ class MpiCallbacks {
m_cb->remove(m_id);
}

MpiCallbacks *cb() const { return m_cb; }
int id() const { return m_id; }
};

Expand All @@ -419,8 +419,10 @@ class MpiCallbacks {

public:
explicit MpiCallbacks(boost::mpi::communicator comm,
std::shared_ptr<boost::mpi::environment> mpi_env,
bool abort_on_exit = true)
: m_abort_on_exit(abort_on_exit), m_comm(std::move(comm)) {
: m_abort_on_exit(abort_on_exit), m_comm(std::move(comm)),
m_mpi_env(std::move(mpi_env)) {
/* Add a dummy at id 0 for loop abort. */
m_callback_map.add(nullptr);

Expand Down Expand Up @@ -738,6 +740,11 @@ class MpiCallbacks {
*/
boost::mpi::communicator m_comm;

/**
* The MPI environment used for the callbacks.
*/
std::shared_ptr<boost::mpi::environment> m_mpi_env;

/**
* Internal storage for the callback functions.
*/
Expand Down
23 changes: 13 additions & 10 deletions src/core/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include <utils/mpi/cart_comm.hpp>

#include <boost/mpi.hpp>
#include <boost/mpi/communicator.hpp>
#include <boost/mpi/environment.hpp>

#include <mpi.h>
#ifdef OPEN_MPI
Expand All @@ -39,22 +41,23 @@
#include <cstdlib>
#include <memory>

namespace Communication {
auto const &mpi_datatype_cache = boost::mpi::detail::mpi_datatype_cache();
std::shared_ptr<boost::mpi::environment> mpi_env;
} // namespace Communication

boost::mpi::communicator comm_cart;

namespace Communication {
std::unique_ptr<MpiCallbacks> m_callbacks;
static std::shared_ptr<MpiCallbacks> m_callbacks;

/* We use a singleton callback class for now. */
MpiCallbacks &mpiCallbacks() {
assert(m_callbacks && "Mpi not initialized!");

return *m_callbacks;
}

std::shared_ptr<MpiCallbacks> mpiCallbacksHandle() {
assert(m_callbacks && "Mpi not initialized!");

return m_callbacks;
}
} // namespace Communication

using Communication::mpiCallbacks;
Expand Down Expand Up @@ -120,8 +123,6 @@ void openmpi_global_namespace() {

namespace Communication {
void init(std::shared_ptr<boost::mpi::environment> mpi_env) {
Communication::mpi_env = std::move(mpi_env);

MPI_Comm_size(MPI_COMM_WORLD, &n_nodes);
node_grid = Utils::Mpi::dims_create<3>(n_nodes);

Expand All @@ -131,12 +132,14 @@ void init(std::shared_ptr<boost::mpi::environment> mpi_env) {
this_node = comm_cart.rank();

Communication::m_callbacks =
std::make_unique<Communication::MpiCallbacks>(comm_cart);
std::make_shared<Communication::MpiCallbacks>(comm_cart, mpi_env);

ErrorHandling::init_error_handling(mpiCallbacks());
ErrorHandling::init_error_handling(Communication::m_callbacks);

on_program_start();
}

void deinit() { Communication::m_callbacks.reset(); }
} // namespace Communication

std::shared_ptr<boost::mpi::environment> mpi_init(int argc, char **argv) {
Expand Down
6 changes: 2 additions & 4 deletions src/core/communication.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ namespace Communication {
* @brief Returns a reference to the global callback class instance.
*/
MpiCallbacks &mpiCallbacks();
std::shared_ptr<MpiCallbacks> mpiCallbacksHandle();
} // namespace Communication

/**************************************************
Expand Down Expand Up @@ -136,12 +137,9 @@ namespace Communication {
/**
* @brief Init globals for communication.
*
* and calls @ref on_program_start. Keeps a copy of
* the pointer to the mpi environment to keep it alive
* while the program is loaded.
*
* @param mpi_env MPI environment that should be used
*/
void init(std::shared_ptr<boost::mpi::environment> mpi_env);
void deinit();
} // namespace Communication
#endif
13 changes: 7 additions & 6 deletions src/core/errorhandling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>

namespace ErrorHandling {
Expand All @@ -45,14 +46,14 @@ namespace {
std::unique_ptr<RuntimeErrorCollector> runtimeErrorCollector;

/** The callback loop we are on. */
Communication::MpiCallbacks *m_callbacks = nullptr;
std::weak_ptr<Communication::MpiCallbacks> m_callbacks;
} // namespace

void init_error_handling(Communication::MpiCallbacks &cb) {
m_callbacks = &cb;
void init_error_handling(std::weak_ptr<Communication::MpiCallbacks> callbacks) {
m_callbacks = std::move(callbacks);

runtimeErrorCollector =
std::make_unique<RuntimeErrorCollector>(m_callbacks->comm());
std::make_unique<RuntimeErrorCollector>(m_callbacks.lock()->comm());
}

RuntimeErrorStream _runtimeMessageStream(RuntimeError::ErrorLevel level,
Expand All @@ -69,7 +70,7 @@ static void mpi_gather_runtime_errors_local() {
REGISTER_CALLBACK(mpi_gather_runtime_errors_local)

std::vector<RuntimeError> mpi_gather_runtime_errors() {
m_callbacks->call(mpi_gather_runtime_errors_local);
m_callbacks.lock()->call(mpi_gather_runtime_errors_local);
return runtimeErrorCollector->gather();
}

Expand All @@ -83,7 +84,7 @@ std::vector<RuntimeError> mpi_gather_runtime_errors_all(bool is_head_node) {
} // namespace ErrorHandling

void errexit() {
ErrorHandling::m_callbacks->comm().abort(1);
ErrorHandling::m_callbacks.lock()->comm().abort(1);

std::abort();
}
Expand Down
3 changes: 2 additions & 1 deletion src/core/errorhandling.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "error_handling/RuntimeError.hpp"
#include "error_handling/RuntimeErrorStream.hpp"

#include <memory>
#include <string>
#include <vector>

Expand Down Expand Up @@ -85,7 +86,7 @@ namespace ErrorHandling {
*
* @param callbacks Callbacks system the error handler should be on.
*/
void init_error_handling(Communication::MpiCallbacks &callbacks);
void init_error_handling(std::weak_ptr<Communication::MpiCallbacks> callbacks);

RuntimeErrorStream _runtimeMessageStream(RuntimeError::ErrorLevel level,
const std::string &file, int line,
Expand Down
2 changes: 1 addition & 1 deletion src/core/particle_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ static int mpi_place_new_particle(int p_id, const Utils::Vector3d &pos) {

void mpi_place_particle_local(int pnode, int p_id) {
if (pnode == this_node) {
Utils::Vector3d pos;
Utils::Vector3d pos{};
comm_cart.recv(0, some_tag, pos);
local_move_particle(p_id, pos);
}
Expand Down
4 changes: 3 additions & 1 deletion src/core/unit_tests/EspressoSystemStandAlone_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,5 +357,7 @@ BOOST_FIXTURE_TEST_CASE(espresso_system_stand_alone, ParticleFactory,
int main(int argc, char **argv) {
espresso::system = std::make_unique<EspressoSystemStandAlone>(argc, argv);

return boost::unit_test::unit_test_main(init_unit_test, argc, argv);
int retval = boost::unit_test::unit_test_main(init_unit_test, argc, argv);
espresso::system.reset();
return retval;
}
Loading

0 comments on commit dabb318

Please sign in to comment.