Skip to content

Commit

Permalink
Create C wrapper for the xgboost plugin.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jun 12, 2024
1 parent c16f453 commit 4a65a60
Show file tree
Hide file tree
Showing 18 changed files with 514 additions and 325 deletions.
24 changes: 9 additions & 15 deletions integration/xgboost/processor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,22 @@ set(CMAKE_BUILD_TYPE Debug)

option(GOOGLE_TEST "Build google tests" OFF)

file(GLOB_RECURSE LIB_SRC
"src/*.h"
"src/*.cc"
)
file(GLOB_RECURSE LIB_SRC "src/*.cc")

add_library(proc_nvflare SHARED ${LIB_SRC})
set(XGB_SRC ${proc_nvflare_SOURCE_DIR}/../../../../xgboost)
target_include_directories(proc_nvflare PRIVATE ${proc_nvflare_SOURCE_DIR}/src/include
${XGB_SRC}/src
${XGB_SRC}/rabit/include
${XGB_SRC}/include
${XGB_SRC}/dmlc-core/include)

link_directories(${XGB_SRC}/lib/)
set_target_properties(proc_nvflare PROPERTIES
CXX_STANDARD 17
CXX_STANDARD_REQUIRED ON
POSITION_INDEPENDENT_CODE ON
ENABLE_EXPORTS ON
)
target_include_directories(proc_nvflare PRIVATE ${proc_nvflare_SOURCE_DIR}/src/include)

if (APPLE)
add_link_options("LINKER:-object_path_lto,$<TARGET_PROPERTY:NAME>_lto.o")
add_link_options("LINKER:-cache_path_lto,${CMAKE_BINARY_DIR}/LTOCache")
endif ()

target_link_libraries(proc_nvflare ${XGB_SRC}/lib/libxgboost${CMAKE_SHARED_LIBRARY_SUFFIX})

#-- Unit Tests
if(GOOGLE_TEST)
find_package(GTest REQUIRED)
Expand All @@ -49,4 +43,4 @@ if(GOOGLE_TEST)
COMMAND proc_test
WORKING_DIRECTORY ${proc_nvflare_BINARY_DIR})

endif()
endif()
5 changes: 1 addition & 4 deletions integration/xgboost/processor/src/README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
# encoding-plugins
Processor Plugin for NVFlare

This plugin is a companion for NVFlare based encryption, it processes the data so it can
This plugin is a companion for NVFlare based encryption, it processes the data so it can
be properly decoded by Python code running on NVFlare.

All the encryption is happening on the local GRPC client/server so no encryption is needed
in this plugin.



5 changes: 1 addition & 4 deletions integration/xgboost/processor/src/dam/README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
# DAM (Direct-Accessible Marshaller)

A simple serialization library that doesn't have dependencies, and the data
A simple serialization library that doesn't have dependencies, and the data
is directly accessible in C/C++ without copying.

To make the data accessible in C, following rules must be followed,

1. Numeric values must be stored in native byte-order.
2. Numeric values must start at the 64-bit boundaries (8-bytes)



66 changes: 33 additions & 33 deletions integration/xgboost/processor/src/dam/dam.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,14 @@ void print_buffer(uint8_t *buffer, int size) {

// DamEncoder ======
void DamEncoder::AddFloatArray(const std::vector<double> &value) {
if (encoded) {
std::cout << "Buffer is already encoded" << std::endl;
return;
}
auto buf_size = value.size()*8;
uint8_t *buffer = static_cast<uint8_t *>(malloc(buf_size));
memcpy(buffer, value.data(), buf_size);
// print_buffer(reinterpret_cast<uint8_t *>(value.data()), value.size() * 8);
entries->push_back(new Entry(kDataTypeFloatArray, buffer, value.size()));
if (encoded) {
std::cout << "Buffer is already encoded" << std::endl;
return;
}
auto buf_size = value.size() * 8;
uint8_t *buffer = static_cast<uint8_t *>(malloc(buf_size));
memcpy(buffer, value.data(), buf_size);
entries->push_back(new Entry(kDataTypeFloatArray, buffer, value.size()));
}

void DamEncoder::AddIntArray(const std::vector<int64_t> &value) {
Expand All @@ -52,15 +51,15 @@ void DamEncoder::AddIntArray(const std::vector<int64_t> &value) {
entries->push_back(new Entry(kDataTypeIntArray, buffer, value.size()));
}

std::uint8_t * DamEncoder::Finish(size_t &size) {
std::vector<std::uint8_t> DamEncoder::Finish(size_t &size) {
encoded = true;

size = calculate_size();
auto buf = static_cast<uint8_t *>(malloc(size));
auto pointer = buf;
std::vector<std::uint8_t> buf(size);
auto pointer = buf.data();
memcpy(pointer, kSignature, strlen(kSignature));
memcpy(pointer+8, &size, 8);
memcpy(pointer+16, &data_set_id, 8);
memcpy(pointer + 8, &size, 8);
memcpy(pointer + 16, &data_set_id, 8);

pointer += kPrefixLen;
for (auto entry : *entries) {
Expand All @@ -75,9 +74,9 @@ std::uint8_t * DamEncoder::Finish(size_t &size) {
// print_buffer(entry->pointer, entry->size*8);
}

if ((pointer - buf) != size) {
std::cout << "Invalid encoded size: " << (pointer - buf) << std::endl;
return nullptr;
if ((pointer - buf.data()) != size) {
throw std::runtime_error{"Invalid encoded size: " +
std::to_string(pointer - buf.data())};
}

return buf;
Expand All @@ -97,7 +96,7 @@ std::size_t DamEncoder::calculate_size() {

// DamDecoder ======

DamDecoder::DamDecoder(std::uint8_t *buffer, std::size_t size) {
DamDecoder::DamDecoder(std::uint8_t const *buffer, std::size_t size) {
this->buffer = buffer;
this->buf_size = size;
this->pos = buffer + kPrefixLen;
Expand All @@ -115,32 +114,33 @@ bool DamDecoder::IsValid() {
}

std::vector<int64_t> DamDecoder::DecodeIntArray() {
auto type = *reinterpret_cast<int64_t *>(pos);
if (type != kDataTypeIntArray) {
std::cout << "Data type " << type << " doesn't match Int Array" << std::endl;
return std::vector<int64_t>();
}
pos += 8;

auto len = *reinterpret_cast<int64_t *>(pos);
pos += 8;
auto ptr = reinterpret_cast<int64_t *>(pos);
pos += 8*len;
return std::vector<int64_t>(ptr, ptr + len);
auto type = *reinterpret_cast<int64_t const*>(pos);
if (type != kDataTypeIntArray) {
std::cout << "Data type " << type << " doesn't match Int Array"
<< std::endl;
return std::vector<int64_t>();
}
pos += 8;

auto len = *reinterpret_cast<int64_t const *>(pos);
pos += 8;
auto ptr = reinterpret_cast<int64_t const *>(pos);
pos += 8 * len;
return std::vector<int64_t>(ptr, ptr + len);
}

std::vector<double> DamDecoder::DecodeFloatArray() {
auto type = *reinterpret_cast<int64_t *>(pos);
auto type = *reinterpret_cast<int64_t const*>(pos);
if (type != kDataTypeFloatArray) {
std::cout << "Data type " << type << " doesn't match Float Array" << std::endl;
return std::vector<double>();
}
pos += 8;

auto len = *reinterpret_cast<int64_t *>(pos);
auto len = *reinterpret_cast<int64_t const *>(pos);
pos += 8;

auto ptr = reinterpret_cast<double *>(pos);
auto ptr = reinterpret_cast<double const *>(pos);
pos += 8*len;
return std::vector<double>(ptr, ptr + len);
}
16 changes: 8 additions & 8 deletions integration/xgboost/processor/src/include/dam.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
* limitations under the License.
*/
#pragma once
#include <string>
#include <vector>
#include <map>
#include <cstdint> // for int64_t
#include <cstddef> // for size_t

const char kSignature[] = "NVDADAM1"; // DAM (Direct Accessible Marshalling) V1
const int kPrefixLen = 24;
Expand Down Expand Up @@ -57,23 +57,23 @@ class DamEncoder {

void AddFloatArray(const std::vector<double> &value);

std::uint8_t * Finish(size_t &size);
std::vector<std::uint8_t> Finish(size_t &size);

private:
private:
std::size_t calculate_size();
};

class DamDecoder {
private:
std::uint8_t *buffer = nullptr;
std::uint8_t const *buffer = nullptr;
std::size_t buf_size = 0;
std::uint8_t *pos = nullptr;
std::uint8_t const *pos = nullptr;
std::size_t remaining = 0;
int64_t data_set_id = 0;
int64_t len = 0;

public:
explicit DamDecoder(std::uint8_t *buffer, std::size_t size);
public:
explicit DamDecoder(std::uint8_t const *buffer, std::size_t size);

size_t Size() {
return len;
Expand Down
92 changes: 45 additions & 47 deletions integration/xgboost/processor/src/include/nvflare_processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
* limitations under the License.
*/
#pragma once
#include <iostream>
#include <string>
#include <vector>
#include <map>
#include "processing/processor.h"
#include <cstdint> // for uint8_t, uint32_t, int32_t, int64_t
#include <string_view> // for string_view
#include <utility> // for pair
#include <vector> // for vector

const int kDataSetHGPairs = 1;
const int kDataSetAggregation = 2;
Expand All @@ -27,50 +26,49 @@ const int kDataSetAggregationResult = 4;
const int kDataSetHistograms = 5;
const int kDataSetHistogramResult = 6;

class NVFlareProcessor: public processing::Processor {
private:
bool active_ = false;
const std::map<std::string, std::string> *params_;
std::vector<double> *gh_pairs_{nullptr};
std::vector<uint32_t> cuts_;
std::vector<int> slots_;
bool feature_sent_ = false;
std::vector<int64_t> features_;
// Opaque pointer type for the C API.
typedef void *FederatedPluginHandle; // NOLINT

public:
void Initialize(bool active, std::map<std::string, std::string> params) override {
this->active_ = active;
this->params_ = &params;
}
namespace nvflare {
// Plugin that uses Python tenseal and GRPC.
class TensealPlugin {
// Buffer for storing encrypted gradient pairs.
std::vector<std::uint8_t> encrypted_gpairs_;
// Buffer for histogram cut pointers (indptr of a CSC).
std::vector<std::uint32_t> cut_ptrs_;
// Buffer for histogram index.
std::vector<std::int32_t> bin_idx_;

void Shutdown() override {
this->gh_pairs_ = nullptr;
this->cuts_.clear();
this->slots_.clear();
}
bool feature_sent_{false};
// The feature index.
std::vector<std::int64_t> features_;
// Buffer for output histogram.
std::vector<std::uint8_t> encrypted_hist_;
std::vector<double> hist_;

void FreeBuffer(void *buffer) override {
free(buffer);
}
public:
TensealPlugin(
std::vector<std::pair<std::string_view, std::string_view>> const &args);
// Gradient pairs
void EncryptGPairs(float const *in_gpair, std::size_t n_in,
std::uint8_t **out_gpair, std::size_t *n_out);
void SyncEncryptedGPairs(std::uint8_t const *in_gpair, std::size_t n_bytes,
std::uint8_t const **out_gpair,
std::size_t *out_n_bytes);

void* ProcessGHPairs(size_t *size, const std::vector<double>& pairs) override;
// Histogram
void ResetHistContext(std::uint32_t const *cutptrs, std::size_t cutptr_len,
std::int32_t const *bin_idx, std::size_t n_idx);
void BuildEncryptedHistHori(double const *in_histogram, std::size_t len,
std::uint8_t **out_hist, std::size_t *out_len);
void SyncEncryptedHistHori(std::uint8_t const *buffer, std::size_t len,
double **out_hist, std::size_t *out_len);

void* HandleGHPairs(size_t *size, void *buffer, size_t buf_size) override;

void InitAggregationContext(const std::vector<uint32_t> &cuts, const std::vector<int> &slots) override {
if (this->slots_.empty()) {
this->cuts_ = std::vector<uint32_t>(cuts);
this->slots_ = std::vector<int>(slots);
} else {
std::cout << "Multiple calls to InitAggregationContext" << std::endl;
}
}

void *ProcessAggregation(size_t *size, std::map<int, std::vector<int>> nodes) override;

std::vector<double> HandleAggregation(void *buffer, size_t buf_size) override;

void *ProcessHistograms(size_t *size, const std::vector<double>& histograms) override;

std::vector<double> HandleHistograms(void *buffer, size_t buf_size) override;
};
void BuildEncryptedHistVert(std::size_t const **ridx,
std::size_t const *sizes,
std::int32_t const *nidx, std::size_t len,
std::uint8_t **out_hist, std::size_t *out_len);
void SyncEncryptedHistVert(std::uint8_t *hist_buffer, std::size_t len,
double **out, size_t *out_len);
};
} // namespace nvflare
Loading

0 comments on commit 4a65a60

Please sign in to comment.