diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc index 0edbe683aece..0290fafe7b8a 100644 --- a/src/runtime/vulkan/vulkan.cc +++ b/src/runtime/vulkan/vulkan.cc @@ -786,7 +786,9 @@ class VulkanModuleNode final : public runtime::ModuleNode { public: explicit VulkanModuleNode(std::unordered_map smap, std::unordered_map fmap, std::string source) - : smap_(smap), fmap_(fmap), source_(source) {} + : smap_(smap), fmap_(fmap), source_(source), max_push_constants_(GetMaxPushConstantsSize()) { + LOG(INFO) << "VulkanModuleNode, max_push_constants: " << max_push_constants_; + } const char* type_key() const final { return "vulkan"; } @@ -896,7 +898,8 @@ class VulkanModuleNode final : public runtime::ModuleNode { } size_t nbytes_scalars = num_pod * sizeof(ArgUnion64); - if (nbytes_scalars > MAX_PUSHCONSTANTS) { + if (nbytes_scalars > max_push_constants_) { + LOG(INFO) << "Using ubo"; push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER); } @@ -951,7 +954,7 @@ class VulkanModuleNode final : public runtime::ModuleNode { playout_cinfo.setLayoutCount = 1; playout_cinfo.pSetLayouts = &(pe->descriptor_set_layout); - if (0 < nbytes_scalars && nbytes_scalars <= MAX_PUSHCONSTANTS) { + if (0 < nbytes_scalars && nbytes_scalars <= max_push_constants_) { playout_cinfo.pushConstantRangeCount = 1; playout_cinfo.pPushConstantRanges = &crange; ICHECK_LE(crange.size, vctx.phy_device_prop.limits.maxPushConstantsSize); @@ -980,7 +983,7 @@ class VulkanModuleNode final : public runtime::ModuleNode { VULKAN_CALL(vkCreateComputePipelines(vctx.device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr, &(pe->pipeline))); - if (nbytes_scalars > MAX_PUSHCONSTANTS) { + if (nbytes_scalars > max_push_constants_) { // Allocate, bind and map UBO UniformBuffer& ubo = pe->ubo; ubo.host_buf = new ArgUnion64[num_pod]; @@ -1031,6 +1034,8 @@ class VulkanModuleNode final : public runtime::ModuleNode { return source_; } + uint32_t MaxPushConstantsSize() const { return max_push_constants_; } + private: // function information table. std::unordered_map smap_; @@ -1040,6 +1045,8 @@ class VulkanModuleNode final : public runtime::ModuleNode { std::string fmt_{"vulkan"}; // The source std::string source_; + // The maximum size of push constants in bytes + uint32_t max_push_constants_; // Guards accesses to `ecache_` std::mutex mutex_; @@ -1142,7 +1149,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, descriptor_buffers[i] = binfo; } const size_t nbytes_scalars = num_pack_args_ * sizeof(ArgUnion64); - bool use_ubo = num_pack_args_ != 0 && nbytes_scalars > MAX_PUSHCONSTANTS; + bool use_ubo = num_pack_args_ != 0 && nbytes_scalars > m_->MaxPushConstantsSize(); if (use_ubo) { CHECK(pipeline->ubo.host_buf) << "The UBO host buffer is not allocated"; memcpy(pipeline->ubo.host_buf, pack_args, nbytes_scalars); @@ -1160,7 +1167,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, vctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR( state->cmd_buffer_, pipeline->descriptor_update_template, pipeline->pipeline_layout, 0, descriptor_buffers.data()); - if (num_pack_args_ > 0 && num_pack_args_ <= MAX_PUSHCONSTANTS) { + if (num_pack_args_ > 0 && num_pack_args_ <= m_->MaxPushConstantsSize()) { vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion64), pack_args); @@ -1210,7 +1217,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, vkCmdBindDescriptorSets(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline_layout, 0, 1, &(pipeline->descriptor_set), 0, nullptr); - if (num_pack_args_ > 0 && num_pack_args_ <= MAX_PUSHCONSTANTS) { + if (num_pack_args_ > 0 && num_pack_args_ <= m_->MaxPushConstantsSize()) { vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT, 0, pack_args_storage.size() * sizeof(ArgUnion64), pack_args_storage.data()); @@ -1265,6 +1272,12 @@ Module VulkanModuleLoadBinary(void* strm) { return VulkanModuleCreate(smap, fmap, ""); } +uint32_t GetMaxPushConstantsSize() { + int device_id = VulkanThreadEntry::ThreadLocal()->device.device_id; + const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); + return vctx.phy_device_prop.limits.maxPushConstantsSize; +} + TVM_REGISTER_GLOBAL("runtime.module.loadfile_vulkan").set_body_typed(VulkanModuleLoadFile); TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary); diff --git a/src/runtime/vulkan/vulkan_common.h b/src/runtime/vulkan/vulkan_common.h index ab38da84b3df..9ee28fee41c6 100644 --- a/src/runtime/vulkan/vulkan_common.h +++ b/src/runtime/vulkan/vulkan_common.h @@ -145,6 +145,8 @@ struct VulkanContext { bool UseImmediate() const { return descriptor_template_khr_functions.get() != nullptr; } }; +uint32_t GetMaxPushConstantsSize(); + } // namespace vulkan } // namespace runtime } // namespace tvm diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index d4c29799bee9..4d55f4c49a5f 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -74,7 +74,8 @@ std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std:: for (size_t i = 0; i < pod_args.size(); ++i) { value_types.push_back(builder_->GetSType(pod_args[i].dtype())); } - if (pod_args.size() * sizeof(runtime::ArgUnion64) <= MAX_PUSHCONSTANTS) { + const auto max_push_constants = runtime::vulkan::GetMaxPushConstantsSize(); + if (pod_args.size() * sizeof(runtime::ArgUnion64) <= max_push_constants) { spirv::Value ptr = builder_->DeclarePushConstant(value_types); for (size_t i = 0; i < pod_args.size(); ++i) { spirv::Value value = @@ -82,6 +83,7 @@ std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std:: var_map_[pod_args[i].get()] = value; } } else { + // If we need to pass more arguments than push constants could handle, we use UBO. spirv::Value ptr = builder_->DeclareUniformBuffer(value_types, num_buffer); for (size_t i = 0; i < pod_args.size(); ++i) { spirv::Value value = builder_->GetUniform(ptr, value_types[i], static_cast(i));