Skip to content

Commit

Permalink
[StorageRewrite] Updated info-gathering pass to allow for untyped poi…
Browse files Browse the repository at this point in the history
…nters.

Currently, some pointer variables in Let nodes are missing type
annotations.  Long-term, they should have type annotations added, but
for the short-term the StorageRewrite pass can work in these cases by
checking how an array is being accessed (previous behavior).
  • Loading branch information
Lunderberg committed Jul 28, 2021
1 parent ec2c861 commit 9230a0f
Showing 1 changed file with 66 additions and 12 deletions.
78 changes: 66 additions & 12 deletions src/tir/transforms/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,8 @@ struct BufferVarInfo {
enum DeclarationLocation {
kPrimFuncParam = (1 << 0),
kPrimFuncBufferMap = (1 << 1),
kAllocateNode = (1 << 2)
kAllocateNode = (1 << 2),
kLetNode = (1 << 3),
};

// The tir::Var that represents this buffer.
Expand Down Expand Up @@ -956,8 +957,14 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
* @param params The parameters passed to a PrimFunc
*
* @param buffer_map The buffer_map associated with a PrimFunc
*
* @param allow_untyped_handles If a buffer or pointer variable is
* missing a type annotation, assume that it has the same underlying
* type as it is later accessed, with scalar element types.
*/
VectorTypeAccessChecker(const Array<tir::Var>& params, const Map<Var, Buffer>& buffer_map) {
VectorTypeAccessChecker(const Array<tir::Var>& params, const Map<Var, Buffer>& buffer_map,
bool allow_untyped_pointers = false)
: allow_untyped_pointers_(allow_untyped_pointers) {
// If a parameter is in the buffer map, we want to track the
// version in the map.
for (auto it : buffer_map) {
Expand Down Expand Up @@ -1007,14 +1014,43 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
StmtExprVisitor::VisitStmt_(op);
}

void VisitExpr_(const LetNode* op) final {
HandleLetNode(op->var);
StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const LetStmtNode* op) final {
HandleLetNode(op->var);
StmtExprVisitor::VisitStmt_(op);
}

void HandleLetNode(Var let_var) {
if (let_var->dtype.is_handle()) {
auto pointer_type = GetPointerType(let_var->type_annotation);
if (pointer_type.first) {
OnArrayDeclaration(let_var, pointer_type.second, 0, BufferVarInfo::kLetNode);
} else if (allow_untyped_pointers_) {
OnArrayDeclaration(let_var, let_var->dtype, 0, BufferVarInfo::kLetNode);
} else {
LOG(FATAL) << "Let statement of variable " << let_var->name_hint
<< " is missing a type annotation, "
<< "or type annotation is not a pointer to primitive";
}
}
}

/* Update the type map for a buffer based on its declaration
*
* @param buffer The VarNode representing the buffer.
*
* @param element_dtype The dtype of a single element of the buffer.
* If unknown, when used with the allow_untyped_handles option,
* should be a handle dtype.
*
* @param extent The extent of the buffer. Zero if size is unknown.
*
* @param allowed_to_rewrite If the element type of this array is
* allowed to be rewritten.
* @param declaration_location How the buffer was allocated, so that
* some locations can be rewritten without others.
*/
void OnArrayDeclaration(Var buffer, DataType element_dtype, PrimExpr extent,
BufferVarInfo::DeclarationLocation declaration_location) {
Expand Down Expand Up @@ -1049,7 +1085,14 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
if (value_dtype.element_of() == DataType::Bool()) {
value_dtype = DataType::Int(8).with_lanes(value_dtype.lanes());
}
// Currently cannot check the element type being accessed. See comments in

if (var_info.element_dtype.is_handle()) {
ICHECK(allow_untyped_pointers_) << "Variable " << buffer->name_hint
<< " was missing a type annotation in its declaration";
var_info.element_dtype = value_dtype.element_of();
}

// Currently cannot valid the element type being accessed. See comments in
// Load::Load for details.
//
// ICHECK_EQ(var_info.element_dtype.element_of(), value_dtype.element_of())
Expand Down Expand Up @@ -1094,6 +1137,10 @@ class VectorTypeAccessChecker : public StmtExprVisitor {

// Map of buffer variable information determined
std::unordered_map<const VarNode*, BufferVarInfo> info_map_;

//
bool allow_untyped_pointers_{false};

// internal analyzer
arith::Analyzer analyzer_;
};
Expand Down Expand Up @@ -1140,10 +1187,13 @@ class VectorTypeRewriter : public StmtExprMutator {
*
* @param rewrite_indices Whether the indices to the Load and Store nodes
* should be rewritten to correspond to the new buffer_var type.
*
* @param rewrite_let_node Whether pointer declarations in let nodes
* should be re-written.
*/
VectorTypeRewriter(const VectorTypeAccessChecker& checker, bool rewrite_params = true,
bool rewrite_buffer_map = true, bool rewrite_allocate_node = true,
bool rewrite_indices = true)
bool rewrite_indices = true, bool rewrite_let_node = true)
: rewrite_indices_(rewrite_indices) {
int rewrite_mask = 0;
if (rewrite_params) {
Expand All @@ -1155,6 +1205,9 @@ class VectorTypeRewriter : public StmtExprMutator {
if (rewrite_allocate_node) {
rewrite_mask |= BufferVarInfo::kAllocateNode;
}
if (rewrite_let_node) {
rewrite_mask |= BufferVarInfo::kLetNode;
}

// Rewrite any buffer variables whose preferred type isn't their current type.
for (const auto& pair : checker.info_map_) {
Expand Down Expand Up @@ -1349,14 +1402,15 @@ class VectorTypeRewriter : public StmtExprMutator {

// Rewrite allocates, pointer parameters, and buffer map into vectorized versions
// if each access into a buffer is the same vector type.
PrimFunc PointerValueTypeRewrite(PrimFunc f, bool rewrite_params = true,
bool rewrite_buffer_map = true, bool rewrite_allocate_node = true,
bool rewrite_indices = true) {
VectorTypeAccessChecker checker(f->params, f->buffer_map);
PrimFunc PointerValueTypeRewrite(PrimFunc f, bool allow_untyped_pointers = false,
bool rewrite_params = true, bool rewrite_buffer_map = true,
bool rewrite_allocate_node = true, bool rewrite_indices = true,
bool rewrite_let_node = true) {
VectorTypeAccessChecker checker(f->params, f->buffer_map, allow_untyped_pointers);
checker(f->body);

VectorTypeRewriter rewriter(checker, rewrite_params, rewrite_buffer_map, rewrite_allocate_node,
rewrite_indices);
rewrite_indices, rewrite_let_node);
PrimFuncNode* n = f.CopyOnWrite();
n->body = rewriter(std::move(n->body));
rewriter.Finalize(&f);
Expand All @@ -1370,7 +1424,7 @@ Pass StorageRewrite() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true);
PrimFunc output = PointerValueTypeRewrite(std::move(f), false, false, true, false);
PrimFunc output = PointerValueTypeRewrite(std::move(f), true, false, false, true, false, true);
return output;
};
return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {});
Expand Down

0 comments on commit 9230a0f

Please sign in to comment.