From 9230a0f2fd9c29f1b8bad70564051312238d9527 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 27 Jul 2021 11:41:13 -0500 Subject: [PATCH] [StorageRewrite] Updated info-gathering pass to allow for untyped pointers. Currently, some pointer variables in Let nodes are missing type annotations. Long-term, they should have type annotations added, but for the short-term the StorageRewrite pass can work in these cases by checking how an array is being accessed (previous behavior). --- src/tir/transforms/storage_rewrite.cc | 78 ++++++++++++++++++++++----- 1 file changed, 66 insertions(+), 12 deletions(-) diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 7510e6eab034e..c9e76bc3dc3f0 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -900,7 +900,8 @@ struct BufferVarInfo { enum DeclarationLocation { kPrimFuncParam = (1 << 0), kPrimFuncBufferMap = (1 << 1), - kAllocateNode = (1 << 2) + kAllocateNode = (1 << 2), + kLetNode = (1 << 3), }; // The tir::Var that represents this buffer. @@ -956,8 +957,14 @@ class VectorTypeAccessChecker : public StmtExprVisitor { * @param params The parameters passed to a PrimFunc * * @param buffer_map The buffer_map associated with a PrimFunc + * + * @param allow_untyped_handles If a buffer or pointer variable is + * missing a type annotation, assume that it has the same underlying + * type as it is later accessed, with scalar element types. */ - VectorTypeAccessChecker(const Array& params, const Map& buffer_map) { + VectorTypeAccessChecker(const Array& params, const Map& buffer_map, + bool allow_untyped_pointers = false) + : allow_untyped_pointers_(allow_untyped_pointers) { // If a parameter is in the buffer map, we want to track the // version in the map. for (auto it : buffer_map) { @@ -1007,14 +1014,43 @@ class VectorTypeAccessChecker : public StmtExprVisitor { StmtExprVisitor::VisitStmt_(op); } + void VisitExpr_(const LetNode* op) final { + HandleLetNode(op->var); + StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const LetStmtNode* op) final { + HandleLetNode(op->var); + StmtExprVisitor::VisitStmt_(op); + } + + void HandleLetNode(Var let_var) { + if (let_var->dtype.is_handle()) { + auto pointer_type = GetPointerType(let_var->type_annotation); + if (pointer_type.first) { + OnArrayDeclaration(let_var, pointer_type.second, 0, BufferVarInfo::kLetNode); + } else if (allow_untyped_pointers_) { + OnArrayDeclaration(let_var, let_var->dtype, 0, BufferVarInfo::kLetNode); + } else { + LOG(FATAL) << "Let statement of variable " << let_var->name_hint + << " is missing a type annotation, " + << "or type annotation is not a pointer to primitive"; + } + } + } + /* Update the type map for a buffer based on its declaration * * @param buffer The VarNode representing the buffer. * * @param element_dtype The dtype of a single element of the buffer. + * If unknown, when used with the allow_untyped_handles option, + * should be a handle dtype. + * + * @param extent The extent of the buffer. Zero if size is unknown. * - * @param allowed_to_rewrite If the element type of this array is - * allowed to be rewritten. + * @param declaration_location How the buffer was allocated, so that + * some locations can be rewritten without others. */ void OnArrayDeclaration(Var buffer, DataType element_dtype, PrimExpr extent, BufferVarInfo::DeclarationLocation declaration_location) { @@ -1049,7 +1085,14 @@ class VectorTypeAccessChecker : public StmtExprVisitor { if (value_dtype.element_of() == DataType::Bool()) { value_dtype = DataType::Int(8).with_lanes(value_dtype.lanes()); } - // Currently cannot check the element type being accessed. See comments in + + if (var_info.element_dtype.is_handle()) { + ICHECK(allow_untyped_pointers_) << "Variable " << buffer->name_hint + << " was missing a type annotation in its declaration"; + var_info.element_dtype = value_dtype.element_of(); + } + + // Currently cannot valid the element type being accessed. See comments in // Load::Load for details. // // ICHECK_EQ(var_info.element_dtype.element_of(), value_dtype.element_of()) @@ -1094,6 +1137,10 @@ class VectorTypeAccessChecker : public StmtExprVisitor { // Map of buffer variable information determined std::unordered_map info_map_; + + // + bool allow_untyped_pointers_{false}; + // internal analyzer arith::Analyzer analyzer_; }; @@ -1140,10 +1187,13 @@ class VectorTypeRewriter : public StmtExprMutator { * * @param rewrite_indices Whether the indices to the Load and Store nodes * should be rewritten to correspond to the new buffer_var type. + * + * @param rewrite_let_node Whether pointer declarations in let nodes + * should be re-written. */ VectorTypeRewriter(const VectorTypeAccessChecker& checker, bool rewrite_params = true, bool rewrite_buffer_map = true, bool rewrite_allocate_node = true, - bool rewrite_indices = true) + bool rewrite_indices = true, bool rewrite_let_node = true) : rewrite_indices_(rewrite_indices) { int rewrite_mask = 0; if (rewrite_params) { @@ -1155,6 +1205,9 @@ class VectorTypeRewriter : public StmtExprMutator { if (rewrite_allocate_node) { rewrite_mask |= BufferVarInfo::kAllocateNode; } + if (rewrite_let_node) { + rewrite_mask |= BufferVarInfo::kLetNode; + } // Rewrite any buffer variables whose preferred type isn't their current type. for (const auto& pair : checker.info_map_) { @@ -1349,14 +1402,15 @@ class VectorTypeRewriter : public StmtExprMutator { // Rewrite allocates, pointer parameters, and buffer map into vectorized versions // if each access into a buffer is the same vector type. -PrimFunc PointerValueTypeRewrite(PrimFunc f, bool rewrite_params = true, - bool rewrite_buffer_map = true, bool rewrite_allocate_node = true, - bool rewrite_indices = true) { - VectorTypeAccessChecker checker(f->params, f->buffer_map); +PrimFunc PointerValueTypeRewrite(PrimFunc f, bool allow_untyped_pointers = false, + bool rewrite_params = true, bool rewrite_buffer_map = true, + bool rewrite_allocate_node = true, bool rewrite_indices = true, + bool rewrite_let_node = true) { + VectorTypeAccessChecker checker(f->params, f->buffer_map, allow_untyped_pointers); checker(f->body); VectorTypeRewriter rewriter(checker, rewrite_params, rewrite_buffer_map, rewrite_allocate_node, - rewrite_indices); + rewrite_indices, rewrite_let_node); PrimFuncNode* n = f.CopyOnWrite(); n->body = rewriter(std::move(n->body)); rewriter.Finalize(&f); @@ -1370,7 +1424,7 @@ Pass StorageRewrite() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true); - PrimFunc output = PointerValueTypeRewrite(std::move(f), false, false, true, false); + PrimFunc output = PointerValueTypeRewrite(std::move(f), true, false, false, true, false, true); return output; }; return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {});