Skip to content

Commit

Permalink
Merge pull request #253 from jameslamb/branch-0.40-merge-0.39
Browse files Browse the repository at this point in the history
Merge branch-0.39 into branch-0.40
  • Loading branch information
AyodeAwe authored Jul 25, 2024
2 parents 63a4603 + 4f3c4a6 commit 8cdbcf0
Show file tree
Hide file tree
Showing 25 changed files with 682 additions and 470 deletions.
9 changes: 8 additions & 1 deletion ci/test_common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,17 @@ run_cpp_tests() {
RUNTIME_PATH=${CONDA_PREFIX:-./}
BINARY_PATH=${RUNTIME_PATH}/bin

CMD_LINE="timeout 10m ${BINARY_PATH}/gtests/libucxx/UCXX_TEST"
# Disable memory get/put with RMM in protov1, it always segfaults.
CMD_LINE="timeout 10m ${BINARY_PATH}/gtests/libucxx/UCXX_TEST --gtest_filter=-*RMM*Memory*"

log_command "${CMD_LINE}"
UCX_TCP_CM_REUSEADDR=y ${CMD_LINE}

# Only test memory get/put with RMM in protov2, as protov1 segfaults.
CMD_LINE="timeout 10m ${BINARY_PATH}/gtests/libucxx/UCXX_TEST --gtest_filter=*RMM*Memory*"

log_command "${CMD_LINE}"
UCX_PROTO_ENABLE=y UCX_TCP_CM_REUSEADDR=y ${CMD_LINE}
}

run_cpp_benchmark() {
Expand Down
36 changes: 21 additions & 15 deletions cpp/benchmarks/perftest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@ typedef std::shared_ptr<BufferMap> BufferMapPtr;
typedef std::shared_ptr<TagMap> TagMapPtr;

struct app_context_t {
ProgressMode progress_mode = ProgressMode::Blocking;
const char* server_addr = NULL;
uint16_t listener_port = 12345;
size_t message_size = 8;
size_t n_iter = 100;
size_t warmup_iter = 3;
bool reuse_alloc = false;
bool verify_results = false;
ProgressMode progress_mode = ProgressMode::Blocking;
const char* server_addr = NULL;
uint16_t listener_port = 12345;
size_t message_size = 8;
size_t n_iter = 100;
size_t warmup_iter = 3;
bool endpoint_error_handling = false;
bool reuse_alloc = false;
bool verify_results = false;
};

class ListenerContext {
Expand All @@ -52,9 +53,13 @@ class ListenerContext {
std::shared_ptr<ucxx::Endpoint> _endpoint{nullptr};
std::shared_ptr<ucxx::Listener> _listener{nullptr};
std::atomic<bool> _isAvailable{true};
bool _endpointErrorHandling{false};

public:
explicit ListenerContext(std::shared_ptr<ucxx::Worker> worker) : _worker{worker} {}
ListenerContext(std::shared_ptr<ucxx::Worker> worker, bool endpointErrorHandling)
: _worker{worker}, _endpointErrorHandling(endpointErrorHandling)
{
}

~ListenerContext() { releaseEndpoint(); }

Expand All @@ -70,8 +75,7 @@ class ListenerContext {
{
if (!isAvailable()) throw std::runtime_error("Listener context already has an endpoint");

static bool endpoint_error_handling = true;
_endpoint = _listener->createEndpointFromConnRequest(conn_request, endpoint_error_handling);
_endpoint = _listener->createEndpointFromConnRequest(conn_request, _endpointErrorHandling);
_isAvailable = false;
}

Expand Down Expand Up @@ -120,6 +124,7 @@ static void printUsage()
std::cerr << " 'thread-polling' and 'thread-blocking' (default: 'blocking')"
<< std::endl;
std::cerr << " -t use thread progress mode (disabled)" << std::endl;
std::cerr << " -e create endpoints with error handling support (disabled)" << std::endl;
std::cerr << " -p <port> port number to listen at (12345)" << std::endl;
std::cerr << " -s <bytes> message size (8)" << std::endl;
std::cerr << " -n <int> number of iterations to run (100)" << std::endl;
Expand All @@ -134,7 +139,7 @@ ucs_status_t parseCommand(app_context_t* app_context, int argc, char* const argv
{
optind = 1;
int c;
while ((c = getopt(argc, argv, "m:p:s:w:n:rvh")) != -1) {
while ((c = getopt(argc, argv, "m:p:s:w:n:ervh")) != -1) {
switch (c) {
case 'm':
if (strcmp(optarg, "blocking") == 0) {
Expand Down Expand Up @@ -185,6 +190,7 @@ ucs_status_t parseCommand(app_context_t* app_context, int argc, char* const argv
return UCS_ERR_INVALID_PARAM;
}
break;
case 'e': app_context->endpoint_error_handling = true; break;
case 'r': app_context->reuse_alloc = true; break;
case 'v': app_context->verify_results = true; break;
case 'h':
Expand Down Expand Up @@ -301,7 +307,7 @@ int main(int argc, char** argv)
std::shared_ptr<ucxx::Endpoint> endpoint;
std::shared_ptr<ucxx::Listener> listener;
if (is_server) {
listener_ctx = std::make_unique<ListenerContext>(worker);
listener_ctx = std::make_unique<ListenerContext>(worker, app_context.endpoint_error_handling);
listener = worker->createListener(app_context.listener_port, listener_cb, listener_ctx.get());
listener_ctx->setListener(listener);
}
Expand All @@ -323,8 +329,8 @@ int main(int argc, char** argv)
if (is_server)
endpoint = listener_ctx->getEndpoint();
else
endpoint =
worker->createEndpointFromHostname(app_context.server_addr, app_context.listener_port, true);
endpoint = worker->createEndpointFromHostname(
app_context.server_addr, app_context.listener_port, app_context.endpoint_error_handling);

std::vector<std::shared_ptr<ucxx::Request>> requests;

Expand Down
8 changes: 5 additions & 3 deletions cpp/include/ucxx/constructors.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,11 @@ std::shared_ptr<Worker> createWorker(std::shared_ptr<Context> context,
const bool enableDelayedSubmission,
const bool enableFuture);

std::shared_ptr<MemoryHandle> createMemoryHandle(std::shared_ptr<Context> context,
const size_t size,
void* buffer = nullptr);
std::shared_ptr<MemoryHandle> createMemoryHandle(
std::shared_ptr<Context> context,
const size_t size,
void* buffer = nullptr,
const ucs_memory_type_t memoryType = UCS_MEMORY_TYPE_HOST);

std::shared_ptr<RemoteKey> createRemoteKeyFromMemoryHandle(
std::shared_ptr<MemoryHandle> memoryHandle);
Expand Down
10 changes: 6 additions & 4 deletions cpp/include/ucxx/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,15 @@ class Context : public Component {
*
* @throws ucxx::Error if either `ucp_mem_map` or `ucp_mem_query` fail.
*
* @param[in] size the minimum size of the memory allocation
* @param[in] buffer the pointer to an existing allocation or `nullptr` to allocate a
* new memory region.
* @param[in] size the minimum size of the memory allocation.
* @param[in] buffer the pointer to an existing allocation or `nullptr` to allocate a
* new memory region.
* @param[in] memoryType the type of memory the handle points to.
*
* @returns The `shared_ptr<ucxx::MemoryHandle>` object
*/
std::shared_ptr<MemoryHandle> createMemoryHandle(const size_t size, void* buffer);
std::shared_ptr<MemoryHandle> createMemoryHandle(
const size_t size, void* buffer, const ucs_memory_type_t memoryType = UCS_MEMORY_TYPE_HOST);
};

} // namespace ucxx
114 changes: 93 additions & 21 deletions cpp/include/ucxx/delayed_submission.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
*/
#pragma once

#include <deque>
#include <functional>
#include <memory>
#include <mutex>
#include <optional>
#include <set>
#include <stdexcept>
#include <string>
#include <utility>
#include <variant>
#include <vector>

#include <ucp/api/ucp.h>
#include <ucs/memory/memory_type.h>
Expand All @@ -30,6 +31,8 @@ namespace ucxx {
*/
typedef std::function<void()> DelayedSubmissionCallbackType;

typedef uint64_t ItemIdType;

/**
* @brief Base type for a collection of delayed submissions.
*
Expand All @@ -40,28 +43,32 @@ template <typename T>
class BaseDelayedSubmissionCollection {
protected:
std::string _name{"undefined"}; ///< The human-readable name of the collection, used for logging
bool _enabled{true}; ///< Whether the resource required to process the collection is enabled.
std::vector<T> _collection{}; ///< The collection.
std::mutex _mutex{}; ///< Mutex to provide access to `_collection`.
bool _enabled{true}; ///< Whether the resource required to process the collection is enabled.
ItemIdType _itemId{0}; ///< The item ID counter, used to allow cancelation.
std::deque<std::pair<ItemIdType, T>> _collection{}; ///< The collection.
std::set<ItemIdType> _canceled{}; ///< IDs of canceled items.
std::mutex _mutex{}; ///< Mutex to provide access to `_collection`.

/**
* @brief Log message during `schedule()`.
*
* Log a specialized message while `schedule()` is being executed.
*
* @param[in] id the ID of the scheduled item, as returned by `schedule()`.
* @param[in] item the callback that was passed as argument to `schedule()`.
*/
virtual void scheduleLog(T item) = 0;
virtual void scheduleLog(ItemIdType id, T item) = 0;

/**
* @brief Process a single item during `process()`.
*
* Method called by `process()` to process a single item of the collection.
*
* @param[in] id the ID of the scheduled item, as returned by `schedule()`.
* @param[in] item the callback that was passed as argument to `schedule()` when
* the first registered.
*/
virtual void processItem(T item) = 0;
virtual void processItem(ItemIdType id, T item) = 0;

public:
/**
Expand Down Expand Up @@ -102,16 +109,22 @@ class BaseDelayedSubmissionCollection {
*
* @param[in] item the callback that will be executed by `process()` when the
* operation is submitted.
*
* @returns the ID of the scheduled item which can be used cancelation requests.
*/
virtual void schedule(T item)
[[nodiscard]] virtual ItemIdType schedule(T item)
{
if (!_enabled) throw std::runtime_error("Resource is disabled.");

ItemIdType id;
{
std::lock_guard<std::mutex> lock(_mutex);
_collection.push_back(item);
id = _itemId++;
_collection.emplace_back(id, item);
}
scheduleLog(item);
scheduleLog(id, item);

return id;
}

/**
Expand All @@ -122,20 +135,47 @@ class BaseDelayedSubmissionCollection {
*/
void process()
{
decltype(_collection) itemsToProcess;
// Process only those that were already inserted to prevent from never
// returning if `_collection` grows indefinitely.
size_t toProcess = 0;
{
std::lock_guard<std::mutex> lock(_mutex);
// Move _collection to a local copy in order to to hold the lock for as
// short as possible
itemsToProcess = std::move(_collection);
toProcess = _collection.size();
}

if (itemsToProcess.size() > 0) {
ucxx_trace_req("Submitting %lu %s callbacks", itemsToProcess.size(), _name.c_str());
for (auto& item : itemsToProcess)
processItem(item);
for (auto i = 0; i < toProcess; ++i) {
std::pair<ItemIdType, T> item;
{
std::lock_guard<std::mutex> lock(_mutex);
item = std::move(_collection.front());
_collection.pop_front();
if (_canceled.erase(item.first)) continue;
}

processItem(item.first, item.second);
}
}

/**
* @brief Cancel a pending callback.
*
* Cancel a pending callback and thus do not execute it, unless the execution has
* already begun, in which case cancelation cannot be done.
*
* @param[in] id the ID of the scheduled item, as returned by `schedule()`.
*/
void cancel(ItemIdType id)
{
std::lock_guard<std::mutex> lock(_mutex);
// TODO: Check if not cancellable anymore? Will likely need a separate set to keep
// track of registered items.
//
// If the callback is already running
// and the user has no way of knowing that but still destroys it, undefined
// behavior may occur.
_canceled.insert(id);
ucxx_trace_req("Canceled item: %lu", id);
}
};

/**
Expand All @@ -149,9 +189,11 @@ class RequestDelayedSubmissionCollection
std::pair<std::shared_ptr<Request>, DelayedSubmissionCallbackType>> {
protected:
void scheduleLog(
ItemIdType id,
std::pair<std::shared_ptr<Request>, DelayedSubmissionCallbackType> item) override;

void processItem(
ItemIdType id,
std::pair<std::shared_ptr<Request>, DelayedSubmissionCallbackType> item) override;

public:
Expand All @@ -177,9 +219,9 @@ class RequestDelayedSubmissionCollection
class GenericDelayedSubmissionCollection
: public BaseDelayedSubmissionCollection<DelayedSubmissionCallbackType> {
protected:
void scheduleLog(DelayedSubmissionCallbackType item) override;
void scheduleLog(ItemIdType id, DelayedSubmissionCallbackType item) override;

void processItem(DelayedSubmissionCallbackType callback) override;
void processItem(ItemIdType id, DelayedSubmissionCallbackType callback) override;

public:
/**
Expand Down Expand Up @@ -276,8 +318,10 @@ class DelayedSubmissionCollection {
* Lifetime of the callback must be ensured by the caller.
*
* @param[in] callback the callback that will be executed by `processPre()`.
*
* @returns the ID of the scheduled item which can be used cancelation requests.
*/
void registerGenericPre(DelayedSubmissionCallbackType callback);
ItemIdType registerGenericPre(DelayedSubmissionCallbackType callback);

/**
* @brief Register a generic callback to execute during `processPost()`.
Expand All @@ -286,8 +330,36 @@ class DelayedSubmissionCollection {
* Lifetime of the callback must be ensured by the caller.
*
* @param[in] callback the callback that will be executed by `processPre()`.
*
* @returns the ID of the scheduled item which can be used cancelation requests.
*/
ItemIdType registerGenericPost(DelayedSubmissionCallbackType callback);

/**
* @brief Cancel a generic callback scheduled for `processPre()` execution.
*
* Cancel the execution of a generic callback that has been previously scheduled for
* execution during `processPre()`. This can be useful if the caller of
* `registerGenericPre()` has given up and will not anymore be able to guarantee the
* lifetime of the callback.
*
* @param[in] id the ID of the scheduled item, as returned
* by `registerGenericPre()`.
*/
void cancelGenericPre(ItemIdType id);

/**
* @brief Cancel a generic callback scheduled for `processPost()` execution.
*
* Cancel the execution of a generic callback that has been previously scheduled for
* execution during `processPos()`. This can be useful if the caller of
* `registerGenericPre()` has given up and will not anymore be able to guarantee the
* lifetime of the callback.
*
* @param[in] id the ID of the scheduled item, as returned
* by `registerGenericPos()`.
*/
void registerGenericPost(DelayedSubmissionCallbackType callback);
void cancelGenericPost(ItemIdType id);

/**
* @brief Inquire if delayed request submission is enabled.
Expand Down
Loading

0 comments on commit 8cdbcf0

Please sign in to comment.