Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge branch-0.39 into branch-0.40 #253

Merged
merged 5 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion ci/test_common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,17 @@ run_cpp_tests() {
RUNTIME_PATH=${CONDA_PREFIX:-./}
BINARY_PATH=${RUNTIME_PATH}/bin

CMD_LINE="timeout 10m ${BINARY_PATH}/gtests/libucxx/UCXX_TEST"
# Disable memory get/put with RMM in protov1, it always segfaults.
CMD_LINE="timeout 10m ${BINARY_PATH}/gtests/libucxx/UCXX_TEST --gtest_filter=-*RMM*Memory*"

log_command "${CMD_LINE}"
UCX_TCP_CM_REUSEADDR=y ${CMD_LINE}

# Only test memory get/put with RMM in protov2, as protov1 segfaults.
CMD_LINE="timeout 10m ${BINARY_PATH}/gtests/libucxx/UCXX_TEST --gtest_filter=*RMM*Memory*"

log_command "${CMD_LINE}"
UCX_PROTO_ENABLE=y UCX_TCP_CM_REUSEADDR=y ${CMD_LINE}
}

run_cpp_benchmark() {
Expand Down
36 changes: 21 additions & 15 deletions cpp/benchmarks/perftest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@ typedef std::shared_ptr<BufferMap> BufferMapPtr;
typedef std::shared_ptr<TagMap> TagMapPtr;

struct app_context_t {
ProgressMode progress_mode = ProgressMode::Blocking;
const char* server_addr = NULL;
uint16_t listener_port = 12345;
size_t message_size = 8;
size_t n_iter = 100;
size_t warmup_iter = 3;
bool reuse_alloc = false;
bool verify_results = false;
ProgressMode progress_mode = ProgressMode::Blocking;
const char* server_addr = NULL;
uint16_t listener_port = 12345;
size_t message_size = 8;
size_t n_iter = 100;
size_t warmup_iter = 3;
bool endpoint_error_handling = false;
bool reuse_alloc = false;
bool verify_results = false;
};

class ListenerContext {
Expand All @@ -52,9 +53,13 @@ class ListenerContext {
std::shared_ptr<ucxx::Endpoint> _endpoint{nullptr};
std::shared_ptr<ucxx::Listener> _listener{nullptr};
std::atomic<bool> _isAvailable{true};
bool _endpointErrorHandling{false};

public:
explicit ListenerContext(std::shared_ptr<ucxx::Worker> worker) : _worker{worker} {}
ListenerContext(std::shared_ptr<ucxx::Worker> worker, bool endpointErrorHandling)
: _worker{worker}, _endpointErrorHandling(endpointErrorHandling)
{
}

~ListenerContext() { releaseEndpoint(); }

Expand All @@ -70,8 +75,7 @@ class ListenerContext {
{
if (!isAvailable()) throw std::runtime_error("Listener context already has an endpoint");

static bool endpoint_error_handling = true;
_endpoint = _listener->createEndpointFromConnRequest(conn_request, endpoint_error_handling);
_endpoint = _listener->createEndpointFromConnRequest(conn_request, _endpointErrorHandling);
_isAvailable = false;
}

Expand Down Expand Up @@ -120,6 +124,7 @@ static void printUsage()
std::cerr << " 'thread-polling' and 'thread-blocking' (default: 'blocking')"
<< std::endl;
std::cerr << " -t use thread progress mode (disabled)" << std::endl;
std::cerr << " -e create endpoints with error handling support (disabled)" << std::endl;
std::cerr << " -p <port> port number to listen at (12345)" << std::endl;
std::cerr << " -s <bytes> message size (8)" << std::endl;
std::cerr << " -n <int> number of iterations to run (100)" << std::endl;
Expand All @@ -134,7 +139,7 @@ ucs_status_t parseCommand(app_context_t* app_context, int argc, char* const argv
{
optind = 1;
int c;
while ((c = getopt(argc, argv, "m:p:s:w:n:rvh")) != -1) {
while ((c = getopt(argc, argv, "m:p:s:w:n:ervh")) != -1) {
switch (c) {
case 'm':
if (strcmp(optarg, "blocking") == 0) {
Expand Down Expand Up @@ -185,6 +190,7 @@ ucs_status_t parseCommand(app_context_t* app_context, int argc, char* const argv
return UCS_ERR_INVALID_PARAM;
}
break;
case 'e': app_context->endpoint_error_handling = true; break;
case 'r': app_context->reuse_alloc = true; break;
case 'v': app_context->verify_results = true; break;
case 'h':
Expand Down Expand Up @@ -301,7 +307,7 @@ int main(int argc, char** argv)
std::shared_ptr<ucxx::Endpoint> endpoint;
std::shared_ptr<ucxx::Listener> listener;
if (is_server) {
listener_ctx = std::make_unique<ListenerContext>(worker);
listener_ctx = std::make_unique<ListenerContext>(worker, app_context.endpoint_error_handling);
listener = worker->createListener(app_context.listener_port, listener_cb, listener_ctx.get());
listener_ctx->setListener(listener);
}
Expand All @@ -323,8 +329,8 @@ int main(int argc, char** argv)
if (is_server)
endpoint = listener_ctx->getEndpoint();
else
endpoint =
worker->createEndpointFromHostname(app_context.server_addr, app_context.listener_port, true);
endpoint = worker->createEndpointFromHostname(
app_context.server_addr, app_context.listener_port, app_context.endpoint_error_handling);

std::vector<std::shared_ptr<ucxx::Request>> requests;

Expand Down
8 changes: 5 additions & 3 deletions cpp/include/ucxx/constructors.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,11 @@ std::shared_ptr<Worker> createWorker(std::shared_ptr<Context> context,
const bool enableDelayedSubmission,
const bool enableFuture);

std::shared_ptr<MemoryHandle> createMemoryHandle(std::shared_ptr<Context> context,
const size_t size,
void* buffer = nullptr);
std::shared_ptr<MemoryHandle> createMemoryHandle(
std::shared_ptr<Context> context,
const size_t size,
void* buffer = nullptr,
const ucs_memory_type_t memoryType = UCS_MEMORY_TYPE_HOST);

std::shared_ptr<RemoteKey> createRemoteKeyFromMemoryHandle(
std::shared_ptr<MemoryHandle> memoryHandle);
Expand Down
10 changes: 6 additions & 4 deletions cpp/include/ucxx/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,15 @@ class Context : public Component {
*
* @throws ucxx::Error if either `ucp_mem_map` or `ucp_mem_query` fail.
*
* @param[in] size the minimum size of the memory allocation
* @param[in] buffer the pointer to an existing allocation or `nullptr` to allocate a
* new memory region.
* @param[in] size the minimum size of the memory allocation.
* @param[in] buffer the pointer to an existing allocation or `nullptr` to allocate a
* new memory region.
* @param[in] memoryType the type of memory the handle points to.
*
* @returns The `shared_ptr<ucxx::MemoryHandle>` object
*/
std::shared_ptr<MemoryHandle> createMemoryHandle(const size_t size, void* buffer);
std::shared_ptr<MemoryHandle> createMemoryHandle(
const size_t size, void* buffer, const ucs_memory_type_t memoryType = UCS_MEMORY_TYPE_HOST);
};

} // namespace ucxx
114 changes: 93 additions & 21 deletions cpp/include/ucxx/delayed_submission.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
*/
#pragma once

#include <deque>
#include <functional>
#include <memory>
#include <mutex>
#include <optional>
#include <set>
#include <stdexcept>
#include <string>
#include <utility>
#include <variant>
#include <vector>

#include <ucp/api/ucp.h>
#include <ucs/memory/memory_type.h>
Expand All @@ -30,6 +31,8 @@ namespace ucxx {
*/
typedef std::function<void()> DelayedSubmissionCallbackType;

typedef uint64_t ItemIdType;

/**
* @brief Base type for a collection of delayed submissions.
*
Expand All @@ -40,28 +43,32 @@ template <typename T>
class BaseDelayedSubmissionCollection {
protected:
std::string _name{"undefined"}; ///< The human-readable name of the collection, used for logging
bool _enabled{true}; ///< Whether the resource required to process the collection is enabled.
std::vector<T> _collection{}; ///< The collection.
std::mutex _mutex{}; ///< Mutex to provide access to `_collection`.
bool _enabled{true}; ///< Whether the resource required to process the collection is enabled.
ItemIdType _itemId{0}; ///< The item ID counter, used to allow cancelation.
std::deque<std::pair<ItemIdType, T>> _collection{}; ///< The collection.
std::set<ItemIdType> _canceled{}; ///< IDs of canceled items.
std::mutex _mutex{}; ///< Mutex to provide access to `_collection`.

/**
* @brief Log message during `schedule()`.
*
* Log a specialized message while `schedule()` is being executed.
*
* @param[in] id the ID of the scheduled item, as returned by `schedule()`.
* @param[in] item the callback that was passed as argument to `schedule()`.
*/
virtual void scheduleLog(T item) = 0;
virtual void scheduleLog(ItemIdType id, T item) = 0;

/**
* @brief Process a single item during `process()`.
*
* Method called by `process()` to process a single item of the collection.
*
* @param[in] id the ID of the scheduled item, as returned by `schedule()`.
* @param[in] item the callback that was passed as argument to `schedule()` when
* the first registered.
*/
virtual void processItem(T item) = 0;
virtual void processItem(ItemIdType id, T item) = 0;

public:
/**
Expand Down Expand Up @@ -102,16 +109,22 @@ class BaseDelayedSubmissionCollection {
*
* @param[in] item the callback that will be executed by `process()` when the
* operation is submitted.
*
* @returns the ID of the scheduled item which can be used cancelation requests.
*/
virtual void schedule(T item)
[[nodiscard]] virtual ItemIdType schedule(T item)
{
if (!_enabled) throw std::runtime_error("Resource is disabled.");

ItemIdType id;
{
std::lock_guard<std::mutex> lock(_mutex);
_collection.push_back(item);
id = _itemId++;
_collection.emplace_back(id, item);
}
scheduleLog(item);
scheduleLog(id, item);

return id;
}

/**
Expand All @@ -122,20 +135,47 @@ class BaseDelayedSubmissionCollection {
*/
void process()
{
decltype(_collection) itemsToProcess;
// Process only those that were already inserted to prevent from never
// returning if `_collection` grows indefinitely.
size_t toProcess = 0;
{
std::lock_guard<std::mutex> lock(_mutex);
// Move _collection to a local copy in order to to hold the lock for as
// short as possible
itemsToProcess = std::move(_collection);
toProcess = _collection.size();
}

if (itemsToProcess.size() > 0) {
ucxx_trace_req("Submitting %lu %s callbacks", itemsToProcess.size(), _name.c_str());
for (auto& item : itemsToProcess)
processItem(item);
for (auto i = 0; i < toProcess; ++i) {
std::pair<ItemIdType, T> item;
{
std::lock_guard<std::mutex> lock(_mutex);
item = std::move(_collection.front());
_collection.pop_front();
if (_canceled.erase(item.first)) continue;
}

processItem(item.first, item.second);
}
}

/**
* @brief Cancel a pending callback.
*
* Cancel a pending callback and thus do not execute it, unless the execution has
* already begun, in which case cancelation cannot be done.
*
* @param[in] id the ID of the scheduled item, as returned by `schedule()`.
*/
void cancel(ItemIdType id)
{
std::lock_guard<std::mutex> lock(_mutex);
// TODO: Check if not cancellable anymore? Will likely need a separate set to keep
// track of registered items.
//
// If the callback is already running
// and the user has no way of knowing that but still destroys it, undefined
// behavior may occur.
_canceled.insert(id);
ucxx_trace_req("Canceled item: %lu", id);
}
};

/**
Expand All @@ -149,9 +189,11 @@ class RequestDelayedSubmissionCollection
std::pair<std::shared_ptr<Request>, DelayedSubmissionCallbackType>> {
protected:
void scheduleLog(
ItemIdType id,
std::pair<std::shared_ptr<Request>, DelayedSubmissionCallbackType> item) override;

void processItem(
ItemIdType id,
std::pair<std::shared_ptr<Request>, DelayedSubmissionCallbackType> item) override;

public:
Expand All @@ -177,9 +219,9 @@ class RequestDelayedSubmissionCollection
class GenericDelayedSubmissionCollection
: public BaseDelayedSubmissionCollection<DelayedSubmissionCallbackType> {
protected:
void scheduleLog(DelayedSubmissionCallbackType item) override;
void scheduleLog(ItemIdType id, DelayedSubmissionCallbackType item) override;

void processItem(DelayedSubmissionCallbackType callback) override;
void processItem(ItemIdType id, DelayedSubmissionCallbackType callback) override;

public:
/**
Expand Down Expand Up @@ -276,8 +318,10 @@ class DelayedSubmissionCollection {
* Lifetime of the callback must be ensured by the caller.
*
* @param[in] callback the callback that will be executed by `processPre()`.
*
* @returns the ID of the scheduled item which can be used cancelation requests.
*/
void registerGenericPre(DelayedSubmissionCallbackType callback);
ItemIdType registerGenericPre(DelayedSubmissionCallbackType callback);

/**
* @brief Register a generic callback to execute during `processPost()`.
Expand All @@ -286,8 +330,36 @@ class DelayedSubmissionCollection {
* Lifetime of the callback must be ensured by the caller.
*
* @param[in] callback the callback that will be executed by `processPre()`.
*
* @returns the ID of the scheduled item which can be used cancelation requests.
*/
ItemIdType registerGenericPost(DelayedSubmissionCallbackType callback);

/**
* @brief Cancel a generic callback scheduled for `processPre()` execution.
*
* Cancel the execution of a generic callback that has been previously scheduled for
* execution during `processPre()`. This can be useful if the caller of
* `registerGenericPre()` has given up and will not anymore be able to guarantee the
* lifetime of the callback.
*
* @param[in] id the ID of the scheduled item, as returned
* by `registerGenericPre()`.
*/
void cancelGenericPre(ItemIdType id);

/**
* @brief Cancel a generic callback scheduled for `processPost()` execution.
*
* Cancel the execution of a generic callback that has been previously scheduled for
* execution during `processPos()`. This can be useful if the caller of
* `registerGenericPre()` has given up and will not anymore be able to guarantee the
* lifetime of the callback.
*
* @param[in] id the ID of the scheduled item, as returned
* by `registerGenericPos()`.
*/
void registerGenericPost(DelayedSubmissionCallbackType callback);
void cancelGenericPost(ItemIdType id);

/**
* @brief Inquire if delayed request submission is enabled.
Expand Down
Loading
Loading