Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XGBoost plugin with new API #2725

Merged
merged 41 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
a655d9d
Updated FOBS readme to add DatumManager, added agrpcs as secure scheme
nvidianz Mar 18, 2024
4e5ba5d
Merge branch 'NVIDIA:main' into main
nvidianz Mar 25, 2024
9f90d48
Merge branch 'NVIDIA:main' into main
nvidianz Apr 10, 2024
84fc5bd
Merge branch 'NVIDIA:main' into main
nvidianz Apr 15, 2024
f972506
Merge branch 'NVIDIA:main' into main
nvidianz Apr 23, 2024
5f6e76f
Merge branch 'NVIDIA:main' into main
nvidianz Apr 26, 2024
15884c4
Merge branch 'NVIDIA:main' into main
nvidianz May 3, 2024
6fbcbaa
Merge branch 'NVIDIA:main' into main
nvidianz May 6, 2024
07d3787
Merge branch 'NVIDIA:main' into main
nvidianz May 8, 2024
98289c4
Merge branch 'NVIDIA:main' into main
nvidianz May 13, 2024
cb8cdef
Merge branch 'NVIDIA:main' into main
nvidianz May 14, 2024
47facb8
Merge branch 'NVIDIA:main' into main
nvidianz May 16, 2024
cdb4b92
Merge branch 'NVIDIA:main' into main
nvidianz May 16, 2024
a194f6d
Merge branch 'NVIDIA:main' into main
nvidianz May 20, 2024
79c955e
Merge branch 'NVIDIA:main' into main
nvidianz May 22, 2024
b8d5f56
Merge branch 'NVIDIA:main' into main
nvidianz Jun 4, 2024
f6f6dbc
Merge branch 'NVIDIA:main' into main
nvidianz Jun 6, 2024
e0732cc
Merge branch 'NVIDIA:main' into main
nvidianz Jun 6, 2024
f77c56b
Merge branch 'NVIDIA:main' into main
nvidianz Jun 7, 2024
5c43908
Merge branch 'NVIDIA:main' into main
nvidianz Jun 11, 2024
f5496b3
Merge branch 'NVIDIA:main' into main
nvidianz Jun 14, 2024
3fdb95b
Merge branch 'NVIDIA:main' into main
nvidianz Jul 3, 2024
654bc17
Merge branch 'NVIDIA:main' into main
nvidianz Jul 12, 2024
c721e0e
Merge branch 'NVIDIA:main' into main
nvidianz Jul 17, 2024
517ad90
Merge branch 'NVIDIA:main' into main
nvidianz Jul 18, 2024
7c66bd2
Merge branch 'NVIDIA:main' into main
nvidianz Jul 23, 2024
d0b51c1
Implemented LocalPlugin
nvidianz Jul 23, 2024
43610ee
Refactoring plugin
nvidianz Jul 25, 2024
729ae37
Fixed formats
nvidianz Jul 25, 2024
1133cfc
Fixed horizontal secure isses with mismatching algather-v sizes
nvidianz Jul 26, 2024
9c5d8f6
Added padding to the buffer so it's big enough for histograms
nvidianz Jul 26, 2024
b6641c8
Format fix
nvidianz Jul 26, 2024
6b48d51
Merge branch 'main' into xgboost-plugin-new-api
YuanTingHsieh Jul 30, 2024
35f0a92
Merge branch 'main' into xgboost-plugin-new-api
YuanTingHsieh Jul 31, 2024
f3e32ac
Changed log level for tenseal exceptions
nvidianz Jul 31, 2024
9736bbf
Fixed a typo
nvidianz Jul 31, 2024
3943313
Added debug statements
nvidianz Aug 2, 2024
e2283e4
Fixed LocalPlugin horizontal bug
nvidianz Aug 2, 2024
df7cb0f
Added #include <chrono>
nvidianz Aug 2, 2024
7acccb4
Added docstring to BasePlugin
nvidianz Aug 2, 2024
9d960b9
Merge branch 'main' into xgboost-plugin-new-api
YuanTingHsieh Aug 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions integration/xgboost/encryption_plugins/.editorconfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
root = true

[*]
charset=utf-8
indent_style = space
indent_size = 2
insert_final_newline = true

[*.py]
indent_style = space
indent_size = 4
41 changes: 41 additions & 0 deletions integration/xgboost/encryption_plugins/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
cmake_minimum_required(VERSION 3.19)
project(xgb_nvflare LANGUAGES CXX C VERSION 1.0)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_BUILD_TYPE Debug)

option(GOOGLE_TEST "Build google tests" OFF)

file(GLOB_RECURSE LIB_SRC "src/*.cc")

add_library(nvflare SHARED ${LIB_SRC})
set_target_properties(nvflare PROPERTIES
CXX_STANDARD 17
CXX_STANDARD_REQUIRED ON
POSITION_INDEPENDENT_CODE ON
ENABLE_EXPORTS ON
)
target_include_directories(nvflare PRIVATE ${xgb_nvflare_SOURCE_DIR}/src/include)

if (APPLE)
add_link_options("LINKER:-object_path_lto,$<TARGET_PROPERTY:NAME>_lto.o")
add_link_options("LINKER:-cache_path_lto,${CMAKE_BINARY_DIR}/LTOCache")
endif ()

#-- Unit Tests
if(GOOGLE_TEST)
find_package(GTest REQUIRED)
enable_testing()
add_executable(nvflare_test)
target_link_libraries(nvflare_test PRIVATE nvflare)


target_include_directories(nvflare_test PRIVATE ${xgb_nvflare_SOURCE_DIR}/src/include)

add_subdirectory(${xgb_nvflare_SOURCE_DIR}/tests)

add_test(
NAME TestNvflarePlugins
COMMAND nvflare_test
WORKING_DIRECTORY ${xgb_nvflare_BINARY_DIR})

endif()
9 changes: 9 additions & 0 deletions integration/xgboost/encryption_plugins/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Build Instruction

cd NVFlare/integration/xgboost/encryption_plugins
mkdir build
cd build
cmake ..
make

The library is libxgb_nvflare.so
274 changes: 274 additions & 0 deletions integration/xgboost/encryption_plugins/src/dam/dam.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
/**
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <cstring>
#include "dam.h"


void print_hex(const uint8_t *buffer, std::size_t size) {
std::cout << std::hex;
for (int i = 0; i < size; i++) {
int c = buffer[i];
std::cout << c << " ";
}
std::cout << std::endl << std::dec;
}

void print_buffer(const uint8_t *buffer, std::size_t size) {
if (size <= 64) {
std::cout << "Whole buffer: " << size << " bytes" << std::endl;
print_hex(buffer, size);
return;
}

std::cout << "First chunk, Total: " << size << " bytes" << std::endl;
print_hex(buffer, 32);
std::cout << "Last chunk, Offset: " << size-16 << " bytes" << std::endl;
print_hex(buffer+size-32, 32);
}

size_t align(const size_t length) {
return ((length + 7)/8)*8;
}

// DamEncoder ======
void DamEncoder::AddBuffer(const Buffer &buffer) {
if (debug_) {
std::cout << "AddBuffer called, size: " << buffer.buf_size << std::endl;
}
if (encoded_) {
std::cout << "Buffer is already encoded" << std::endl;
return;
}
// print_buffer(buffer, buf_size);
entries_.emplace_back(kDataTypeBuffer, static_cast<const uint8_t *>(buffer.buffer), buffer.buf_size);
}

void DamEncoder::AddFloatArray(const std::vector<double> &value) {
if (debug_) {
std::cout << "AddFloatArray called, size: " << value.size() << std::endl;
}

if (encoded_) {
std::cout << "Buffer is already encoded" << std::endl;
return;
}
// print_buffer(reinterpret_cast<uint8_t *>(value.data()), value.size() * 8);
entries_.emplace_back(kDataTypeFloatArray, reinterpret_cast<const uint8_t *>(value.data()), value.size());
}

void DamEncoder::AddIntArray(const std::vector<int64_t> &value) {
if (debug_) {
std::cout << "AddIntArray called, size: " << value.size() << std::endl;
}

if (encoded_) {
std::cout << "Buffer is already encoded" << std::endl;
return;
}
// print_buffer(buffer, buf_size);
entries_.emplace_back(kDataTypeIntArray, reinterpret_cast<const uint8_t *>(value.data()), value.size());
}

void DamEncoder::AddBufferArray(const std::vector<Buffer> &value) {
if (debug_) {
std::cout << "AddBufferArray called, size: " << value.size() << std::endl;
}

if (encoded_) {
std::cout << "Buffer is already encoded" << std::endl;
return;
}
size_t size = 0;
for (auto &buf: value) {
size += buf.buf_size;
}
size += 8*value.size();
entries_.emplace_back(kDataTypeBufferArray, reinterpret_cast<const uint8_t *>(&value), size);
}


std::uint8_t * DamEncoder::Finish(size_t &size) {
encoded_ = true;

size = CalculateSize();
auto buf = static_cast<uint8_t *>(calloc(size, 1));
auto pointer = buf;
auto sig = local_version_ ? kSignatureLocal : kSignature;
memcpy(pointer, sig, strlen(sig));
memcpy(pointer+8, &size, 8);
memcpy(pointer+16, &data_set_id_, 8);

pointer += kPrefixLen;
for (auto& entry : entries_) {
std::size_t len;
if (entry.data_type == kDataTypeBufferArray) {
auto buffers = reinterpret_cast<const std::vector<Buffer> *>(entry.pointer);
memcpy(pointer, &entry.data_type, 8);
pointer += 8;
auto array_size = static_cast<int64_t>(buffers->size());
memcpy(pointer, &array_size, 8);
pointer += 8;
auto sizes = reinterpret_cast<int64_t *>(pointer);
for (auto &item : *buffers) {
*sizes = static_cast<int64_t>(item.buf_size);
sizes++;
}
len = 8*buffers->size();
auto buf_ptr = pointer + len;
for (auto &item : *buffers) {
if (item.buf_size > 0) {
memcpy(buf_ptr, item.buffer, item.buf_size);
}
buf_ptr += item.buf_size;
len += item.buf_size;
}
} else {
memcpy(pointer, &entry.data_type, 8);
pointer += 8;
memcpy(pointer, &entry.size, 8);
pointer += 8;
len = entry.size * entry.ItemSize();
if (len) {
memcpy(pointer, entry.pointer, len);
}
}
pointer += align(len);
}

if ((pointer - buf) != size) {
std::cout << "Invalid encoded size: " << (pointer - buf) << std::endl;
return nullptr;
}

return buf;
}

std::size_t DamEncoder::CalculateSize() {
std::size_t size = kPrefixLen;

for (auto& entry : entries_) {
size += 16; // The Type and Len
auto len = entry.size * entry.ItemSize();
size += align(len);
}

return size;
}


// DamDecoder ======

DamDecoder::DamDecoder(std::uint8_t *buffer, std::size_t size, bool local_version, bool debug) {
local_version_ = local_version;
buffer_ = buffer;
buf_size_ = size;
pos_ = buffer + kPrefixLen;
debug_ = debug;

if (size >= kPrefixLen) {
memcpy(&len_, buffer + 8, 8);
memcpy(&data_set_id_, buffer + 16, 8);
} else {
len_ = 0;
data_set_id_ = 0;
}
}

bool DamDecoder::IsValid() const {
auto sig = local_version_ ? kSignatureLocal : kSignature;
return buf_size_ >= kPrefixLen && memcmp(buffer_, sig, strlen(sig)) == 0;
}

Buffer DamDecoder::DecodeBuffer() {
auto type = *reinterpret_cast<int64_t *>(pos_);
if (type != kDataTypeBuffer) {
std::cout << "Data type " << type << " doesn't match bytes" << std::endl;
return {};
}
pos_ += 8;

auto size = *reinterpret_cast<int64_t *>(pos_);
pos_ += 8;

if (size == 0) {
return {};
}

auto ptr = reinterpret_cast<void *>(pos_);
pos_ += align(size);
return{ ptr, static_cast<std::size_t>(size)};
}

std::vector<int64_t> DamDecoder::DecodeIntArray() {
auto type = *reinterpret_cast<int64_t *>(pos_);
if (type != kDataTypeIntArray) {
std::cout << "Data type " << type << " doesn't match Int Array" << std::endl;
return {};
}
pos_ += 8;

auto array_size = *reinterpret_cast<int64_t *>(pos_);
pos_ += 8;
auto ptr = reinterpret_cast<int64_t *>(pos_);
pos_ += align(8 * array_size);
return {ptr, ptr + array_size};
}

std::vector<double> DamDecoder::DecodeFloatArray() {
auto type = *reinterpret_cast<int64_t *>(pos_);
if (type != kDataTypeFloatArray) {
std::cout << "Data type " << type << " doesn't match Float Array" << std::endl;
return {};
}
pos_ += 8;

auto array_size = *reinterpret_cast<int64_t *>(pos_);
pos_ += 8;

auto ptr = reinterpret_cast<double *>(pos_);
pos_ += align(8 * array_size);
return {ptr, ptr + array_size};
}

std::vector<Buffer> DamDecoder::DecodeBufferArray() {
auto type = *reinterpret_cast<int64_t *>(pos_);
if (type != kDataTypeBufferArray) {
std::cout << "Data type " << type << " doesn't match Bytes Array" << std::endl;
return {};
}
pos_ += 8;

auto num = *reinterpret_cast<int64_t *>(pos_);
pos_ += 8;

auto size_ptr = reinterpret_cast<int64_t *>(pos_);
auto buf_ptr = pos_ + 8 * num;
size_t total_size = 8 * num;
auto result = std::vector<Buffer>(num);
for (int i = 0; i < num; i++) {
auto size = size_ptr[i];
if (buf_size_ > 0) {
result[i].buf_size = size;
result[i].buffer = buf_ptr;
buf_ptr += size;
}
total_size += size;
}

pos_ += align(total_size);
return result;
}
Loading
Loading