From a5f23fe5496d3f72a77fd8628b314d618d255ef2 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 29 Jul 2021 12:52:53 -0500 Subject: [PATCH] [Vulkan] Allow for pointer rewrites that change base type. A single memory allocation may have more than one type of data stored within it. This allows the PointerTypeRewrite pass to recognize if a function only uses the pointer as a particular base type. This wasn't an issue in C-based codegen, but is required for Vulkan. Since Vulkan shaders do not permit type-casting, the cast must be done when passing the pointer argument into the shader. --- src/tir/transforms/storage_rewrite.cc | 82 ++++++++++++++++----------- 1 file changed, 50 insertions(+), 32 deletions(-) diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 80dca3289fe7..eab750b2af2b 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -921,29 +921,46 @@ struct BufferVarInfo { // Where the buffer was declared DeclarationLocation declaration_location; - // When accessed, how many lanes of data are used. - std::set lanes_used; + // When accessed, which element type is it accessed as. This may + // differ both in base type (e.g. int32* cast to float32* after + // packing in StorageRewrite) or in number of lanes (e.g. float16* + // cast to float16x4*). + std::unordered_set access_dtype; + + DataType get_preferred_dtype() const { + std::unordered_set base_access_dtype; + for (auto dtype : access_dtype) { + base_access_dtype.insert(dtype.element_of()); + } + // If the array is accessed as multiple base types within a + // function, no point in changing the declared type. CodeGenC can + // handle this with a type-cast prior to indexing. Vulkan will + // raise an error at code-gen time, if a later pass doesn't split + // it out. + if (base_access_dtype.size() != 1) { + return element_dtype; + } + + DataType preferred_base_type = *base_access_dtype.begin(); - int get_preferred_lanes() const { // If there is only one vectorizable size used to access the // buffer, and if that access size is compatible with the array // size, then the buffer is vectorizable. In the future, this // could be improved to allow vectorized buffer access of size // GCD(*lanes_used), if necessary. - if ((element_dtype.lanes() == 1) && (lanes_used.size() == 1)) { + int preferred_lanes = element_dtype.lanes(); + if ((element_dtype.lanes() == 1) && (access_dtype.size() == 1)) { arith::Analyzer analyzer_; arith::ModularSet me = analyzer_.modular_set(extent); - int lanes = *lanes_used.begin(); + int lanes = access_dtype.begin()->lanes(); if ((me->coeff % lanes == 0) && (me->base % lanes == 0)) { - return lanes; + preferred_lanes = lanes; } } - return element_dtype.lanes(); + return preferred_base_type.with_lanes(preferred_lanes); } - - DataType get_preferred_dtype() const { return element_dtype.with_lanes(get_preferred_lanes()); } }; /* Checks whether buffers are accessed as scalar or vector parameters in a @@ -1092,12 +1109,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor { var_info.element_dtype = value_dtype.element_of(); } - // Currently cannot valid the element type being accessed. See comments in - // Load::Load for details. - // - // ICHECK_EQ(var_info.element_dtype.element_of(), value_dtype.element_of()) - // << "Attempting to access buffer of type " << var_info.element_dtype << " as type " - // << value_dtype; + DataType access_dtype = value_dtype; int lanes_used = var_info.element_dtype.lanes(); @@ -1113,6 +1125,10 @@ class VectorTypeAccessChecker : public StmtExprVisitor { var_info.element_dtype = var_info.element_dtype.with_lanes(1); } + // TODO(Lunderberg): Uncomment this check once it can be applied. + // See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615 + // for discussion. + // ICHECK_EQ(index.dtype().lanes() * var_info.element_dtype.lanes(), value_dtype.lanes()) // << "Attempting to retrieve " << value_dtype.lanes() << " lanes of data with " // << index.dtype().lanes() << " indices into an array whose elements have " @@ -1132,7 +1148,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor { } } - var_info.lanes_used.insert(lanes_used); + var_info.access_dtype.insert(access_dtype.with_lanes(lanes_used)); } // Map of buffer variable information determined @@ -1240,14 +1256,17 @@ class VectorTypeRewriter : public StmtExprMutator { } const auto& info = it->second; - const RampNode* ramp_index = op->index.as(); - ICHECK(ramp_index) << "Incorrect rewrite, index of LoadNode should be Ramp"; - ICHECK(is_one(ramp_index->stride)) - << "Incorrect rewrite, stride of Ramp index should be 1 for contiguous access"; + DataType out_dtype_base = info.new_element_dtype.element_of(); - PrimExpr new_index = ramp_index->base / make_const(ramp_index->base.dtype(), ramp_index->lanes); - return Load(op->dtype, info.new_buffer_var, new_index, const_true(new_index.dtype().lanes()), - op->span); + const RampNode* ramp_index = op->index.as(); + if (ramp_index && is_one(ramp_index->stride)) { + PrimExpr new_index = + ramp_index->base / make_const(ramp_index->base.dtype(), ramp_index->lanes); + return Load(out_dtype_base.with_lanes(op->dtype.lanes()), info.new_buffer_var, new_index, + const_true(new_index.dtype().lanes()), op->span); + } else { + return Load(out_dtype_base, info.new_buffer_var, op->index, op->predicate); + } } Stmt VisitStmt_(const StoreNode* op) final { @@ -1265,13 +1284,14 @@ class VectorTypeRewriter : public StmtExprMutator { const auto& info = it->second; const RampNode* ramp_index = op->index.as(); - ICHECK(ramp_index) << "Incorrect rewrite, index of StoreNode should be Ramp"; - ICHECK(is_one(ramp_index->stride)) - << "Incorrect rewrite, stride of Ramp index should be 1 for contiguous access"; - - PrimExpr new_index = ramp_index->base / make_const(ramp_index->base.dtype(), ramp_index->lanes); - return Store(info.new_buffer_var, op->value, new_index, const_true(new_index.dtype().lanes()), - op->span); + if (ramp_index && is_one(ramp_index->stride)) { + PrimExpr new_index = + ramp_index->base / make_const(ramp_index->base.dtype(), ramp_index->lanes); + return Store(info.new_buffer_var, op->value, new_index, const_true(new_index.dtype().lanes()), + op->span); + } else { + return Store(info.new_buffer_var, op->value, op->index, op->predicate, op->span); + } } PrimExpr VisitExpr_(const CallNode* op) final { @@ -1356,8 +1376,6 @@ class VectorTypeRewriter : public StmtExprMutator { } else { const auto& info = it->second; new_params.push_back(info.new_buffer_var); - // n->buffer_map.Set(info.new_buffer_var, n->buffer_map[old_param]); - // n->buffer_map.erase(old_param); } } n->params = new_params;