Skip to content

Commit

Permalink
allocate and bind ubo
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 10, 2021
1 parent 7d2ed2b commit e1788b8
Showing 1 changed file with 98 additions and 71 deletions.
169 changes: 98 additions & 71 deletions src/runtime/vulkan/vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ struct VulkanBuffer {
VkDeviceMemory memory{VK_NULL_HANDLE};
};

struct UniformBuffer {
VulkanBuffer* vk_buf;
ArgUnion64* host_buf;
};

struct VulkanPipeline {
VulkanContext* vctx_{nullptr};
VkShaderModule shader{VK_NULL_HANDLE};
Expand All @@ -100,11 +105,80 @@ struct VulkanPipeline {
VkPipelineLayout pipeline_layout{VK_NULL_HANDLE};
VkPipeline pipeline{VK_NULL_HANDLE};
VkDescriptorUpdateTemplateKHR descriptor_update_template{VK_NULL_HANDLE};
VulkanBuffer ubo;
UniformBuffer ubo;
};

typedef dmlc::ThreadLocalStore<VulkanThreadEntry> VulkanThreadStore;

VulkanBuffer* CreateBuffer(const VulkanContext& vctx, size_t nbytes, VkBufferUsageFlags usage) {
VkBufferCreateInfo info;
info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
info.pNext = nullptr;
info.flags = 0;
info.size = nbytes;
info.queueFamilyIndexCount = 1;
info.pQueueFamilyIndices = &(vctx.queue_family_index);
info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
info.usage = usage;
// create buffer
VkBuffer buffer;
VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer));

// bind to memory
bool dedicated_allocation = false;
VkMemoryRequirements2KHR req2;

if (vctx.get_buffer_memory_requirements_2_functions) {
VkBufferMemoryRequirementsInfo2KHR req_info2;
req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR;
req_info2.pNext = 0;
req_info2.buffer = buffer;

req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR;
req2.pNext = 0;

VkMemoryDedicatedRequirementsKHR dedicated_req;
dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR;
dedicated_req.pNext = 0;
req2.pNext = &dedicated_req;

vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR(
vctx.device, &req_info2, &req2);
dedicated_allocation =
dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation;
}

VkDeviceMemory memory;
// TODO: revisit memoryTypeIndex
if (!dedicated_allocation) {
VkMemoryAllocateInfo minfo;
minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
minfo.pNext = nullptr;
minfo.allocationSize = nbytes;
minfo.memoryTypeIndex = vctx.compute_mtype_index;
VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
} else {
VkMemoryAllocateInfo minfo;
minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
minfo.pNext = nullptr;
minfo.allocationSize = req2.memoryRequirements.size;
minfo.memoryTypeIndex = vctx.compute_mtype_index;

VkMemoryDedicatedAllocateInfoKHR mdinfo;
mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR;
mdinfo.pNext = 0;
mdinfo.image = 0;
mdinfo.buffer = buffer;
minfo.pNext = &mdinfo;
VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
}
VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0));
VulkanBuffer* pbuf = new VulkanBuffer();
pbuf->memory = memory;
pbuf->buffer = buffer;
return pbuf;
}

class VulkanDeviceAPI final : public DeviceAPI {
public:
VulkanDeviceAPI();
Expand All @@ -125,70 +199,9 @@ class VulkanDeviceAPI final : public DeviceAPI {
nbytes = 1;
}
const auto& vctx = context(dev.device_id);
VkBufferCreateInfo info;
info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
info.pNext = nullptr;
info.flags = 0;
info.size = nbytes;
info.queueFamilyIndexCount = 1;
info.pQueueFamilyIndices = &(vctx.queue_family_index);
info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
// create buffer
VkBuffer buffer;
VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer));
// bind to memory
VkBufferMemoryRequirementsInfo2KHR req_info2;
req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR;
req_info2.pNext = 0;
req_info2.buffer = buffer;

VkMemoryRequirements2KHR req2;
req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR;
req2.pNext = 0;

VkMemoryDedicatedRequirementsKHR dedicated_req;
dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR;
dedicated_req.pNext = 0;
req2.pNext = &dedicated_req;

bool dedicated_allocation = false;
if (vctx.get_buffer_memory_requirements_2_functions) {
vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR(
vctx.device, &req_info2, &req2);
dedicated_allocation =
dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation;
}

VkDeviceMemory memory;
if (!dedicated_allocation) {
VkMemoryAllocateInfo minfo;
minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
minfo.pNext = nullptr;
minfo.allocationSize = nbytes;
minfo.memoryTypeIndex = vctx.compute_mtype_index;
VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
} else {
VkMemoryAllocateInfo minfo;
minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
minfo.pNext = nullptr;
minfo.allocationSize = req2.memoryRequirements.size;
minfo.memoryTypeIndex = vctx.compute_mtype_index;

VkMemoryDedicatedAllocateInfoKHR mdinfo;
mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR;
mdinfo.pNext = 0;
mdinfo.image = 0;
mdinfo.buffer = buffer;
minfo.pNext = &mdinfo;
VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
}
VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0));
VulkanBuffer* pbuf = new VulkanBuffer();
pbuf->memory = memory;
pbuf->buffer = buffer;
return pbuf;
return CreateBuffer(vctx, nbytes, usage);
}

void FreeDataSpace(Device dev, void* ptr) final {
Expand Down Expand Up @@ -784,6 +797,11 @@ class VulkanModuleNode final : public runtime::ModuleNode {
vkDestroyDescriptorPool(vctx.device, pe->descriptor_pool, nullptr);
vkDestroyDescriptorSetLayout(vctx.device, pe->descriptor_set_layout, nullptr);
vkDestroyShaderModule(vctx.device, pe->shader, nullptr);
// UBO
vkDestroyBuffer(vctx.device, pe->ubo.vk_buf->buffer, nullptr);
vkFreeMemory(vctx.device, pe->ubo.vk_buf->memory, nullptr);
delete pe->ubo.vk_buf;
delete[] pe->ubo.host_buf;
}
}
}
Expand Down Expand Up @@ -846,14 +864,14 @@ class VulkanModuleNode final : public runtime::ModuleNode {
}
}

if (num_pod != 0 && num_pod * 8 > 120) {
size_t nbytes_scalars = num_pod * sizeof(ArgUnion64);
if (nbytes_scalars > 120) {
ICHECK(num_pod == num_pack_args);
// UBO
// TODO: allocate ubo
{
VkDescriptorSetLayoutBinding bd;
bd.binding = num_buffer;
bd.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
bd.descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
bd.descriptorCount = 1;
bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
bd.pImmutableSamplers = nullptr;
Expand All @@ -864,7 +882,7 @@ class VulkanModuleNode final : public runtime::ModuleNode {
tpl.dstBinding = num_buffer;
tpl.dstArrayElement = 0;
tpl.descriptorCount = 1;
tpl.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
tpl.descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
tpl.offset = num_buffer * sizeof(VkDescriptorBufferInfo);
tpl.stride = sizeof(VkDescriptorBufferInfo);
arg_template.push_back(tpl);
Expand Down Expand Up @@ -951,6 +969,15 @@ class VulkanModuleNode final : public runtime::ModuleNode {
VULKAN_CALL(vkCreateComputePipelines(vctx.device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr,
&(pe->pipeline)));

if (nbytes_scalars > 120) {
// Allocate, bind and map UBO
UniformBuffer ubo = pe->ubo;
ubo.host_buf = new ArgUnion64[nbytes_scalars];
ubo.vk_buf = CreateBuffer(vctx, nbytes_scalars, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT);
void* host_ptr = ubo.host_buf;
vkMapMemory(vctx.device, ubo.vk_buf->memory, 0, nbytes_scalars, 0, &host_ptr);
}

if (vctx.UseImmediate()) {
VkDescriptorUpdateTemplateCreateInfoKHR descrip_template_cinfo;
descrip_template_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_CREATE_INFO_KHR;
Expand Down Expand Up @@ -1104,11 +1131,11 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
binfo.range = VK_WHOLE_SIZE;
descriptor_buffers[i] = binfo;
}
if (num_pack_args_ != 0 && num_pack_args_ * 8 > 120) {
if (num_pack_args_ != 0 && num_pack_args_ * sizeof(ArgUnion64) > 120) {
// UBO
// TODO: copy pack_args
memcpy(pipeline->ubo.host_buf, pack_args, num_pack_args_ * sizeof(ArgUnion64));
VkDescriptorBufferInfo binfo;
binfo.buffer = pipeline->ubo.buffer;
binfo.buffer = pipeline->ubo.vk_buf->buffer;
binfo.offset = 0;
binfo.range = VK_WHOLE_SIZE;
descriptor_buffers.push_back(binfo);
Expand Down

0 comments on commit e1788b8

Please sign in to comment.