diff --git a/integration/xgboost/processor/CMakeLists.txt b/integration/xgboost/processor/CMakeLists.txt index d29b246377..056fd365e2 100644 --- a/integration/xgboost/processor/CMakeLists.txt +++ b/integration/xgboost/processor/CMakeLists.txt @@ -5,28 +5,22 @@ set(CMAKE_BUILD_TYPE Debug) option(GOOGLE_TEST "Build google tests" OFF) -file(GLOB_RECURSE LIB_SRC - "src/*.h" - "src/*.cc" - ) +file(GLOB_RECURSE LIB_SRC "src/*.cc") add_library(proc_nvflare SHARED ${LIB_SRC}) -set(XGB_SRC ${proc_nvflare_SOURCE_DIR}/../../../../xgboost) -target_include_directories(proc_nvflare PRIVATE ${proc_nvflare_SOURCE_DIR}/src/include - ${XGB_SRC}/src - ${XGB_SRC}/rabit/include - ${XGB_SRC}/include - ${XGB_SRC}/dmlc-core/include) - -link_directories(${XGB_SRC}/lib/) +set_target_properties(proc_nvflare PROPERTIES + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + POSITION_INDEPENDENT_CODE ON + ENABLE_EXPORTS ON +) +target_include_directories(proc_nvflare PRIVATE ${proc_nvflare_SOURCE_DIR}/src/include) if (APPLE) add_link_options("LINKER:-object_path_lto,$_lto.o") add_link_options("LINKER:-cache_path_lto,${CMAKE_BINARY_DIR}/LTOCache") endif () -target_link_libraries(proc_nvflare ${XGB_SRC}/lib/libxgboost${CMAKE_SHARED_LIBRARY_SUFFIX}) - #-- Unit Tests if(GOOGLE_TEST) find_package(GTest REQUIRED) @@ -49,4 +43,4 @@ if(GOOGLE_TEST) COMMAND proc_test WORKING_DIRECTORY ${proc_nvflare_BINARY_DIR}) -endif() \ No newline at end of file +endif() diff --git a/integration/xgboost/processor/README.md b/integration/xgboost/processor/README.md index 08afc24e42..e879081b84 100644 --- a/integration/xgboost/processor/README.md +++ b/integration/xgboost/processor/README.md @@ -1,15 +1,11 @@ # Build Instruction -This plugin build requires xgboost source code, checkout xgboost source and build it with FEDERATED plugin, - -cd xgboost -mkdir build -cd build -cmake .. -DPLUGIN_FEDERATED=ON -make - +``` sh cd NVFlare/integration/xgboost/processor mkdir build cd build cmake .. make +``` + +See [tests](./tests) for simple examples. \ No newline at end of file diff --git a/integration/xgboost/processor/src/README.md b/integration/xgboost/processor/src/README.md index a10dae75ed..f0e4bb14dc 100644 --- a/integration/xgboost/processor/src/README.md +++ b/integration/xgboost/processor/src/README.md @@ -1,11 +1,8 @@ # encoding-plugins Processor Plugin for NVFlare -This plugin is a companion for NVFlare based encryption, it processes the data so it can +This plugin is a companion for NVFlare based encryption, it processes the data so it can be properly decoded by Python code running on NVFlare. All the encryption is happening on the local GRPC client/server so no encryption is needed in this plugin. - - - diff --git a/integration/xgboost/processor/src/dam/README.md b/integration/xgboost/processor/src/dam/README.md index ba65423e65..8cce132900 100644 --- a/integration/xgboost/processor/src/dam/README.md +++ b/integration/xgboost/processor/src/dam/README.md @@ -1,12 +1,9 @@ # DAM (Direct-Accessible Marshaller) -A simple serialization library that doesn't have dependencies, and the data +A simple serialization library that doesn't have dependencies, and the data is directly accessible in C/C++ without copying. To make the data accessible in C, following rules must be followed, 1. Numeric values must be stored in native byte-order. 2. Numeric values must start at the 64-bit boundaries (8-bytes) - - - diff --git a/integration/xgboost/processor/src/dam/dam.cc b/integration/xgboost/processor/src/dam/dam.cc index 27c3512946..10625ab9b5 100644 --- a/integration/xgboost/processor/src/dam/dam.cc +++ b/integration/xgboost/processor/src/dam/dam.cc @@ -27,15 +27,14 @@ void print_buffer(uint8_t *buffer, int size) { // DamEncoder ====== void DamEncoder::AddFloatArray(const std::vector &value) { - if (encoded) { - std::cout << "Buffer is already encoded" << std::endl; - return; - } - auto buf_size = value.size()*8; - uint8_t *buffer = static_cast(malloc(buf_size)); - memcpy(buffer, value.data(), buf_size); - // print_buffer(reinterpret_cast(value.data()), value.size() * 8); - entries->push_back(new Entry(kDataTypeFloatArray, buffer, value.size())); + if (encoded) { + std::cout << "Buffer is already encoded" << std::endl; + return; + } + auto buf_size = value.size() * 8; + uint8_t *buffer = static_cast(malloc(buf_size)); + memcpy(buffer, value.data(), buf_size); + entries->push_back(new Entry(kDataTypeFloatArray, buffer, value.size())); } void DamEncoder::AddIntArray(const std::vector &value) { @@ -52,15 +51,15 @@ void DamEncoder::AddIntArray(const std::vector &value) { entries->push_back(new Entry(kDataTypeIntArray, buffer, value.size())); } -std::uint8_t * DamEncoder::Finish(size_t &size) { +std::vector DamEncoder::Finish(size_t &size) { encoded = true; size = calculate_size(); - auto buf = static_cast(malloc(size)); - auto pointer = buf; + std::vector buf(size); + auto pointer = buf.data(); memcpy(pointer, kSignature, strlen(kSignature)); - memcpy(pointer+8, &size, 8); - memcpy(pointer+16, &data_set_id, 8); + memcpy(pointer + 8, &size, 8); + memcpy(pointer + 16, &data_set_id, 8); pointer += kPrefixLen; for (auto entry : *entries) { @@ -75,9 +74,9 @@ std::uint8_t * DamEncoder::Finish(size_t &size) { // print_buffer(entry->pointer, entry->size*8); } - if ((pointer - buf) != size) { - std::cout << "Invalid encoded size: " << (pointer - buf) << std::endl; - return nullptr; + if ((pointer - buf.data()) != size) { + throw std::runtime_error{"Invalid encoded size: " + + std::to_string(pointer - buf.data())}; } return buf; @@ -97,7 +96,7 @@ std::size_t DamEncoder::calculate_size() { // DamDecoder ====== -DamDecoder::DamDecoder(std::uint8_t *buffer, std::size_t size) { +DamDecoder::DamDecoder(std::uint8_t const *buffer, std::size_t size) { this->buffer = buffer; this->buf_size = size; this->pos = buffer + kPrefixLen; @@ -115,32 +114,33 @@ bool DamDecoder::IsValid() { } std::vector DamDecoder::DecodeIntArray() { - auto type = *reinterpret_cast(pos); - if (type != kDataTypeIntArray) { - std::cout << "Data type " << type << " doesn't match Int Array" << std::endl; - return std::vector(); - } - pos += 8; - - auto len = *reinterpret_cast(pos); - pos += 8; - auto ptr = reinterpret_cast(pos); - pos += 8*len; - return std::vector(ptr, ptr + len); + auto type = *reinterpret_cast(pos); + if (type != kDataTypeIntArray) { + std::cout << "Data type " << type << " doesn't match Int Array" + << std::endl; + return std::vector(); + } + pos += 8; + + auto len = *reinterpret_cast(pos); + pos += 8; + auto ptr = reinterpret_cast(pos); + pos += 8 * len; + return std::vector(ptr, ptr + len); } std::vector DamDecoder::DecodeFloatArray() { - auto type = *reinterpret_cast(pos); + auto type = *reinterpret_cast(pos); if (type != kDataTypeFloatArray) { std::cout << "Data type " << type << " doesn't match Float Array" << std::endl; return std::vector(); } pos += 8; - auto len = *reinterpret_cast(pos); + auto len = *reinterpret_cast(pos); pos += 8; - auto ptr = reinterpret_cast(pos); + auto ptr = reinterpret_cast(pos); pos += 8*len; return std::vector(ptr, ptr + len); } diff --git a/integration/xgboost/processor/src/include/dam.h b/integration/xgboost/processor/src/include/dam.h index 1f113d92fe..7afdf983af 100644 --- a/integration/xgboost/processor/src/include/dam.h +++ b/integration/xgboost/processor/src/include/dam.h @@ -14,9 +14,9 @@ * limitations under the License. */ #pragma once -#include #include -#include +#include // for int64_t +#include // for size_t const char kSignature[] = "NVDADAM1"; // DAM (Direct Accessible Marshalling) V1 const int kPrefixLen = 24; @@ -57,23 +57,23 @@ class DamEncoder { void AddFloatArray(const std::vector &value); - std::uint8_t * Finish(size_t &size); + std::vector Finish(size_t &size); - private: + private: std::size_t calculate_size(); }; class DamDecoder { private: - std::uint8_t *buffer = nullptr; + std::uint8_t const *buffer = nullptr; std::size_t buf_size = 0; - std::uint8_t *pos = nullptr; + std::uint8_t const *pos = nullptr; std::size_t remaining = 0; int64_t data_set_id = 0; int64_t len = 0; - public: - explicit DamDecoder(std::uint8_t *buffer, std::size_t size); + public: + explicit DamDecoder(std::uint8_t const *buffer, std::size_t size); size_t Size() { return len; diff --git a/integration/xgboost/processor/src/include/nvflare_processor.h b/integration/xgboost/processor/src/include/nvflare_processor.h index cc6fb6b1a4..cb7076eaf4 100644 --- a/integration/xgboost/processor/src/include/nvflare_processor.h +++ b/integration/xgboost/processor/src/include/nvflare_processor.h @@ -14,11 +14,10 @@ * limitations under the License. */ #pragma once -#include -#include -#include -#include -#include "processing/processor.h" +#include // for uint8_t, uint32_t, int32_t, int64_t +#include // for string_view +#include // for pair +#include // for vector const int kDataSetHGPairs = 1; const int kDataSetAggregation = 2; @@ -27,50 +26,49 @@ const int kDataSetAggregationResult = 4; const int kDataSetHistograms = 5; const int kDataSetHistogramResult = 6; -class NVFlareProcessor: public processing::Processor { - private: - bool active_ = false; - const std::map *params_; - std::vector *gh_pairs_{nullptr}; - std::vector cuts_; - std::vector slots_; - bool feature_sent_ = false; - std::vector features_; +// Opaque pointer type for the C API. +typedef void *FederatedPluginHandle; // NOLINT - public: - void Initialize(bool active, std::map params) override { - this->active_ = active; - this->params_ = ¶ms; - } +namespace nvflare { +// Plugin that uses Python tenseal and GRPC. +class TensealPlugin { + // Buffer for storing encrypted gradient pairs. + std::vector encrypted_gpairs_; + // Buffer for histogram cut pointers (indptr of a CSC). + std::vector cut_ptrs_; + // Buffer for histogram index. + std::vector bin_idx_; - void Shutdown() override { - this->gh_pairs_ = nullptr; - this->cuts_.clear(); - this->slots_.clear(); - } + bool feature_sent_{false}; + // The feature index. + std::vector features_; + // Buffer for output histogram. + std::vector encrypted_hist_; + std::vector hist_; - void FreeBuffer(void *buffer) override { - free(buffer); - } +public: + TensealPlugin( + std::vector> const &args); + // Gradient pairs + void EncryptGPairs(float const *in_gpair, std::size_t n_in, + std::uint8_t **out_gpair, std::size_t *n_out); + void SyncEncryptedGPairs(std::uint8_t const *in_gpair, std::size_t n_bytes, + std::uint8_t const **out_gpair, + std::size_t *out_n_bytes); - void* ProcessGHPairs(size_t *size, const std::vector& pairs) override; + // Histogram + void ResetHistContext(std::uint32_t const *cutptrs, std::size_t cutptr_len, + std::int32_t const *bin_idx, std::size_t n_idx); + void BuildEncryptedHistHori(double const *in_histogram, std::size_t len, + std::uint8_t **out_hist, std::size_t *out_len); + void SyncEncryptedHistHori(std::uint8_t const *buffer, std::size_t len, + double **out_hist, std::size_t *out_len); - void* HandleGHPairs(size_t *size, void *buffer, size_t buf_size) override; - - void InitAggregationContext(const std::vector &cuts, const std::vector &slots) override { - if (this->slots_.empty()) { - this->cuts_ = std::vector(cuts); - this->slots_ = std::vector(slots); - } else { - std::cout << "Multiple calls to InitAggregationContext" << std::endl; - } - } - - void *ProcessAggregation(size_t *size, std::map> nodes) override; - - std::vector HandleAggregation(void *buffer, size_t buf_size) override; - - void *ProcessHistograms(size_t *size, const std::vector& histograms) override; - - std::vector HandleHistograms(void *buffer, size_t buf_size) override; -}; \ No newline at end of file + void BuildEncryptedHistVert(std::size_t const **ridx, + std::size_t const *sizes, + std::int32_t const *nidx, std::size_t len, + std::uint8_t **out_hist, std::size_t *out_len); + void SyncEncryptedHistVert(std::uint8_t *hist_buffer, std::size_t len, + double **out, std::size_t *out_len); +}; +} // namespace nvflare diff --git a/integration/xgboost/processor/src/nvflare-plugin/nvflare_processor.cc b/integration/xgboost/processor/src/nvflare-plugin/nvflare_processor.cc index 749d8e98b5..3e742b14ef 100644 --- a/integration/xgboost/processor/src/nvflare-plugin/nvflare_processor.cc +++ b/integration/xgboost/processor/src/nvflare-plugin/nvflare_processor.cc @@ -13,181 +13,366 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include #include "nvflare_processor.h" -#include "dam.h" -const char kPluginName[] = "nvflare"; +#include "dam.h" // for DamEncoder +#include +#include // for copy_n, transform +#include // for memcpy +#include // for shared_ptr +#include // for invalid_argument +#include // for string_view +#include // for vector -using std::vector; -using std::cout; -using std::endl; +namespace nvflare { +namespace { +// The opaque type for the C handle. +using CHandleT = std::shared_ptr *; +// Actual representation used in C++ code base. +using HandleT = std::remove_pointer_t; -void* NVFlareProcessor::ProcessGHPairs(size_t *size, const std::vector& pairs) { - cout << "ProcessGHPairs called with pairs size: " << pairs.size() << endl; - gh_pairs_ = new std::vector(pairs); +std::string &GlobalErrorMsg() { + static thread_local std::string msg; + return msg; +} - DamEncoder encoder(kDataSetHGPairs); - encoder.AddFloatArray(pairs); - auto buffer = encoder.Finish(*size); +// Perform handle handling for C API functions. +template auto CApiGuard(FederatedPluginHandle handle, Fn &&fn) { + auto pptr = static_cast(handle); + if (!pptr) { + return 1; + } - return buffer; + try { + if constexpr (std::is_void_v>) { + fn(*pptr); + return 0; + } else { + return fn(*pptr); + } + } catch (std::exception const &e) { + GlobalErrorMsg() = e.what(); + return 1; + } } +} // namespace -void* NVFlareProcessor::HandleGHPairs(size_t *size, void *buffer, size_t buf_size) { - cout << "HandleGHPairs called with buffer size: " << buf_size << " Active: " << active_ << endl; - *size = buf_size; - return buffer; +TensealPlugin::TensealPlugin( + std::vector> const &args) { + if (!args.empty()) { + throw std::invalid_argument{"Invaid arguments for the tenseal plugin."}; + } } -void *NVFlareProcessor::ProcessAggregation(size_t *size, std::map> nodes) { - cout << "ProcessAggregation called with " << nodes.size() << " nodes" << endl; +void TensealPlugin::EncryptGPairs(float const *in_gpair, std::size_t n_in, + std::uint8_t **out_gpair, + std::size_t *n_out) { + std::vector pairs(n_in); + std::copy_n(in_gpair, n_in, pairs.begin()); + DamEncoder encoder(kDataSetHGPairs); + encoder.AddFloatArray(pairs); + encrypted_gpairs_ = encoder.Finish(*n_out); + if (!out_gpair) { + throw std::invalid_argument{"Invalid pointer to output gpair."}; + } + *out_gpair = encrypted_gpairs_.data(); + *n_out = encrypted_gpairs_.size(); +} - int64_t data_set_id; - if (!feature_sent_) { - data_set_id = kDataSetAggregationWithFeatures; - feature_sent_ = true; - } else { - data_set_id = kDataSetAggregation; - } +void TensealPlugin::SyncEncryptedGPairs(std::uint8_t const *in_gpair, + std::size_t n_bytes, + std::uint8_t const **out_gpair, + std::size_t *out_n_bytes) { + *out_n_bytes = n_bytes; + *out_gpair = in_gpair; +} - DamEncoder encoder(data_set_id); +void TensealPlugin::ResetHistContext(std::uint32_t const *cutptrs, + std::size_t cutptr_len, + std::int32_t const *bin_idx, + std::size_t n_idx) { + // fixme: this doesn't have to be copied multiple times. + this->cut_ptrs_.resize(cutptr_len); + std::copy_n(cutptrs, cutptr_len, cut_ptrs_.begin()); + this->bin_idx_.resize(n_idx); + std::copy_n(bin_idx, n_idx, this->bin_idx_.begin()); +} - // Add cuts pointers - vector cuts_vec; - for (auto value : cuts_) { - cuts_vec.push_back(value); - } - encoder.AddIntArray(cuts_vec); - - auto num_features = cuts_.size() - 1; - auto num_samples = slots_.size() / num_features; - cout << "Samples: " << num_samples << " Features: " << num_features << endl; - - if (data_set_id == kDataSetAggregationWithFeatures) { - if (features_.empty()) { - for (std::size_t f = 0; f < num_features; f++) { - auto slot = slots_[f]; - if (slot >= 0) { - features_.push_back(f); - } - } - } - cout << "Including feature size: " << features_.size() << endl; - encoder.AddIntArray(features_); - - vector bins; - for (int i = 0; i < num_samples; i++) { - for (auto f : features_) { - auto index = f + i * num_features; - if (index > slots_.size()) { - cout << "Index is out of range " << index << endl; - } - auto slot = slots_[index]; - bins.push_back(slot); - } - } - encoder.AddIntArray(bins); - } +void TensealPlugin::BuildEncryptedHistVert(std::size_t const **ridx, + std::size_t const *sizes, + std::int32_t const *nidx, + std::size_t len, + std::uint8_t** out_hist, + std::size_t* out_len) { + std::int64_t data_set_id; + if (!feature_sent_) { + data_set_id = kDataSetAggregationWithFeatures; + feature_sent_ = true; + } else { + data_set_id = kDataSetAggregation; + } - // Add nodes to build - vector node_vec; - for (const auto &kv : nodes) { - std::cout << "Node: " << kv.first << " Rows: " << kv.second.size() << std::endl; - node_vec.push_back(kv.first); - } - encoder.AddIntArray(node_vec); + DamEncoder encoder(data_set_id); + + // Add cuts pointers + std::vector cuts_vec(cut_ptrs_.cbegin(), cut_ptrs_.cend()); + encoder.AddIntArray(cuts_vec); + + auto num_features = cut_ptrs_.size() - 1; + auto num_samples = bin_idx_.size() / num_features; - // For each node, get the row_id/slot pair - for (const auto &kv : nodes) { - vector rows; - for (auto row : kv.second) { - rows.push_back(row); + if (data_set_id == kDataSetAggregationWithFeatures) { + if (features_.empty()) { // when is it not empty? + for (std::size_t f = 0; f < num_features; f++) { + auto slot = bin_idx_[f]; + if (slot >= 0) { + // what happens if it's missing? + features_.push_back(f); } - encoder.AddIntArray(rows); + } } + encoder.AddIntArray(features_); - auto buffer = encoder.Finish(*size); - return buffer; -} - -std::vector NVFlareProcessor::HandleAggregation(void *buffer, size_t buf_size) { - cout << "HandleAggregation called with buffer size: " << buf_size << endl; - auto remaining = buf_size; - char *pointer = reinterpret_cast(buffer); - - // The buffer is concatenated by AllGather. It may contain multiple DAM buffers - std::vector result; - auto max_slot = cuts_.back(); - auto array_size = 2 * max_slot * sizeof(double); - double *slots = static_cast(malloc(array_size)); - while (remaining > kPrefixLen) { - DamDecoder decoder(reinterpret_cast(pointer), remaining); - if (!decoder.IsValid()) { - cout << "Not DAM encoded buffer ignored at offset: " - << static_cast((pointer - reinterpret_cast(buffer))) << endl; - break; + std::vector bins; + for (int i = 0; i < num_samples; i++) { + for (auto f : features_) { + auto index = f + i * num_features; + if (index > bin_idx_.size()) { + throw std::out_of_range{"Index is out of range: " + + std::to_string(index)}; } - auto size = decoder.Size(); - auto node_list = decoder.DecodeIntArray(); - for (auto node : node_list) { - memset(slots, 0, array_size); - auto feature_list = decoder.DecodeIntArray(); - // Convert per-feature histo to a flat one - for (auto f : feature_list) { - auto base = cuts_[f]; - auto bins = decoder.DecodeFloatArray(); - auto n = bins.size() / 2; - for (int i = 0; i < n; i++) { - auto index = base + i; - slots[2 * index] += bins[2 * i]; - slots[2 * index + 1] += bins[2 * i + 1]; - } - } - result.insert(result.end(), slots, slots + 2 * max_slot); - } - remaining -= size; - pointer += size; + auto slot = bin_idx_[index]; + bins.push_back(slot); + } } - free(slots); + encoder.AddIntArray(bins); + } - return result; -} + // Add nodes to build + std::vector node_vec(len); + std::copy_n(nidx, len, node_vec.begin()); + encoder.AddIntArray(node_vec); + + // For each node, get the row_id/slot pair + for (std::size_t i = 0; i < len; ++i) { + std::vector rows(sizes[i]); + std::copy_n(ridx[i], sizes[i], rows.begin()); + encoder.AddIntArray(rows); + } -void *NVFlareProcessor::ProcessHistograms(size_t *size, const std::vector& histograms) { - cout << "ProcessHistograms called with " << histograms.size() << " entries" << endl; + std::size_t n{0}; + encrypted_hist_ = encoder.Finish(n); - DamEncoder encoder(kDataSetHistograms); - encoder.AddFloatArray(histograms); - return encoder.Finish(*size); + *out_hist = encrypted_hist_.data(); + *out_len = encrypted_hist_.size(); } -std::vector NVFlareProcessor::HandleHistograms(void *buffer, size_t buf_size) { - cout << "HandleHistograms called with buffer size: " << buf_size << endl; +void TensealPlugin::SyncEncryptedHistVert(std::uint8_t *buffer, + std::size_t buf_size, double **out, + std::size_t *out_len) { + auto remaining = buf_size; + char *pointer = reinterpret_cast(buffer); - DamDecoder decoder(reinterpret_cast(buffer), buf_size); + // The buffer is concatenated by AllGather. It may contain multiple DAM + // buffers + std::vector &result = hist_; + result.clear(); + auto max_slot = cut_ptrs_.back(); + auto array_size = 2 * max_slot * sizeof(double); + // A new histogram array? + double *slots = static_cast(malloc(array_size)); + while (remaining > kPrefixLen) { + DamDecoder decoder(reinterpret_cast(pointer), remaining); if (!decoder.IsValid()) { - cout << "Not DAM encoded buffer, ignored" << endl; - return std::vector(); + std::cout << "Not DAM encoded buffer ignored at offset: " + << static_cast( + (pointer - reinterpret_cast(buffer))) + << std::endl; + break; } - - if (decoder.GetDataSetId() != kDataSetHistogramResult) { - cout << "Invalid dataset: " << decoder.GetDataSetId() << endl; - return std::vector(); + auto size = decoder.Size(); + auto node_list = decoder.DecodeIntArray(); + for (auto node : node_list) { + std::memset(slots, 0, array_size); + auto feature_list = decoder.DecodeIntArray(); + // Convert per-feature histo to a flat one + for (auto f : feature_list) { + auto base = cut_ptrs_[f]; // cut pointer for the current feature + auto bins = decoder.DecodeFloatArray(); + auto n = bins.size() / 2; + for (int i = 0; i < n; i++) { + auto index = base + i; + // [Q] Build local histogram? Why does it need to be built here? + slots[2 * index] += bins[2 * i]; + slots[2 * index + 1] += bins[2 * i + 1]; + } + } + result.insert(result.end(), slots, slots + 2 * max_slot); } + remaining -= size; + pointer += size; + } + free(slots); - return decoder.DecodeFloatArray(); + *out_len = result.size(); + *out = result.data(); } +void TensealPlugin::BuildEncryptedHistHori(double const *in_histogram, + std::size_t len, + std::uint8_t **out_hist, + std::size_t *out_len) { + DamEncoder encoder(kDataSetHistograms); + std::vector copy(in_histogram, in_histogram + len); + encoder.AddFloatArray(copy); + + std::size_t size{0}; + this->encrypted_hist_ = encoder.Finish(size); + + *out_hist = this->encrypted_hist_.data(); + *out_len = this->encrypted_hist_.size(); +} + +void TensealPlugin::SyncEncryptedHistHori(std::uint8_t const *buffer, + std::size_t len, double **out_hist, + std::size_t *out_len) { + DamDecoder decoder(reinterpret_cast(buffer), len); + if (!decoder.IsValid()) { + std::cout << "Not DAM encoded buffer, ignored" << std::endl; + } + + if (decoder.GetDataSetId() != kDataSetHistogramResult) { + throw std::runtime_error{"Invalid dataset: " + + std::to_string(decoder.GetDataSetId())}; + } + this->hist_ = decoder.DecodeFloatArray(); + *out_hist = this->hist_.data(); + *out_len = this->hist_.size(); +} +} // namespace nvflare + +#if defined(_MSC_VER) || defined(_WIN32) +#define NVF_C __declspec(dllexport) +#else +#define NVF_C __attribute__((visibility("default"))) +#endif // defined(_MSC_VER) || defined(_WIN32) + extern "C" { +NVF_C char const *FederatedPluginErrorMsg() { + return nvflare::GlobalErrorMsg().c_str(); +} -processing::Processor *LoadProcessor(char *plugin_name) { - if (strcasecmp(plugin_name, kPluginName) != 0) { - cout << "Unknown plugin name: " << plugin_name << endl; - return nullptr; - } +FederatedPluginHandle NVF_C FederatedPluginCreate(int argc, char const **argv) { + using namespace nvflare; + try { + CHandleT pptr = new std::shared_ptr; + std::vector> args; + std::transform( + argv, argv + argc, std::back_inserter(args), [](char const *carg) { + // Split a key value pair in contructor argument: `key=value` + std::string_view arg{carg}; + auto idx = arg.find('='); + if (idx == std::string_view::npos) { + // `=` not found + throw std::invalid_argument{"Invalid argument:" + std::string{arg}}; + } + auto key = arg.substr(0, idx); + auto value = arg.substr(idx + 1); + return std::make_pair(key, value); + }); + *pptr = std::make_shared(args); + return pptr; + } catch (std::exception const &e) { + GlobalErrorMsg() = e.what(); + return nullptr; + } +} + +int NVF_C FederatedPluginClose(FederatedPluginHandle handle) { + using namespace nvflare; + auto pptr = static_cast(handle); + if (!pptr) { + return 1; + } + + try { + delete pptr; + } catch (std::exception const &e) { + GlobalErrorMsg() = e.what(); + return 1; + } + return 0; +} - return new NVFlareProcessor(); +int NVF_C FederatedPluginEncryptGPairs(FederatedPluginHandle handle, + float const *in_gpair, size_t n_in, + uint8_t **out_gpair, size_t *n_out) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT plugin) { + plugin->EncryptGPairs(in_gpair, n_in, out_gpair, n_out); + return 0; + }); } -} // extern "C" +int NVF_C FederatedPluginSyncEncryptedGPairs(FederatedPluginHandle handle, + uint8_t const *in_gpair, + size_t n_bytes, + uint8_t const **out_gpair, + size_t *n_out) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT plugin) { + plugin->SyncEncryptedGPairs(in_gpair, n_bytes, out_gpair, n_out); + }); +} + +int NVF_C FederatedPluginResetHistContextVert(FederatedPluginHandle handle, + uint32_t const *cutptrs, + size_t cutptr_len, + int32_t const *bin_idx, + size_t n_idx) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT plugin) { + plugin->ResetHistContext(cutptrs, cutptr_len, bin_idx, n_idx); + }); +} + +int NVF_C FederatedPluginBuildEncryptedHistVert( + FederatedPluginHandle handle, uint64_t const **ridx, size_t const *sizes, + int32_t const *nidx, size_t len, uint8_t **out_hist, size_t *out_len) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT plugin) { + plugin->BuildEncryptedHistVert(ridx, sizes, nidx, len, out_hist, out_len); + }); +} + +int NVF_C FederatedPluginSyncEnrcyptedHistVert(FederatedPluginHandle handle, + uint8_t *in_hist, size_t len, + double **out_hist, + size_t *out_len) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT plugin) { + plugin->SyncEncryptedHistVert(in_hist, len, out_hist, out_len); + }); +} + +int NVF_C FederatedPluginBuildEncryptedHistHori(FederatedPluginHandle handle, + double const *in_hist, + size_t len, uint8_t **out_hist, + size_t *out_len) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT plugin) { + plugin->BuildEncryptedHistHori(in_hist, len, out_hist, out_len); + }); +} + +int NVF_C FederatedPluginSyncEnrcyptedHistHori(FederatedPluginHandle handle, + uint8_t const *in_hist, + size_t len, double **out_hist, + size_t *out_len) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT plugin) { + plugin->SyncEncryptedHistHori(in_hist, len, out_hist, out_len); + return 0; + }); +} +} // extern "C" diff --git a/integration/xgboost/processor/tests/test_dam.cc b/integration/xgboost/processor/tests/test_dam.cc index 1cf5c151fa..5573d5440d 100644 --- a/integration/xgboost/processor/tests/test_dam.cc +++ b/integration/xgboost/processor/tests/test_dam.cc @@ -29,7 +29,7 @@ TEST(DamTest, TestEncodeDecode) { auto buf = encoder.Finish(size); std::cout << "Encoded size is " << size << std::endl; - DamDecoder decoder(buf, size); + DamDecoder decoder(buf.data(), size); EXPECT_EQ(decoder.IsValid(), true); EXPECT_EQ(decoder.GetDataSetId(), 123); diff --git a/integration/xgboost/processor/tests/test_tenseal.py b/integration/xgboost/processor/tests/test_tenseal.py new file mode 100644 index 0000000000..ace7699873 --- /dev/null +++ b/integration/xgboost/processor/tests/test_tenseal.py @@ -0,0 +1,106 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import ctypes +import os +from contextlib import contextmanager + +import numpy as np +from typing import Generator, Tuple + + +def _check_call(rc: int) -> None: + assert rc == 0 + + +plugin_path = os.path.join( + os.path.dirname(os.path.normpath(os.path.abspath(__file__))), os.pardir, "build", "libproc_nvflare.so" +) + + +@contextmanager +def load_plugin() -> Generator[Tuple[ctypes.CDLL, ctypes.c_void_p], None, None]: + nvflare = ctypes.cdll.LoadLibrary(plugin_path) + nvflare.FederatedPluginCreate.restype = ctypes.c_void_p + nvflare.FederatedPluginErrorMsg.restype = ctypes.c_char_p + handle = ctypes.c_void_p(nvflare.FederatedPluginCreate(ctypes.c_int(0), None)) + try: + yield nvflare, handle + finally: + _check_call(nvflare.FederatedPluginClose(handle)) + + +def test_load() -> None: + with load_plugin() as nvflare: + pass + + +def test_grad() -> None: + array = np.arange(16, dtype=np.float32) + out = ctypes.POINTER(ctypes.c_uint8)() + out_len = ctypes.c_size_t() + + with load_plugin() as (nvflare, handle): + _check_call( + nvflare.FederatedPluginEncryptGPairs( + handle, + array.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), + array.size, + ctypes.byref(out), + ctypes.byref(out_len), + ) + ) + + out1 = ctypes.POINTER(ctypes.c_uint8)() + out_len1 = ctypes.c_size_t() + + _check_call( + nvflare.FederatedPluginSyncEncryptedGPairs( + handle, + out, + out_len, + ctypes.byref(out1), + ctypes.byref(out_len1), + ) + ) + +def test_hori() -> None: + array = np.arange(16, dtype=np.float32) + # This is a DAM, we might use the Python DAM class to verify its content + out = ctypes.POINTER(ctypes.c_uint8)() + out_len = ctypes.c_size_t() + + with load_plugin() as (nvflare, handle): + _check_call( + nvflare.FederatedPluginBuildEncryptedHistHori( + handle, + array.ctypes.data_as(ctypes.POINTER(ctypes.c_double)), + array.size, + ctypes.byref(out), + ctypes.byref(out_len), + ) + ) + + out1 = ctypes.POINTER(ctypes.c_double)() + out_len1 = ctypes.c_size_t() + + nvflare.FederatedPluginSyncEnrcyptedHistHori( + handle, + out, + out_len, + ctypes.byref(out1), + ctypes.byref(out_len1), + ) + # Needs the GRPC server to process the message. + msg = nvflare.FederatedPluginErrorMsg().decode("utf-8") + assert msg.find("Invalid dataset") != -1 diff --git a/nvflare/app_opt/xgboost/data_loader.py b/nvflare/app_opt/xgboost/data_loader.py index 2fa8855c99..d9a56552bf 100644 --- a/nvflare/app_opt/xgboost/data_loader.py +++ b/nvflare/app_opt/xgboost/data_loader.py @@ -21,7 +21,7 @@ class XGBDataLoader(ABC): @abstractmethod - def load_data(self, client_id: str) -> Tuple[xgb.core.DMatrix, xgb.core.DMatrix]: + def load_data(self, client_id: str) -> Tuple[xgb.DMatrix, xgb.DMatrix]: """Loads data for xgboost. Returns: diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated.proto b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated.proto index a37e63526b..fbc2adf503 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated.proto +++ b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated.proto @@ -1,10 +1,9 @@ /*! - * Copyright 2022 XGBoost contributors - * This is federated.old.proto from XGBoost + * Copyright 2022-2023 XGBoost contributors */ syntax = "proto3"; -package xgboost.federated; +package xgboost.collective.federated; service Federated { rpc Allgather(AllgatherRequest) returns (AllgatherReply) {} @@ -14,14 +13,18 @@ service Federated { } enum DataType { - INT8 = 0; - UINT8 = 1; - INT32 = 2; - UINT32 = 3; - INT64 = 4; - UINT64 = 5; - FLOAT = 6; - DOUBLE = 7; + HALF = 0; + FLOAT = 1; + DOUBLE = 2; + LONG_DOUBLE = 3; + INT8 = 4; + INT16 = 5; + INT32 = 6; + INT64 = 7; + UINT8 = 8; + UINT16 = 9; + UINT32 = 10; + UINT64 = 11; } enum ReduceOperation { diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.py b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.py index 6c77fc334e..e69d5d5e07 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.py @@ -28,33 +28,33 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x66\x65\x64\x65rated.proto\x12\x11xgboost.federated\"N\n\x10\x41llgatherRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\"(\n\x0e\x41llgatherReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"O\n\x11\x41llgatherVRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\")\n\x0f\x41llgatherVReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"\xbc\x01\n\x10\x41llreduceRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\x12.\n\tdata_type\x18\x04 \x01(\x0e\x32\x1b.xgboost.federated.DataType\x12<\n\x10reduce_operation\x18\x05 \x01(\x0e\x32\".xgboost.federated.ReduceOperation\"(\n\x0e\x41llreduceReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"\\\n\x10\x42roadcastRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\x12\x0c\n\x04root\x18\x04 \x01(\x05\"(\n\x0e\x42roadcastReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c*d\n\x08\x44\x61taType\x12\x08\n\x04INT8\x10\x00\x12\t\n\x05UINT8\x10\x01\x12\t\n\x05INT32\x10\x02\x12\n\n\x06UINT32\x10\x03\x12\t\n\x05INT64\x10\x04\x12\n\n\x06UINT64\x10\x05\x12\t\n\x05\x46LOAT\x10\x06\x12\n\n\x06\x44OUBLE\x10\x07*^\n\x0fReduceOperation\x12\x07\n\x03MAX\x10\x00\x12\x07\n\x03MIN\x10\x01\x12\x07\n\x03SUM\x10\x02\x12\x0f\n\x0b\x42ITWISE_AND\x10\x03\x12\x0e\n\nBITWISE_OR\x10\x04\x12\x0f\n\x0b\x42ITWISE_XOR\x10\x05\x32\xea\x02\n\tFederated\x12U\n\tAllgather\x12#.xgboost.federated.AllgatherRequest\x1a!.xgboost.federated.AllgatherReply\"\x00\x12X\n\nAllgatherV\x12$.xgboost.federated.AllgatherVRequest\x1a\".xgboost.federated.AllgatherVReply\"\x00\x12U\n\tAllreduce\x12#.xgboost.federated.AllreduceRequest\x1a!.xgboost.federated.AllreduceReply\"\x00\x12U\n\tBroadcast\x12#.xgboost.federated.BroadcastRequest\x1a!.xgboost.federated.BroadcastReply\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x66\x65\x64\x65rated.proto\x12\x1cxgboost.collective.federated\"N\n\x10\x41llgatherRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\"(\n\x0e\x41llgatherReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"O\n\x11\x41llgatherVRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\")\n\x0f\x41llgatherVReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"\xd2\x01\n\x10\x41llreduceRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\x12\x39\n\tdata_type\x18\x04 \x01(\x0e\x32&.xgboost.collective.federated.DataType\x12G\n\x10reduce_operation\x18\x05 \x01(\x0e\x32-.xgboost.collective.federated.ReduceOperation\"(\n\x0e\x41llreduceReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"\\\n\x10\x42roadcastRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\x12\x0c\n\x04root\x18\x04 \x01(\x05\"(\n\x0e\x42roadcastReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c*\x96\x01\n\x08\x44\x61taType\x12\x08\n\x04HALF\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\n\n\x06\x44OUBLE\x10\x02\x12\x0f\n\x0bLONG_DOUBLE\x10\x03\x12\x08\n\x04INT8\x10\x04\x12\t\n\x05INT16\x10\x05\x12\t\n\x05INT32\x10\x06\x12\t\n\x05INT64\x10\x07\x12\t\n\x05UINT8\x10\x08\x12\n\n\x06UINT16\x10\t\x12\n\n\x06UINT32\x10\n\x12\n\n\x06UINT64\x10\x0b*^\n\x0fReduceOperation\x12\x07\n\x03MAX\x10\x00\x12\x07\n\x03MIN\x10\x01\x12\x07\n\x03SUM\x10\x02\x12\x0f\n\x0b\x42ITWISE_AND\x10\x03\x12\x0e\n\nBITWISE_OR\x10\x04\x12\x0f\n\x0b\x42ITWISE_XOR\x10\x05\x32\xc2\x03\n\tFederated\x12k\n\tAllgather\x12..xgboost.collective.federated.AllgatherRequest\x1a,.xgboost.collective.federated.AllgatherReply\"\x00\x12n\n\nAllgatherV\x12/.xgboost.collective.federated.AllgatherVRequest\x1a-.xgboost.collective.federated.AllgatherVReply\"\x00\x12k\n\tAllreduce\x12..xgboost.collective.federated.AllreduceRequest\x1a,.xgboost.collective.federated.AllreduceReply\"\x00\x12k\n\tBroadcast\x12..xgboost.collective.federated.BroadcastRequest\x1a,.xgboost.collective.federated.BroadcastReply\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'federated_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_DATATYPE']._serialized_start=653 - _globals['_DATATYPE']._serialized_end=753 - _globals['_REDUCEOPERATION']._serialized_start=755 - _globals['_REDUCEOPERATION']._serialized_end=849 - _globals['_ALLGATHERREQUEST']._serialized_start=38 - _globals['_ALLGATHERREQUEST']._serialized_end=116 - _globals['_ALLGATHERREPLY']._serialized_start=118 - _globals['_ALLGATHERREPLY']._serialized_end=158 - _globals['_ALLGATHERVREQUEST']._serialized_start=160 - _globals['_ALLGATHERVREQUEST']._serialized_end=239 - _globals['_ALLGATHERVREPLY']._serialized_start=241 - _globals['_ALLGATHERVREPLY']._serialized_end=282 - _globals['_ALLREDUCEREQUEST']._serialized_start=285 - _globals['_ALLREDUCEREQUEST']._serialized_end=473 - _globals['_ALLREDUCEREPLY']._serialized_start=475 - _globals['_ALLREDUCEREPLY']._serialized_end=515 - _globals['_BROADCASTREQUEST']._serialized_start=517 - _globals['_BROADCASTREQUEST']._serialized_end=609 - _globals['_BROADCASTREPLY']._serialized_start=611 - _globals['_BROADCASTREPLY']._serialized_end=651 - _globals['_FEDERATED']._serialized_start=852 - _globals['_FEDERATED']._serialized_end=1214 + _globals['_DATATYPE']._serialized_start=687 + _globals['_DATATYPE']._serialized_end=837 + _globals['_REDUCEOPERATION']._serialized_start=839 + _globals['_REDUCEOPERATION']._serialized_end=933 + _globals['_ALLGATHERREQUEST']._serialized_start=49 + _globals['_ALLGATHERREQUEST']._serialized_end=127 + _globals['_ALLGATHERREPLY']._serialized_start=129 + _globals['_ALLGATHERREPLY']._serialized_end=169 + _globals['_ALLGATHERVREQUEST']._serialized_start=171 + _globals['_ALLGATHERVREQUEST']._serialized_end=250 + _globals['_ALLGATHERVREPLY']._serialized_start=252 + _globals['_ALLGATHERVREPLY']._serialized_end=293 + _globals['_ALLREDUCEREQUEST']._serialized_start=296 + _globals['_ALLREDUCEREQUEST']._serialized_end=506 + _globals['_ALLREDUCEREPLY']._serialized_start=508 + _globals['_ALLREDUCEREPLY']._serialized_end=548 + _globals['_BROADCASTREQUEST']._serialized_start=550 + _globals['_BROADCASTREQUEST']._serialized_end=642 + _globals['_BROADCASTREPLY']._serialized_start=644 + _globals['_BROADCASTREPLY']._serialized_end=684 + _globals['_FEDERATED']._serialized_start=936 + _globals['_FEDERATED']._serialized_end=1386 # @@protoc_insertion_point(module_scope) diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.pyi b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.pyi index 750db95a25..7ad47596df 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.pyi +++ b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.pyi @@ -6,32 +6,40 @@ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union DESCRIPTOR: _descriptor.FileDescriptor class DataType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = () + __slots__ = [] + HALF: _ClassVar[DataType] + FLOAT: _ClassVar[DataType] + DOUBLE: _ClassVar[DataType] + LONG_DOUBLE: _ClassVar[DataType] INT8: _ClassVar[DataType] - UINT8: _ClassVar[DataType] + INT16: _ClassVar[DataType] INT32: _ClassVar[DataType] - UINT32: _ClassVar[DataType] INT64: _ClassVar[DataType] + UINT8: _ClassVar[DataType] + UINT16: _ClassVar[DataType] + UINT32: _ClassVar[DataType] UINT64: _ClassVar[DataType] - FLOAT: _ClassVar[DataType] - DOUBLE: _ClassVar[DataType] class ReduceOperation(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = () + __slots__ = [] MAX: _ClassVar[ReduceOperation] MIN: _ClassVar[ReduceOperation] SUM: _ClassVar[ReduceOperation] BITWISE_AND: _ClassVar[ReduceOperation] BITWISE_OR: _ClassVar[ReduceOperation] BITWISE_XOR: _ClassVar[ReduceOperation] +HALF: DataType +FLOAT: DataType +DOUBLE: DataType +LONG_DOUBLE: DataType INT8: DataType -UINT8: DataType +INT16: DataType INT32: DataType -UINT32: DataType INT64: DataType +UINT8: DataType +UINT16: DataType +UINT32: DataType UINT64: DataType -FLOAT: DataType -DOUBLE: DataType MAX: ReduceOperation MIN: ReduceOperation SUM: ReduceOperation @@ -40,7 +48,7 @@ BITWISE_OR: ReduceOperation BITWISE_XOR: ReduceOperation class AllgatherRequest(_message.Message): - __slots__ = ("sequence_number", "rank", "send_buffer") + __slots__ = ["sequence_number", "rank", "send_buffer"] SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] RANK_FIELD_NUMBER: _ClassVar[int] SEND_BUFFER_FIELD_NUMBER: _ClassVar[int] @@ -50,13 +58,13 @@ class AllgatherRequest(_message.Message): def __init__(self, sequence_number: _Optional[int] = ..., rank: _Optional[int] = ..., send_buffer: _Optional[bytes] = ...) -> None: ... class AllgatherReply(_message.Message): - __slots__ = ("receive_buffer",) + __slots__ = ["receive_buffer"] RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int] receive_buffer: bytes def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ... class AllgatherVRequest(_message.Message): - __slots__ = ("sequence_number", "rank", "send_buffer") + __slots__ = ["sequence_number", "rank", "send_buffer"] SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] RANK_FIELD_NUMBER: _ClassVar[int] SEND_BUFFER_FIELD_NUMBER: _ClassVar[int] @@ -66,13 +74,13 @@ class AllgatherVRequest(_message.Message): def __init__(self, sequence_number: _Optional[int] = ..., rank: _Optional[int] = ..., send_buffer: _Optional[bytes] = ...) -> None: ... class AllgatherVReply(_message.Message): - __slots__ = ("receive_buffer",) + __slots__ = ["receive_buffer"] RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int] receive_buffer: bytes def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ... class AllreduceRequest(_message.Message): - __slots__ = ("sequence_number", "rank", "send_buffer", "data_type", "reduce_operation") + __slots__ = ["sequence_number", "rank", "send_buffer", "data_type", "reduce_operation"] SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] RANK_FIELD_NUMBER: _ClassVar[int] SEND_BUFFER_FIELD_NUMBER: _ClassVar[int] @@ -86,13 +94,13 @@ class AllreduceRequest(_message.Message): def __init__(self, sequence_number: _Optional[int] = ..., rank: _Optional[int] = ..., send_buffer: _Optional[bytes] = ..., data_type: _Optional[_Union[DataType, str]] = ..., reduce_operation: _Optional[_Union[ReduceOperation, str]] = ...) -> None: ... class AllreduceReply(_message.Message): - __slots__ = ("receive_buffer",) + __slots__ = ["receive_buffer"] RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int] receive_buffer: bytes def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ... class BroadcastRequest(_message.Message): - __slots__ = ("sequence_number", "rank", "send_buffer", "root") + __slots__ = ["sequence_number", "rank", "send_buffer", "root"] SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] RANK_FIELD_NUMBER: _ClassVar[int] SEND_BUFFER_FIELD_NUMBER: _ClassVar[int] @@ -104,7 +112,7 @@ class BroadcastRequest(_message.Message): def __init__(self, sequence_number: _Optional[int] = ..., rank: _Optional[int] = ..., send_buffer: _Optional[bytes] = ..., root: _Optional[int] = ...) -> None: ... class BroadcastReply(_message.Message): - __slots__ = ("receive_buffer",) + __slots__ = ["receive_buffer"] RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int] receive_buffer: bytes def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ... diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2_grpc.py b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2_grpc.py index 36bbbbea0c..45eee5c8dd 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2_grpc.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2_grpc.py @@ -29,22 +29,22 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.Allgather = channel.unary_unary( - '/xgboost.federated.Federated/Allgather', + '/xgboost.collective.federated.Federated/Allgather', request_serializer=federated__pb2.AllgatherRequest.SerializeToString, response_deserializer=federated__pb2.AllgatherReply.FromString, ) self.AllgatherV = channel.unary_unary( - '/xgboost.federated.Federated/AllgatherV', + '/xgboost.collective.federated.Federated/AllgatherV', request_serializer=federated__pb2.AllgatherVRequest.SerializeToString, response_deserializer=federated__pb2.AllgatherVReply.FromString, ) self.Allreduce = channel.unary_unary( - '/xgboost.federated.Federated/Allreduce', + '/xgboost.collective.federated.Federated/Allreduce', request_serializer=federated__pb2.AllreduceRequest.SerializeToString, response_deserializer=federated__pb2.AllreduceReply.FromString, ) self.Broadcast = channel.unary_unary( - '/xgboost.federated.Federated/Broadcast', + '/xgboost.collective.federated.Federated/Broadcast', request_serializer=federated__pb2.BroadcastRequest.SerializeToString, response_deserializer=federated__pb2.BroadcastReply.FromString, ) @@ -102,7 +102,7 @@ def add_FederatedServicer_to_server(servicer, server): ), } generic_handler = grpc.method_handlers_generic_handler( - 'xgboost.federated.Federated', rpc_method_handlers) + 'xgboost.collective.federated.Federated', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) @@ -121,7 +121,7 @@ def Allgather(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/xgboost.federated.Federated/Allgather', + return grpc.experimental.unary_unary(request, target, '/xgboost.collective.federated.Federated/Allgather', federated__pb2.AllgatherRequest.SerializeToString, federated__pb2.AllgatherReply.FromString, options, channel_credentials, @@ -138,7 +138,7 @@ def AllgatherV(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/xgboost.federated.Federated/AllgatherV', + return grpc.experimental.unary_unary(request, target, '/xgboost.collective.federated.Federated/AllgatherV', federated__pb2.AllgatherVRequest.SerializeToString, federated__pb2.AllgatherVReply.FromString, options, channel_credentials, @@ -155,7 +155,7 @@ def Allreduce(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/xgboost.federated.Federated/Allreduce', + return grpc.experimental.unary_unary(request, target, '/xgboost.collective.federated.Federated/Allreduce', federated__pb2.AllreduceRequest.SerializeToString, federated__pb2.AllreduceReply.FromString, options, channel_credentials, @@ -172,7 +172,7 @@ def Broadcast(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/xgboost.federated.Federated/Broadcast', + return grpc.experimental.unary_unary(request, target, '/xgboost.collective.federated.Federated/Broadcast', federated__pb2.BroadcastRequest.SerializeToString, federated__pb2.BroadcastReply.FromString, options, channel_credentials, diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/proto/gen_proto.sh b/nvflare/app_opt/xgboost/histogram_based_v2/proto/gen_proto.sh index 10afcf5b3b..f174f5d30f 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/proto/gen_proto.sh +++ b/nvflare/app_opt/xgboost/histogram_based_v2/proto/gen_proto.sh @@ -1 +1,6 @@ +#!/usr/bin/env sh +# Install grpcio-tools: +# pip install grpcio-tools +# or +# mamba install grpcio-tools python -m grpc_tools.protoc -I. --python_out=. --pyi_out=. --grpc_python_out=. federated.proto diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py b/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py index f3a9dbc905..35627bc3cd 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py @@ -72,7 +72,7 @@ def initialize(self, fl_ctx: FLContext): if not isinstance(self._metrics_writer, LogWriter): self.system_panic("writer should be type LogWriter", fl_ctx) - def _xgb_train(self, params: XGBoostParams, train_data, val_data) -> xgb.core.Booster: + def _xgb_train(self, params: XGBoostParams, train_data: xgb.DMatrix, val_data) -> xgb.core.Booster: """XGBoost training logic. Args: @@ -129,14 +129,14 @@ def run(self, ctx: dict): self.logger.info(f"server address is {self._server_addr}") communicator_env = { - "xgboost_communicator": "federated", + "dmlc_communicator": "federated", "federated_server_address": f"{self._server_addr}", "federated_world_size": self._world_size, "federated_rank": self._rank, - "plugin_name": "nvflare", - "loader_params": { - "LIBRARY_PATH": "/tmp", - }, + # FIXME: It should be possible to customize this or find a better location + # to distribut the shared object, preferably along side the nvflare Python + # package. + "federated_plugin": {"path": "/tmp/libproc_nvflare.so"}, } with xgb.collective.CommunicatorContext(**communicator_env): # Load the data. Dmatrix must be created with column split mode in CommunicatorContext for vertical FL diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_server_runner.py b/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_server_runner.py index 32e708c90e..4f7752faee 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_server_runner.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_server_runner.py @@ -30,7 +30,7 @@ def run(self, ctx: dict): xgb_federated.run_federated_server( port=self._port, - world_size=self._world_size, + n_workers=self._world_size, ) self._stopped = True