Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Vulkan] Rewrote PointerValueTypeRewrite transform #8528

Merged
merged 3 commits into from
Jul 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 78 additions & 92 deletions src/target/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::
this->InitFuncState();
ICHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) << "SPIRV only takes restricted memory model";
std::vector<Var> pod_args;
uint32_t num_buffer = 0;
uint32_t i_buffer = 0;

// Currently, all storage and uniform buffer arguments are passed as
// a single descriptor set at index 0. If ever non-zero, must
Expand All @@ -53,24 +53,25 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::
for (Var arg : f->params) {
DataType t = arg.dtype();
if (t.is_handle()) {
if (auto* ptr = arg->type_annotation.as<PointerTypeNode>()) {
auto* prim = ptr->element_type.as<PrimTypeNode>();
ICHECK(prim);
DataType value_storage_type = prim->dtype;
if (value_storage_type == DataType::UInt(1)) {
// We need a physically addressable buffer type to support boolean tensors.
// The loaded byte is cast to bool inside the LoadNode visitor below.
value_storage_type = DataType::UInt(8);
}
spirv::Value arg_value = builder_->BufferArgument(builder_->GetSType(value_storage_type),
descriptor_set, num_buffer);
builder_->SetName(arg_value, arg->name_hint);
storage_info_[arg.get()].UpdateContentType(value_storage_type);
var_map_[arg.get()] = arg_value;
} else {
LOG(FATAL) << "require all handles to be typed";
auto* ptr = arg->type_annotation.as<PointerTypeNode>();
ICHECK(ptr) << "All handles passed to the Vulkan codegen must have a type_annotation as a "
"PointerType, "
<< "and must point to a PrimType";
auto* prim = ptr->element_type.as<PrimTypeNode>();
ICHECK(prim) << "All handles passed to the Vulkan codegen must have a type_annotation as a "
"PointerType, "
<< "and must point to a PrimType";
DataType value_storage_type = prim->dtype;
if (value_storage_type == DataType::Bool()) {
// We need a physically addressable buffer type to support boolean tensors.
// The loaded byte is cast to bool inside the LoadNode visitor below.
value_storage_type = boolean_storage_type_.with_lanes(value_storage_type.lanes());
}
++num_buffer;
spirv::Value arg_value = builder_->BufferArgument(builder_->GetSType(value_storage_type),
descriptor_set, i_buffer++);
builder_->SetName(arg_value, arg->name_hint);
storage_info_[arg.get()].SetContentType(value_storage_type, arg->name_hint);
var_map_[arg.get()] = arg_value;
} else {
pod_args.push_back(arg);
}
Expand All @@ -95,7 +96,7 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::
} else {
shader.flag |= 1 << runtime::vulkan::ShaderMetaDataFlagMask::kUseUBO;
// If we need to pass more arguments than push constants could handle, we use UBO.
spirv::Value ptr = builder_->DeclareUniformBuffer(value_types, descriptor_set, num_buffer);
spirv::Value ptr = builder_->DeclareUniformBuffer(value_types, descriptor_set, i_buffer++);
for (size_t i = 0; i < pod_args.size(); ++i) {
spirv::Value value = builder_->GetUniform(ptr, value_types[i], static_cast<uint32_t>(i));
var_map_[pod_args[i].get()] = value;
Expand Down Expand Up @@ -404,62 +405,58 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const BroadcastNode* op) {

spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) {
ICHECK(is_one(op->predicate));
auto it = storage_info_.find(op->buffer_var.get());

DataType desired_read_type = op->dtype;
if (desired_read_type == DataType::Bool()) {
desired_read_type = boolean_storage_type_.with_lanes(desired_read_type.lanes());
}

const VarNode* buffer_var = op->buffer_var.get();
auto it = storage_info_.find(buffer_var);
ICHECK(it != storage_info_.end());
StorageInfo& info = it->second;
if (!info.content_fixed) {
info.UpdateContentType(op->dtype);
}
info.CheckContentType(desired_read_type, op->index.dtype().lanes());

spirv::SType content_type = builder_->GetSType(info.content_type);
spirv::SType content_type = builder_->GetSType(info.element_type);
spirv::Value buffer = MakeValue(op->buffer_var);
spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class);

uint32_t mask = spv::MemoryAccessMaskNone;
if (info.is_volatile) {
mask |= spv::MemoryAccessVolatileMask;
}
if (op->dtype.lanes() == 1) {

if (desired_read_type == info.element_type) {
// Requested a single value from an array. This may be a scalar load
// or a vectorized load, based on the array element type.
spirv::Value index = MakeValue(op->index);
spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
spirv::Value loaded = builder_->MakeValue(spv::OpLoad, content_type, ptr, mask);
if (op->dtype == DataType::UInt(1)) {
// A bool tensor is backed by a byte buffer, we cast to bool here.
auto bool_ty = builder_->GetSType(DataType::UInt(1));
return builder_->Cast(bool_ty, loaded);
} else {
ICHECK_EQ(info.content_type, op->dtype)
<< "Vulkan only allow one type access to the same buffer";
return loaded;
// OpTypeBool have no physical address/storage. Here, cast from
// the storage type to an OpTypeBool.
if (op->dtype == DataType::Bool()) {
auto spirv_bool = builder_->GetSType(DataType::Bool());
loaded = builder_->Cast(spirv_bool, loaded);
}
return loaded;

} else if (desired_read_type.element_of() == info.element_type) {
// Requested several elements returned as an array. Read out each
// element and concatenate into the result.
std::vector<spirv::Value> values;
auto f = [&](int i, spirv::Value index) {
spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
values.emplace_back(builder_->MakeValue(spv::OpLoad, content_type, ptr, mask));
};
this->Scalarize(op->index, f);
return builder_->Concat(values);

} else {
if (op->dtype.element_of() == info.content_type) {
// because content type is element type, we can only do scalarize load.
std::vector<spirv::Value> values;
auto f = [&](int i, spirv::Value index) {
spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
values.emplace_back(builder_->MakeValue(spv::OpLoad, content_type, ptr, mask));
};
this->Scalarize(op->index, f);
return builder_->Concat(values);
} else {
if (const RampNode* ramp = op->index.as<RampNode>()) {
if (is_one(ramp->stride)) {
ICHECK_EQ(ramp->lanes, op->dtype.lanes());
arith::ModularSet me = analyzer_->modular_set(ramp->base);
ICHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0)
<< "Only aligned vector access is allowed in SPIRV";
PrimExpr vec_index =
analyzer_->Simplify(ramp->base / make_const(ramp->base.dtype(), ramp->lanes));
spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, MakeValue(vec_index));
return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask);
}
}
}
LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV";
LOG(FATAL) << "Cannot perform buffer access of buffer variable '" << buffer_var->name_hint
<< "' with element type " << info.element_type << " using index of type "
<< op->index->dtype << " to produce output of type " << op->dtype;
return spirv::Value();
}
LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV";
return spirv::Value();
}

void CodeGenSPIRV::Scalarize(const PrimExpr& e, std::function<void(int i, spirv::Value v)> f) {
Expand All @@ -482,12 +479,9 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) {
auto it = storage_info_.find(op->buffer_var.get());
ICHECK(it != storage_info_.end());
StorageInfo& info = it->second;
info.CheckContentType(op->value.dtype(), op->index.dtype().lanes());

if (!info.content_fixed) {
info.UpdateContentType(op->value.dtype());
}

spirv::SType content_type = builder_->GetSType(info.content_type);
spirv::SType content_type = builder_->GetSType(info.element_type);
spirv::Value buffer = MakeValue(op->buffer_var);
spirv::Value value = MakeValue(op->value);
spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class);
Expand All @@ -497,37 +491,29 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) {
mask |= spv::MemoryAccessVolatileMask;
}

if (op->value.dtype().lanes() == 1) {
ICHECK_EQ(info.content_type, op->value.dtype())
if (op->value.dtype() == info.element_type) {
// Requested store of a single value. This may be a scalar store
// or a vectorized store, based on the array element type.
ICHECK_EQ(info.element_type, op->value.dtype())
<< "Vulkan only allow one type access to the same buffer";
spirv::Value index = MakeValue(op->index);
spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
builder_->MakeInst(spv::OpStore, ptr, value, mask);

} else if (op->value.dtype().element_of() == info.element_type) {
// Requested store of several arbitrarily located values. Extract
// each value from the composite, then assign to the buffer.
auto f = [&](int i, spirv::Value index) {
spirv::Value elem = builder_->MakeValue(spv::OpCompositeExtract, content_type, value, i);
spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
builder_->MakeInst(spv::OpStore, ptr, elem, mask);
};
this->Scalarize(op->index, f);

} else {
if (op->value.dtype().element_of() == info.content_type) {
// because content type is element type, we can only do scalarize load.
auto f = [&](int i, spirv::Value index) {
spirv::Value elem = builder_->MakeValue(spv::OpCompositeExtract, content_type, value, i);
spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
builder_->MakeInst(spv::OpStore, ptr, elem, mask);
};
this->Scalarize(op->index, f);
} else {
if (const RampNode* ramp = op->index.as<RampNode>()) {
if (is_one(ramp->stride)) {
ICHECK_EQ(ramp->lanes, op->value.dtype().lanes());
arith::ModularSet me = analyzer_->modular_set(ramp->base);
ICHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0)
<< "Only aligned vector access is allowed in SPIRV";
PrimExpr vec_index =
analyzer_->Simplify(ramp->base / make_const(ramp->base.dtype(), ramp->lanes));
spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, MakeValue(vec_index));
builder_->MakeInst(spv::OpStore, ptr, value, mask);
return;
}
}
LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV";
}
LOG(FATAL) << "Cannot store value of type " << op->value.dtype() << " into buffer variable '"
<< op->buffer_var->name_hint << "' with element type " << info.element_type
<< " using index of type " << op->index->dtype;
}
}

Expand Down Expand Up @@ -663,8 +649,8 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) {
builder_->SetName(buf, op->buffer_var->name_hint);

StorageInfo& info = storage_info_[op->buffer_var.get()];
ICHECK(!info.content_fixed);
info.UpdateContentType(op->dtype);
ICHECK(!info.element_type_known);
info.SetContentType(op->dtype, op->buffer_var->name_hint);

ICHECK(!var_map_.count(op->buffer_var.get()));
var_map_[op->buffer_var.get()] = buf;
Expand Down
85 changes: 71 additions & 14 deletions src/target/spirv/codegen_spirv.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,47 +114,104 @@ class CodeGenSPIRV : public ExprFunctor<spirv::Value(const PrimExpr&)>,
void VisitStmt_(const EvaluateNode* op) override;

protected:
/*! \brief The storage information */
/*! \brief Storage information for a buffer */
struct StorageInfo {
/*! \brief The name of the tir::Var for the buffer
*
* Used for error messages.
*/
std::string name_hint;

/*! \brief Whether it is volatile */
bool is_volatile{false};
/*! \brief Whether it is volatile */
bool content_fixed{false};
/*! \brief Current content type */
DataType content_type{DataType::Handle()};

// Update content type if it hasn't beenupdated.
void UpdateContentType(DataType type) {
if (content_fixed) {
ICHECK_EQ(type, content_type) << "Cannot use two different content type in GLSL model";
} else {
this->content_type = type;
content_fixed = true;
}

/*! \brief Whether the element type of the buffer is known.
*
* This value is determined based on the type_annotation of the
* buffer variable (AllocateNode) or of the parameter (shader
* arguments).
*/
bool element_type_known{false};

/*! \brief The known element type of the buffer.
*
* This value is determined based on the type_annotation of the
* buffer variable (AllocateNode) or of the parameter (shader
* arguments).
*/
DataType element_type{DataType()};

/* \brief Check that the access type matches the known type
*
* Asserts that the type given is the same as the type previously
* stored in this array.
*
* @param type The data type being stored/loaded in the buffer
*
* @param index_lanes The number of lanes of the index. The
* number of lanes in the value being stored/loaded should be the
* product of the number of lanes of the buffer element type and
* the number of lanes of the index.
*/
void CheckContentType(DataType type, int index_lanes = 1) {
ICHECK(element_type_known) << "Cannot check element type of buffer " << name_hint
<< " no previous element type defined";
DataType expected_type = element_type.with_lanes(index_lanes * element_type.lanes());
ICHECK_EQ(type, expected_type) << "Attempted to access buffer " << name_hint
<< " as element type " << type << " using an index of size "
<< index_lanes << " when the element type is " << element_type;
}

// Update content type if it hasn't been updated.
void SetContentType(DataType type, std::string name_hint) {
ICHECK(!element_type_known) << "Cannot set element type of buffer " << name_hint
<< " a second time.";
this->element_type = type;
this->name_hint = name_hint;
element_type_known = true;
}
};
// Reset the state so it works for a new function.
void InitFuncState();
// Get the thread index
spirv::Value GetThreadIndex(const IterVar& iv, const PrimExpr& extent);

spirv::Value CreateStorageSync(const CallNode* op);
void Scalarize(const PrimExpr& e, std::function<void(int i, spirv::Value v)> f);

// SPIRV-related capabilities of the target
SPIRVSupport spirv_support_;

// The builder
std::unique_ptr<spirv::IRBuilder> builder_;

// Work group size of three
uint32_t workgroup_size_[3];

// Likely branch
uint32_t weight_likely_branch_{128};

/* The data type used for the backing array for booleans.
*
* Currently matched to the data type used in Buffer::vstore and
* Buffer::vload. In the future, this should be the smallest
* integer type supported by the device, as not all Vulkan
* implementations support int8.
*/
DataType boolean_storage_type_{DataType::Int(8)};

// the storage scope of allocation
std::unordered_map<const VarNode*, StorageInfo> storage_info_;

// The definition of local variable.
std::unordered_map<const VarNode*, spirv::Value> var_map_;

// The analyzer.
std::unique_ptr<arith::Analyzer> analyzer_;

// deep comparison of PrimExpr
ExprDeepEqual deep_equal_;

// binding of let variables. Enables duplicate var defs that map to same value
std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
};
Expand Down
Loading