Skip to content

Commit

Permalink
query push constant size using runtime API
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 10, 2021
1 parent 95ec1db commit 9a67f4a
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 8 deletions.
27 changes: 20 additions & 7 deletions src/runtime/vulkan/vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,9 @@ class VulkanModuleNode final : public runtime::ModuleNode {
public:
explicit VulkanModuleNode(std::unordered_map<std::string, VulkanShader> smap,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
: smap_(smap), fmap_(fmap), source_(source) {}
: smap_(smap), fmap_(fmap), source_(source), max_push_constants_(GetMaxPushConstantsSize()) {
LOG(INFO) << "VulkanModuleNode, max_push_constants: " << max_push_constants_;
}

const char* type_key() const final { return "vulkan"; }

Expand Down Expand Up @@ -896,7 +898,8 @@ class VulkanModuleNode final : public runtime::ModuleNode {
}

size_t nbytes_scalars = num_pod * sizeof(ArgUnion64);
if (nbytes_scalars > MAX_PUSHCONSTANTS) {
if (nbytes_scalars > max_push_constants_) {
LOG(INFO) << "Using ubo";
push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER);
}

Expand Down Expand Up @@ -951,7 +954,7 @@ class VulkanModuleNode final : public runtime::ModuleNode {
playout_cinfo.setLayoutCount = 1;
playout_cinfo.pSetLayouts = &(pe->descriptor_set_layout);

if (0 < nbytes_scalars && nbytes_scalars <= MAX_PUSHCONSTANTS) {
if (0 < nbytes_scalars && nbytes_scalars <= max_push_constants_) {
playout_cinfo.pushConstantRangeCount = 1;
playout_cinfo.pPushConstantRanges = &crange;
ICHECK_LE(crange.size, vctx.phy_device_prop.limits.maxPushConstantsSize);
Expand Down Expand Up @@ -980,7 +983,7 @@ class VulkanModuleNode final : public runtime::ModuleNode {
VULKAN_CALL(vkCreateComputePipelines(vctx.device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr,
&(pe->pipeline)));

if (nbytes_scalars > MAX_PUSHCONSTANTS) {
if (nbytes_scalars > max_push_constants_) {
// Allocate, bind and map UBO
UniformBuffer& ubo = pe->ubo;
ubo.host_buf = new ArgUnion64[num_pod];
Expand Down Expand Up @@ -1031,6 +1034,8 @@ class VulkanModuleNode final : public runtime::ModuleNode {
return source_;
}

uint32_t MaxPushConstantsSize() const { return max_push_constants_; }

private:
// function information table.
std::unordered_map<std::string, VulkanShader> smap_;
Expand All @@ -1040,6 +1045,8 @@ class VulkanModuleNode final : public runtime::ModuleNode {
std::string fmt_{"vulkan"};
// The source
std::string source_;
// The maximum size of push constants in bytes
uint32_t max_push_constants_;

// Guards accesses to `ecache_`
std::mutex mutex_;
Expand Down Expand Up @@ -1142,7 +1149,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
descriptor_buffers[i] = binfo;
}
const size_t nbytes_scalars = num_pack_args_ * sizeof(ArgUnion64);
bool use_ubo = num_pack_args_ != 0 && nbytes_scalars > MAX_PUSHCONSTANTS;
bool use_ubo = num_pack_args_ != 0 && nbytes_scalars > m_->MaxPushConstantsSize();
if (use_ubo) {
CHECK(pipeline->ubo.host_buf) << "The UBO host buffer is not allocated";
memcpy(pipeline->ubo.host_buf, pack_args, nbytes_scalars);
Expand All @@ -1160,7 +1167,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
vctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR(
state->cmd_buffer_, pipeline->descriptor_update_template, pipeline->pipeline_layout, 0,
descriptor_buffers.data());
if (num_pack_args_ > 0 && num_pack_args_ <= MAX_PUSHCONSTANTS) {
if (num_pack_args_ > 0 && num_pack_args_ <= m_->MaxPushConstantsSize()) {
vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout,
VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion64),
pack_args);
Expand Down Expand Up @@ -1210,7 +1217,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
vkCmdBindDescriptorSets(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE,
pipeline->pipeline_layout, 0, 1, &(pipeline->descriptor_set), 0,
nullptr);
if (num_pack_args_ > 0 && num_pack_args_ <= MAX_PUSHCONSTANTS) {
if (num_pack_args_ > 0 && num_pack_args_ <= m_->MaxPushConstantsSize()) {
vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT,
0, pack_args_storage.size() * sizeof(ArgUnion64),
pack_args_storage.data());
Expand Down Expand Up @@ -1265,6 +1272,12 @@ Module VulkanModuleLoadBinary(void* strm) {
return VulkanModuleCreate(smap, fmap, "");
}

uint32_t GetMaxPushConstantsSize() {
int device_id = VulkanThreadEntry::ThreadLocal()->device.device_id;
const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
return vctx.phy_device_prop.limits.maxPushConstantsSize;
}

TVM_REGISTER_GLOBAL("runtime.module.loadfile_vulkan").set_body_typed(VulkanModuleLoadFile);

TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary);
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/vulkan/vulkan_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ struct VulkanContext {
bool UseImmediate() const { return descriptor_template_khr_functions.get() != nullptr; }
};

uint32_t GetMaxPushConstantsSize();

} // namespace vulkan
} // namespace runtime
} // namespace tvm
Expand Down
4 changes: 3 additions & 1 deletion src/target/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,16 @@ std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::
for (size_t i = 0; i < pod_args.size(); ++i) {
value_types.push_back(builder_->GetSType(pod_args[i].dtype()));
}
if (pod_args.size() * sizeof(runtime::ArgUnion64) <= MAX_PUSHCONSTANTS) {
const auto max_push_constants = runtime::vulkan::GetMaxPushConstantsSize();
if (pod_args.size() * sizeof(runtime::ArgUnion64) <= max_push_constants) {
spirv::Value ptr = builder_->DeclarePushConstant(value_types);
for (size_t i = 0; i < pod_args.size(); ++i) {
spirv::Value value =
builder_->GetPushConstant(ptr, value_types[i], static_cast<uint32_t>(i));
var_map_[pod_args[i].get()] = value;
}
} else {
// If we need to pass more arguments than push constants could handle, we use UBO.
spirv::Value ptr = builder_->DeclareUniformBuffer(value_types, num_buffer);
for (size_t i = 0; i < pod_args.size(); ++i) {
spirv::Value value = builder_->GetUniform(ptr, value_types[i], static_cast<uint32_t>(i));
Expand Down

0 comments on commit 9a67f4a

Please sign in to comment.