Skip to content

Commit

Permalink
Support PyTorch CUDACachingAllocator (#12)
Browse files Browse the repository at this point in the history
* Setting `export PJRT_USE_TORCH_ALLOCATOR=1` to use PyTorch CUDACachingAllocator
  • Loading branch information
yitongh committed Sep 12, 2024
1 parent 5070f86 commit a32543b
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 1 deletion.
13 changes: 13 additions & 0 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ cc_library(
":profiler",
":sys_util",
":tf_logging",
":torch_allocator",
":xla_coordinator",
"@xla//xla/service:gpu_plugin",
"@xla//xla/pjrt/gpu:se_gpu_pjrt_client",
Expand Down Expand Up @@ -371,6 +372,18 @@ cc_library(
],
)

cc_library(
name = "torch_allocator",
srcs = ["torch_allocator.cc"],
hdrs = ["torch_allocator.h"],
deps = [
":tf_logging",
"@tsl//tsl/framework:allocator",
"@torch//:headers",
"@xla//xla/stream_executor/gpu:gpu_types_header",
],
)

cc_library(
name = "tensor_source",
hdrs = ["tensor_source.h"],
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/env_vars.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const char* const kEnvPjrtAllocatorCudaAsync = "PJRT_ALLOCATOR_CUDA_ASYNC";
const char* const kEnvPjrtAllocatorPreallocate = "PJRT_ALLOCATOR_PREALLOCATE";
const char* const kEnvPjrtAllocatorFraction = "PJRT_ALLOCATOR_FRACTION";
const char* const kEnvPjrtDynamicPlugins = "PJRT_DYNAMIC_PLUGINS";
const char* const kEnvPjrtUseTorchAllocator = "PJRT_USE_TORCH_ALLOCATOR";

} // namespace env
} // namespace runtime
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/env_vars.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ extern const char* const kEnvPjrtAllocatorCudaAsync;
extern const char* const kEnvPjrtAllocatorPreallocate;
extern const char* const kEnvPjrtAllocatorFraction;
extern const char* const kEnvPjrtDynamicPlugins;
extern const char* const kEnvPjrtUseTorchAllocator;

} // namespace env
} // namespace runtime
Expand Down
128 changes: 127 additions & 1 deletion torch_xla/csrc/runtime/pjrt_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "torch_xla/csrc/runtime/profiler.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/runtime/torch_allocator.h"
#include "torch_xla/csrc/runtime/xla_coordinator.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/pjrt/distributed/client.h"
Expand All @@ -14,6 +15,14 @@
#include "xla/pjrt/pjrt_api.h"
#include "xla/pjrt/pjrt_c_api_client.h"
#include "xla/pjrt/tfrt_cpu_pjrt_client.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/device_memory_allocator.h"
#include "xla/stream_executor/integrations/device_mem_allocator.h"
#include "xla/stream_executor/integrations/tf_allocator_adapter.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_executor.h"

namespace torch_xla {
namespace runtime {
Expand Down Expand Up @@ -54,6 +63,117 @@ void RegisterPjRtPlugin(std::string name,
pjrt_plugins_[name] = plugin;
}

// Copied from openxla's
// xla/pjrt/gpu/se_gpu_pjrt_client.cc::BuildLocalDeviceStates
absl::StatusOr<std::map<int, std::unique_ptr<xla::LocalDeviceState>>>
BuildLocalDeviceStates(xla::LocalClient* xla_client) {
std::map<int, std::unique_ptr<xla::LocalDeviceState>> addressable_devices;
for (stream_executor::StreamExecutor* executor :
xla_client->backend().stream_executors()) {
addressable_devices.emplace(
executor->device_ordinal(),
std::make_unique<xla::LocalDeviceState>(
executor, xla_client, xla::LocalDeviceState::kComputeSynchronized,
/*max_inflight_computations=*/32,
/*allow_event_reuse=*/true, /*use_callback_stream=*/true));
}
return std::move(addressable_devices);
}

// Modified from openxla's
// xla/pjrt/gpu/se_gpu_pjrt_client.cc::GetStreamExecutorGpuDeviceAllocator
// change to use torch allocator
absl::StatusOr<std::unique_ptr<stream_executor::DeviceMemoryAllocator>>
GetTorchAllocator(stream_executor::Platform* platform,
const xla::GpuAllocatorConfig& allocator_config,
const std::map<int, std::unique_ptr<xla::LocalDeviceState>>&
addressable_devices) {
std::vector<stream_executor::MultiDeviceAdapter::AllocatorInfo> allocators;
LOG(INFO) << "Using PyTorch CUDACachingAllocator.";
for (const auto& ordinal_and_device : addressable_devices) {
stream_executor::StreamExecutor* executor =
ordinal_and_device.second->executor();
int device_ordinal = executor->device_ordinal();
auto allocator =
std::make_unique<TorchCUDACachingAllocator>(device_ordinal);
allocator->SetStreamAndPreallocateMemory(
ordinal_and_device.second->compute_stream()
->platform_specific_handle()
.stream);
allocators.emplace_back(std::move(allocator),
ordinal_and_device.second->compute_stream(),
/*memory_space=*/0);
}

// Add any additional allocators for alternate memory spaces.
for (const auto& ordinal_and_device : addressable_devices) {
TF_ASSIGN_OR_RETURN(
auto collective_bfc_allocator,
xla::CreateCollectiveBFCAllocator(
ordinal_and_device.second->executor(),
/*memory_fraction=*/1.0 - allocator_config.memory_fraction,
allocator_config.collective_memory_size));
allocators.emplace_back(std::move(collective_bfc_allocator),
ordinal_and_device.second->compute_stream(),
/*memory_space=*/1);
}

for (const auto& ordinal_and_device : addressable_devices) {
auto host_allocator =
xla::GetGpuHostAllocator(ordinal_and_device.second->executor());
allocators.emplace_back(
std::move(host_allocator), ordinal_and_device.second->compute_stream(),
/*memory_space=*/
static_cast<int>(stream_executor::MemoryType::kHost));
}

return std::make_unique<stream_executor::MultiDeviceAdapter>(
platform, std::move(allocators));
}

// Modified from xla::GetStreamExecutorGpuClient, change to use torch allocator
absl::StatusOr<std::unique_ptr<xla::PjRtClient>>
GetPjRtClientWithTorchAllocator(const xla::GpuClientOptions& options) {
auto pjrt_platform_name = xla::CudaName();

TF_ASSIGN_OR_RETURN(
xla::LocalClient * xla_client,
xla::GetGpuXlaClient(options.platform_name, options.allowed_devices));
std::map<int, std::unique_ptr<xla::LocalDeviceState>> local_device_states;
TF_ASSIGN_OR_RETURN(local_device_states, BuildLocalDeviceStates(xla_client));
xla::EnablePeerAccess(xla_client->backend().stream_executors());

TF_ASSIGN_OR_RETURN(
auto allocator,
GetTorchAllocator(xla_client->platform(), options.allocator_config,
local_device_states));

auto host_memory_allocator =
xla::GetGpuHostAllocator(local_device_states.begin()->second->executor());

std::vector<std::unique_ptr<xla::PjRtStreamExecutorDevice>> devices;
auto gpu_run_options = std::make_unique<xla::gpu::GpuExecutableRunOptions>();
if (options.enable_mock_nccl) {
gpu_run_options->set_enable_mock_nccl_collectives();
}
std::shared_ptr<xla::KeyValueStoreInterface> kv_store = options.kv_store;
if (options.enable_mock_nccl) {
kv_store = std::make_shared<xla::InMemoryKeyValueStore>();
}
TF_RET_CHECK(options.num_nodes == 1 || kv_store != nullptr);
TF_RETURN_IF_ERROR(xla::BuildDistributedDevices(
pjrt_platform_name, std::move(local_device_states), options.node_id,
options.num_nodes, &devices, gpu_run_options.get(), kv_store,
options.enable_mock_nccl));

return std::unique_ptr<xla::PjRtClient>(
std::make_unique<xla::StreamExecutorGpuClient>(
pjrt_platform_name, xla_client, std::move(devices), options.node_id,
std::move(allocator), std::move(host_memory_allocator),
options.should_stage_host_to_device_transfers,
std::move(gpu_run_options)));
}

std::tuple<std::unique_ptr<xla::PjRtClient>, std::unique_ptr<XlaCoordinator>>
InitializePjRt(const std::string& device_type) {
std::unique_ptr<xla::PjRtClient> client;
Expand Down Expand Up @@ -167,7 +287,13 @@ InitializePjRt(const std::string& device_type) {
options.platform_name = "gpu";
options.should_stage_host_to_device_transfers = true;
options.kv_store = kv_store;
client = std::move(xla::GetStreamExecutorGpuClient(options).value());
bool use_torch_allocator =
sys_util::GetEnvBool(env::kEnvPjrtUseTorchAllocator, false);
if (use_torch_allocator) {
client = std::move(GetPjRtClientWithTorchAllocator(options).value());
} else {
client = std::move(xla::GetStreamExecutorGpuClient(options).value());
}
} else if (device_type == "XPU") {
TF_VLOG(1) << "Initializing PjRt XPU client...";
XLA_CHECK_OK(
Expand Down
48 changes: 48 additions & 0 deletions torch_xla/csrc/runtime/torch_allocator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#include "torch_xla/csrc/runtime/torch_allocator.h"

#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAGuard.h>

#include "torch_xla/csrc/runtime/tf_logging.h"

namespace torch_xla {
namespace runtime {

TorchCUDACachingAllocator::TorchCUDACachingAllocator(int device_ordinal) {
VLOG(3) << "Creating TorchCUDACachingAllocator for device " << device_ordinal;
name_ = c10::cuda::CUDACachingAllocator::name();
cuda_stream_ = nullptr;
device_index_ = static_cast<c10::DeviceIndex>(device_ordinal);
}

void* TorchCUDACachingAllocator::AllocateRaw(size_t alignment,
size_t num_bytes) {
CHECK(cuda_stream_ != nullptr)
<< "A stream must be added to the TorchCUDACachingAllocator allocator";
if (num_bytes == 0) {
return nullptr;
}
at::cuda::CUDAGuard device_guard{device_index_};
auto ptr = c10::cuda::CUDACachingAllocator::raw_alloc_with_stream(
num_bytes, cuda_stream_);
VLOG(3) << "Alloc num_bytes " << num_bytes << " with ptr " << ptr
<< " for device " << static_cast<int>(device_index_);
return ptr;
}

void TorchCUDACachingAllocator::DeallocateRaw(void* ptr) {
VLOG(3) << "Dealloc ptr " << ptr << " for device "
<< static_cast<int>(device_index_);
c10::cuda::CUDACachingAllocator::raw_delete(ptr);
}

void TorchCUDACachingAllocator::SetStreamAndPreallocateMemory(void* stream) {
auto new_cuda_stream = static_cast<cudaStream_t>(stream);
VLOG(3) << "Setting cuda stream " << stream
<< " for TorchCUDACachingAllocator on device "
<< static_cast<int>(device_index_);
cuda_stream_ = new_cuda_stream;
}

} // namespace runtime
} // namespace torch_xla
37 changes: 37 additions & 0 deletions torch_xla/csrc/runtime/torch_allocator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#ifndef XLA_CLIENT_TORCH_ALLOCATOR_H_
#define XLA_CLIENT_TORCH_ALLOCATOR_H_

#include <c10/cuda/CUDAStream.h>
#include <cuda_runtime_api.h>

#include "tsl/framework/allocator.h"

namespace torch_xla {
namespace runtime {

class TorchCUDACachingAllocator : public tsl::Allocator {
public:
TorchCUDACachingAllocator(int device_ordinal);
~TorchCUDACachingAllocator() override{};

std::string Name() override { return name_; }

void* AllocateRaw(size_t alignment, size_t num_bytes) override;
void DeallocateRaw(void* ptr) override;

void SetStreamAndPreallocateMemory(void* stream) override;

tsl::AllocatorMemoryType GetMemoryType() const override {
return tsl::AllocatorMemoryType::kDevice;
}

private:
std::string name_;
cudaStream_t cuda_stream_;
c10::DeviceIndex device_index_;
};

} // namespace runtime
} // namespace torch_xla

#endif // XLA_CLIENT_TORCH_ALLOCATOR_H_

0 comments on commit a32543b

Please sign in to comment.