Skip to content

Commit

Permalink
[Vulkan] Allow for pointer rewrites that change base type.
Browse files Browse the repository at this point in the history
A single memory allocation may have more than one type of data stored
within it.  This allows the PointerTypeRewrite pass to recognize if a
function only uses the pointer as a particular base type.  This wasn't
an issue in C-based codegen, but is required for Vulkan.  Since Vulkan
shaders do not permit type-casting, the cast must be done when passing
the pointer argument into the shader.
  • Loading branch information
Lunderberg committed Jul 29, 2021
1 parent 83d2066 commit a5f23fe
Showing 1 changed file with 50 additions and 32 deletions.
82 changes: 50 additions & 32 deletions src/tir/transforms/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -921,29 +921,46 @@ struct BufferVarInfo {
// Where the buffer was declared
DeclarationLocation declaration_location;

// When accessed, how many lanes of data are used.
std::set<int> lanes_used;
// When accessed, which element type is it accessed as. This may
// differ both in base type (e.g. int32* cast to float32* after
// packing in StorageRewrite) or in number of lanes (e.g. float16*
// cast to float16x4*).
std::unordered_set<DataType> access_dtype;

DataType get_preferred_dtype() const {
std::unordered_set<DataType> base_access_dtype;
for (auto dtype : access_dtype) {
base_access_dtype.insert(dtype.element_of());
}
// If the array is accessed as multiple base types within a
// function, no point in changing the declared type. CodeGenC can
// handle this with a type-cast prior to indexing. Vulkan will
// raise an error at code-gen time, if a later pass doesn't split
// it out.
if (base_access_dtype.size() != 1) {
return element_dtype;
}

DataType preferred_base_type = *base_access_dtype.begin();

int get_preferred_lanes() const {
// If there is only one vectorizable size used to access the
// buffer, and if that access size is compatible with the array
// size, then the buffer is vectorizable. In the future, this
// could be improved to allow vectorized buffer access of size
// GCD(*lanes_used), if necessary.
if ((element_dtype.lanes() == 1) && (lanes_used.size() == 1)) {
int preferred_lanes = element_dtype.lanes();
if ((element_dtype.lanes() == 1) && (access_dtype.size() == 1)) {
arith::Analyzer analyzer_;
arith::ModularSet me = analyzer_.modular_set(extent);

int lanes = *lanes_used.begin();
int lanes = access_dtype.begin()->lanes();
if ((me->coeff % lanes == 0) && (me->base % lanes == 0)) {
return lanes;
preferred_lanes = lanes;
}
}

return element_dtype.lanes();
return preferred_base_type.with_lanes(preferred_lanes);
}

DataType get_preferred_dtype() const { return element_dtype.with_lanes(get_preferred_lanes()); }
};

/* Checks whether buffers are accessed as scalar or vector parameters in a
Expand Down Expand Up @@ -1092,12 +1109,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
var_info.element_dtype = value_dtype.element_of();
}

// Currently cannot valid the element type being accessed. See comments in
// Load::Load for details.
//
// ICHECK_EQ(var_info.element_dtype.element_of(), value_dtype.element_of())
// << "Attempting to access buffer of type " << var_info.element_dtype << " as type "
// << value_dtype;
DataType access_dtype = value_dtype;

int lanes_used = var_info.element_dtype.lanes();

Expand All @@ -1113,6 +1125,10 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
var_info.element_dtype = var_info.element_dtype.with_lanes(1);
}

// TODO(Lunderberg): Uncomment this check once it can be applied.
// See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615
// for discussion.

// ICHECK_EQ(index.dtype().lanes() * var_info.element_dtype.lanes(), value_dtype.lanes())
// << "Attempting to retrieve " << value_dtype.lanes() << " lanes of data with "
// << index.dtype().lanes() << " indices into an array whose elements have "
Expand All @@ -1132,7 +1148,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
}
}

var_info.lanes_used.insert(lanes_used);
var_info.access_dtype.insert(access_dtype.with_lanes(lanes_used));
}

// Map of buffer variable information determined
Expand Down Expand Up @@ -1240,14 +1256,17 @@ class VectorTypeRewriter : public StmtExprMutator {
}
const auto& info = it->second;

const RampNode* ramp_index = op->index.as<RampNode>();
ICHECK(ramp_index) << "Incorrect rewrite, index of LoadNode should be Ramp";
ICHECK(is_one(ramp_index->stride))
<< "Incorrect rewrite, stride of Ramp index should be 1 for contiguous access";
DataType out_dtype_base = info.new_element_dtype.element_of();

PrimExpr new_index = ramp_index->base / make_const(ramp_index->base.dtype(), ramp_index->lanes);
return Load(op->dtype, info.new_buffer_var, new_index, const_true(new_index.dtype().lanes()),
op->span);
const RampNode* ramp_index = op->index.as<RampNode>();
if (ramp_index && is_one(ramp_index->stride)) {
PrimExpr new_index =
ramp_index->base / make_const(ramp_index->base.dtype(), ramp_index->lanes);
return Load(out_dtype_base.with_lanes(op->dtype.lanes()), info.new_buffer_var, new_index,
const_true(new_index.dtype().lanes()), op->span);
} else {
return Load(out_dtype_base, info.new_buffer_var, op->index, op->predicate);
}
}

Stmt VisitStmt_(const StoreNode* op) final {
Expand All @@ -1265,13 +1284,14 @@ class VectorTypeRewriter : public StmtExprMutator {
const auto& info = it->second;

const RampNode* ramp_index = op->index.as<RampNode>();
ICHECK(ramp_index) << "Incorrect rewrite, index of StoreNode should be Ramp";
ICHECK(is_one(ramp_index->stride))
<< "Incorrect rewrite, stride of Ramp index should be 1 for contiguous access";

PrimExpr new_index = ramp_index->base / make_const(ramp_index->base.dtype(), ramp_index->lanes);
return Store(info.new_buffer_var, op->value, new_index, const_true(new_index.dtype().lanes()),
op->span);
if (ramp_index && is_one(ramp_index->stride)) {
PrimExpr new_index =
ramp_index->base / make_const(ramp_index->base.dtype(), ramp_index->lanes);
return Store(info.new_buffer_var, op->value, new_index, const_true(new_index.dtype().lanes()),
op->span);
} else {
return Store(info.new_buffer_var, op->value, op->index, op->predicate, op->span);
}
}

PrimExpr VisitExpr_(const CallNode* op) final {
Expand Down Expand Up @@ -1356,8 +1376,6 @@ class VectorTypeRewriter : public StmtExprMutator {
} else {
const auto& info = it->second;
new_params.push_back(info.new_buffer_var);
// n->buffer_map.Set(info.new_buffer_var, n->buffer_map[old_param]);
// n->buffer_map.erase(old_param);
}
}
n->params = new_params;
Expand Down

0 comments on commit a5f23fe

Please sign in to comment.