diff --git a/src/runtime/vulkan/vulkan_device.cc b/src/runtime/vulkan/vulkan_device.cc index 16987210b232..156f86dbb03e 100644 --- a/src/runtime/vulkan/vulkan_device.cc +++ b/src/runtime/vulkan/vulkan_device.cc @@ -156,6 +156,27 @@ VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance, device_name = properties.properties.deviceName; driver_version = properties.properties.driverVersion; + switch (properties.properties.deviceType) { + case VK_PHYSICAL_DEVICE_TYPE_OTHER: + device_type = "other"; + break; + case VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU: + device_type = "integrated"; + break; + case VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU: + device_type = "discrete"; + break; + case VK_PHYSICAL_DEVICE_TYPE_VIRTUAL_GPU: + device_type = "virtual"; + break; + case VK_PHYSICAL_DEVICE_TYPE_CPU: + device_type = "cpu"; + break; + default: + LOG(FATAL) << "Unknown vulkan device type: " << properties.properties.deviceType; + break; + } + // By default, use the maximum API version that the driver allows, // so that any supported features can be used by TVM shaders. // However, if we can query the conformance version, then limit to diff --git a/src/runtime/vulkan/vulkan_device.h b/src/runtime/vulkan/vulkan_device.h index 045628bc9092..412542029209 100644 --- a/src/runtime/vulkan/vulkan_device.h +++ b/src/runtime/vulkan/vulkan_device.h @@ -92,7 +92,8 @@ struct VulkanDeviceProperties { uint32_t max_storage_buffer_range{1 << 27}; uint32_t max_per_stage_descriptor_storage_buffer{4}; uint32_t max_shared_memory_per_block{16384}; - std::string device_name{"unknown device name"}; + std::string device_type{"unknown_device_type"}; + std::string device_name{"unknown_device_name"}; uint32_t driver_version{0}; uint32_t vulkan_api_version{VK_API_VERSION_1_0}; uint32_t max_spirv_version{0x10000}; diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index 1fede98f7211..b4987eb321cf 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -50,6 +50,28 @@ VulkanDeviceAPI::VulkanDeviceAPI() { devices_.push_back(std::move(device)); } } + + // Move discrete GPUs to the start of the list, so the default + // device_id=0 preferentially uses a discrete GPU. + auto preference = [](const VulkanDevice& device) { + const std::string& type = device.device_properties.device_type; + if (type == "discrete") { + return 0; + } else if (type == "integrated") { + return 1; + } else if (type == "virtual") { + return 2; + } else if (type == "cpu") { + return 3; + } else { + return 4; + } + }; + + std::stable_sort(devices_.begin(), devices_.end(), + [&preference](const VulkanDevice& a, const VulkanDevice& b) { + return preference(a) < preference(b); + }); } VulkanDeviceAPI::~VulkanDeviceAPI() {} @@ -214,8 +236,8 @@ void VulkanDeviceAPI::GetTargetProperty(Device dev, const std::string& property, if (property == "max_shared_memory_per_block") { *rv = int64_t(prop.max_shared_memory_per_block); } - if (property == ":string device_name") { - *rv = prop.device_name; + if (property == "device_name") { + *rv = String(prop.device_name); } if (property == "driver_version") { *rv = int64_t(prop.driver_version); diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index d037b9dfdbdb..a56916248858 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -249,7 +249,7 @@ Map UpdateVulkanAttrs(Map attrs) { "driver_version", "vulkan_api_version", "max_spirv_version"}; - std::vector str_opts = {"device_name"}; + std::vector str_opts = {"device_name", "device_type"}; for (auto& key : bool_opts) { if (!attrs.count(key)) { @@ -387,6 +387,7 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .add_attr_option("max_per_stage_descriptor_storage_buffer") .add_attr_option("max_shared_memory_per_block") // Other device properties + .add_attr_option("device_type") .add_attr_option("device_name") .add_attr_option("driver_version") .add_attr_option("vulkan_api_version")