Skip to content

Commit

Permalink
[DeviceSanitizer] Handle the case of urMemGetNativeHandle getting a n…
Browse files Browse the repository at this point in the history
…ullptr Device (#1969)


Co-authored-by: Yang Zhao <allanzyne@outlook.com>
  • Loading branch information
yingcong-wu and AllanZyne authored Oct 10, 2024
1 parent d52dccb commit ea13b2f
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 9 deletions.
9 changes: 9 additions & 0 deletions source/loader/layers/sanitizer/asan_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {
return UR_RESULT_SUCCESS;
}

// Device may be null, we follow the L0 adapter's practice to use the first
// device
if (!Device) {
auto Devices = GetDevices(Context);
assert(Devices.size() > 0 && "Devices should not be empty");
Device = Devices[0];
}
assert((void *)Device != nullptr && "Device cannot be nullptr");

std::scoped_lock<ur_shared_mutex> Guard(Mutex);
auto &Allocation = Allocations[Device];
ur_result_t URes = UR_RESULT_SUCCESS;
Expand Down
2 changes: 1 addition & 1 deletion source/loader/layers/sanitizer/asan_interceptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ ur_result_t SanitizerInterceptor::updateShadowMemory(
ur_result_t
SanitizerInterceptor::registerDeviceGlobals(ur_context_handle_t Context,
ur_program_handle_t Program) {
std::vector<ur_device_handle_t> Devices = GetProgramDevices(Program);
std::vector<ur_device_handle_t> Devices = GetDevices(Program);

auto ContextInfo = getContextInfo(Context);

Expand Down
32 changes: 25 additions & 7 deletions source/loader/layers/sanitizer/ur_sanitizer_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,22 @@ ur_device_handle_t GetDevice(ur_queue_handle_t Queue) {
return Device;
}

std::vector<ur_device_handle_t> GetDevices(ur_context_handle_t Context) {
std::vector<ur_device_handle_t> Devices{};
uint32_t DeviceNum = 0;
[[maybe_unused]] ur_result_t Result;
Result = getContext()->urDdiTable.Context.pfnGetInfo(
Context, UR_CONTEXT_INFO_NUM_DEVICES, sizeof(uint32_t), &DeviceNum,
nullptr);
assert(Result == UR_RESULT_SUCCESS && "getDevices(Context) failed");
Devices.resize(DeviceNum);
Result = getContext()->urDdiTable.Context.pfnGetInfo(
Context, UR_CONTEXT_INFO_DEVICES,
sizeof(ur_device_handle_t) * DeviceNum, Devices.data(), nullptr);
assert(Result == UR_RESULT_SUCCESS && "getDevices(Context) failed");
return Devices;
}

ur_program_handle_t GetProgram(ur_kernel_handle_t Kernel) {
ur_program_handle_t Program{};
[[maybe_unused]] auto Result = getContext()->urDdiTable.Kernel.pfnGetInfo(
Expand Down Expand Up @@ -169,18 +185,20 @@ bool GetDeviceUSMCapability(ur_device_handle_t Device,
return (bool)Flag;
}

std::vector<ur_device_handle_t> GetProgramDevices(ur_program_handle_t Program) {
size_t PropSize;
std::vector<ur_device_handle_t> GetDevices(ur_program_handle_t Program) {
uint32_t DeviceNum = 0;
[[maybe_unused]] ur_result_t Result =
getContext()->urDdiTable.Program.pfnGetInfo(
Program, UR_PROGRAM_INFO_DEVICES, 0, nullptr, &PropSize);
assert(Result == UR_RESULT_SUCCESS);
Program, UR_PROGRAM_INFO_NUM_DEVICES, sizeof(DeviceNum), &DeviceNum,
nullptr);
assert(Result == UR_RESULT_SUCCESS && "getDevices(Program) failed");

std::vector<ur_device_handle_t> Devices;
Devices.resize(PropSize / sizeof(ur_device_handle_t));
Devices.resize(DeviceNum);
Result = getContext()->urDdiTable.Program.pfnGetInfo(
Program, UR_PROGRAM_INFO_DEVICES, PropSize, Devices.data(), nullptr);
assert(Result == UR_RESULT_SUCCESS);
Program, UR_PROGRAM_INFO_DEVICES,
DeviceNum * sizeof(ur_device_handle_t), Devices.data(), nullptr);
assert(Result == UR_RESULT_SUCCESS && "getDevices(Program) failed");

return Devices;
}
Expand Down
3 changes: 2 additions & 1 deletion source/loader/layers/sanitizer/ur_sanitizer_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ ur_context_handle_t GetContext(ur_queue_handle_t Queue);
ur_context_handle_t GetContext(ur_program_handle_t Program);
ur_context_handle_t GetContext(ur_kernel_handle_t Kernel);
ur_device_handle_t GetDevice(ur_queue_handle_t Queue);
std::vector<ur_device_handle_t> GetDevices(ur_context_handle_t Context);
std::vector<ur_device_handle_t> GetDevices(ur_program_handle_t Program);
DeviceType GetDeviceType(ur_context_handle_t Context,
ur_device_handle_t Device);
ur_device_handle_t GetParentDevice(ur_device_handle_t Device);
Expand All @@ -42,7 +44,6 @@ bool GetDeviceUSMCapability(ur_device_handle_t Device,
std::string GetKernelName(ur_kernel_handle_t Kernel);
size_t GetDeviceLocalMemorySize(ur_device_handle_t Device);
ur_program_handle_t GetProgram(ur_kernel_handle_t Kernel);
std::vector<ur_device_handle_t> GetProgramDevices(ur_program_handle_t Program);
ur_device_handle_t GetUSMAllocDevice(ur_context_handle_t Context,
const void *MemPtr);
uint32_t GetKernelNumArgs(ur_kernel_handle_t Kernel);
Expand Down

0 comments on commit ea13b2f

Please sign in to comment.