Skip to content

Commit

Permalink
[Cherry-pick][Framework] Support custom allocator (PaddlePaddle#10013)
Browse files Browse the repository at this point in the history
  • Loading branch information
shentanyue committed Feb 22, 2023
1 parent 22d87bf commit 7d27aea
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 44 deletions.
7 changes: 7 additions & 0 deletions lite/api/paddle_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,13 @@ void ConfigBase::add_discarded_pass(const std::string pass) {
return;
}

// Set external custom allocator
void ConfigBase::set_custom_allocator(TargetType target_type,
CustomAllocator custom_allocator) {
// TODO(shentanyue): TargetType will be supported in the future.
lite::Allocator::Global().SetCustomAllocator(custom_allocator);
}

#ifdef LITE_WITH_X86
void ConfigBase::set_x86_math_num_threads(int threads) {
x86_math_num_threads_ = threads;
Expand Down
4 changes: 4 additions & 0 deletions lite/api/paddle_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,10 @@ class LITE_API ConfigBase {
const std::vector<std::string> get_discarded_passes() const {
return discarded_passes_;
}

// Set external custom allocator
void set_custom_allocator(TargetType target_type,
CustomAllocator custom_allocator);
};

class LITE_API CxxModelBuffer {
Expand Down
5 changes: 5 additions & 0 deletions lite/api/paddle_place.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,5 +278,10 @@ struct LITE_API Place {
std::string DebugString() const;
};

struct LITE_API CustomAllocator {
void* (*alloc)(size_t size, size_t alignment) = nullptr;
void (*free)(void* ptr) = nullptr;
};

} // namespace lite_api
} // namespace paddle
2 changes: 2 additions & 0 deletions lite/backends/host/target_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ void* TargetWrapper<TARGET(kHost)>::Malloc(size_t size) {
static_cast<void**>(r)[-1] = p;
return r;
}

void TargetWrapper<TARGET(kHost)>::Free(void* ptr) {
if (ptr) {
free(static_cast<void**>(ptr)[-1]);
}
}

void TargetWrapper<TARGET(kHost)>::MemcpySync(void* dst,
const void* src,
size_t size,
Expand Down
97 changes: 53 additions & 44 deletions lite/core/memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,67 +24,76 @@ namespace lite {

void* TargetMalloc(TargetType target, size_t size) {
void* data{nullptr};
switch (target) {
case TargetType::kHost:
case TargetType::kX86:
case TargetType::kARM:
data = TargetWrapper<TARGET(kHost)>::Malloc(size);
break;
if (lite::Allocator::Global().GetCustomAllocator().alloc) {
data = lite::Allocator::Global().GetCustomAllocator().alloc(
size, host::MALLOC_ALIGN);
} else {
switch (target) {
case TargetType::kHost:
case TargetType::kX86:
case TargetType::kARM:
data = TargetWrapper<TARGET(kHost)>::Malloc(size);
break;
#ifdef LITE_WITH_OPENCL
case TargetType::kOpenCL:
data = TargetWrapperCL::Malloc(size);
break;
case TargetType::kOpenCL:
data = TargetWrapperCL::Malloc(size);
break;
#endif // LITE_WITH_OPENCL
#ifdef LITE_WITH_XPU
case TargetType::kXPU:
data = TargetWrapperXPU::Malloc(size);
break;
case TargetType::kXPU:
data = TargetWrapperXPU::Malloc(size);
break;
#endif // LITE_WITH_XPU
#ifdef LITE_WITH_METAL
case TargetType::kMetal: {
data = TargetWrapperMetal::Malloc(size);
break;
}
case TargetType::kMetal: {
data = TargetWrapperMetal::Malloc(size);
break;
}
#endif // LITE_WITH_METAL
default:
LOG(FATAL) << "Unknown supported target " << TargetToStr(target);
default:
LOG(FATAL) << "Unknown supported target " << TargetToStr(target);
}
}
return data;
}

void TargetFree(TargetType target, void* data, std::string free_flag) {
switch (target) {
case TargetType::kHost:
case TargetType::kX86:
case TargetType::kARM:
TargetWrapper<TARGET(kHost)>::Free(data);
break;
if (lite::Allocator::Global().GetCustomAllocator().free) {
lite::Allocator::Global().GetCustomAllocator().free(data);
} else {
switch (target) {
case TargetType::kHost:
case TargetType::kX86:
case TargetType::kARM:
TargetWrapper<TARGET(kHost)>::Free(data);
break;

#ifdef LITE_WITH_OPENCL
case TargetType::kOpenCL:
if (free_flag == "cl_use_image2d_") {
TargetWrapperCL::FreeImage(data);
} else {
TargetWrapperCL::Free(data);
}
break;
case TargetType::kOpenCL:
if (free_flag == "cl_use_image2d_") {
TargetWrapperCL::FreeImage(data);
} else {
TargetWrapperCL::Free(data);
}
break;
#endif // LITE_WITH_OPENCL
#ifdef LITE_WITH_XPU
case TargetType::kXPU:
TargetWrapperXPU::Free(data);
break;
case TargetType::kXPU:
TargetWrapperXPU::Free(data);
break;
#endif // LITE_WITH_XPU
#ifdef LITE_WITH_METAL
case TargetType::kMetal:
if (free_flag == "metal_use_image2d_") {
TargetWrapperMetal::FreeImage(data);
} else {
TargetWrapperMetal::Free(data);
}
break;
case TargetType::kMetal:
if (free_flag == "metal_use_image2d_") {
TargetWrapperMetal::FreeImage(data);
} else {
TargetWrapperMetal::Free(data);
}
break;
#endif
default:
LOG(FATAL) << "Unknown type";
default:
LOG(FATAL) << "Unknown supported target:" << TargetToStr(target);
}
}
}

Expand Down Expand Up @@ -113,7 +122,7 @@ void TargetCopy(TargetType target, void* dst, const void* src, size_t size) {
break;
#endif
default:
LOG(FATAL) << "unsupported type";
LOG(FATAL) << "Unknown supported target:" << TargetToStr(target);
}
}

Expand Down
20 changes: 20 additions & 0 deletions lite/core/target_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ using lite_api::DataLayoutToStr;
using lite_api::TargetRepr;
using lite_api::PrecisionRepr;
using lite_api::DataLayoutRepr;
using lite_api::CustomAllocator;

namespace host {
const int MALLOC_ALIGN = 64;
Expand Down Expand Up @@ -93,6 +94,25 @@ enum class IoDirection {
DtoD, // Device to device
};

// Allocator
class Allocator {
public:
static Allocator& Global() {
static auto* allocator = new Allocator;
return *allocator;
}

void SetCustomAllocator(CustomAllocator custom_allocator) {
custom_allocator_ = custom_allocator;
}

CustomAllocator GetCustomAllocator() { return custom_allocator_; }

private:
CustomAllocator custom_allocator_;
Allocator() = default;
};

// This interface should be specified by each kind of target.
template <TargetType Target, typename StreamTy = int, typename EventTy = int>
class TargetWrapper {
Expand Down

0 comments on commit 7d27aea

Please sign in to comment.