Skip to content

Commit

Permalink
[StorageRewrite] Updates as recommended in review.
Browse files Browse the repository at this point in the history
- Added explicit TODO(Lunderberg) for follow-ups

- Pass `checker.info_map_` instead of `checker` to
  `VectorTypeRewriter`
  • Loading branch information
Lunderberg committed Jul 29, 2021
1 parent 9230a0f commit 4a69ee6
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
4 changes: 4 additions & 0 deletions src/tir/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,10 @@ Load::Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate, S
// the same scope, regardless of element type. The codegen is
// then responsible for casting to the output type.

// TODO(Lunderberg): Uncomment this check once it can be applied.
// See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615
// for discussion.

// ICHECK(dtype.element_of() == pointer_type.second.element_of())
// << "Type mismatch, cannot load type " << dtype << " from buffer " <<
// buffer_var->name_hint
Expand Down
7 changes: 4 additions & 3 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,15 +255,16 @@ Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate,
// Currently cannot check element type of array, see Load::Load
// for details.

// TODO(Lunderberg): Uncomment this check once it can be applied.
// See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615
// for discussion.

// ICHECK_EQ(value.dtype().element_of(), pointer_type.second.element_of())
// << "Type mismatch, cannot store type " << value.dtype() << " into buffer "
// << buffer_var->name_hint << " of type " << pointer_type.second;
element_lanes = pointer_type.second.lanes();
}

// ICHECK_EQ(value.dtype().lanes(), element_lanes * index.dtype().lanes());
// ICHECK_EQ(value.dtype().lanes(), element_lanes * predicate.dtype().lanes());

ICHECK((value.dtype().lanes() == element_lanes * index.dtype().lanes()) ||
(value.dtype().lanes() == index.dtype().lanes()));
ICHECK((value.dtype().lanes() == element_lanes * predicate.dtype().lanes()) ||
Expand Down
13 changes: 7 additions & 6 deletions src/tir/transforms/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1191,9 +1191,10 @@ class VectorTypeRewriter : public StmtExprMutator {
* @param rewrite_let_node Whether pointer declarations in let nodes
* should be re-written.
*/
VectorTypeRewriter(const VectorTypeAccessChecker& checker, bool rewrite_params = true,
bool rewrite_buffer_map = true, bool rewrite_allocate_node = true,
bool rewrite_indices = true, bool rewrite_let_node = true)
VectorTypeRewriter(const std::unordered_map<const VarNode*, BufferVarInfo>& info_map,
bool rewrite_params = true, bool rewrite_buffer_map = true,
bool rewrite_allocate_node = true, bool rewrite_indices = true,
bool rewrite_let_node = true)
: rewrite_indices_(rewrite_indices) {
int rewrite_mask = 0;
if (rewrite_params) {
Expand All @@ -1210,7 +1211,7 @@ class VectorTypeRewriter : public StmtExprMutator {
}

// Rewrite any buffer variables whose preferred type isn't their current type.
for (const auto& pair : checker.info_map_) {
for (const auto& pair : info_map) {
const auto& var_info = pair.second;
DataType preferred = var_info.get_preferred_dtype();
if (preferred != var_info.element_dtype && (rewrite_mask & var_info.declaration_location)) {
Expand Down Expand Up @@ -1409,8 +1410,8 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f, bool allow_untyped_pointers = false
VectorTypeAccessChecker checker(f->params, f->buffer_map, allow_untyped_pointers);
checker(f->body);

VectorTypeRewriter rewriter(checker, rewrite_params, rewrite_buffer_map, rewrite_allocate_node,
rewrite_indices, rewrite_let_node);
VectorTypeRewriter rewriter(checker.info_map_, rewrite_params, rewrite_buffer_map,
rewrite_allocate_node, rewrite_indices, rewrite_let_node);
PrimFuncNode* n = f.CopyOnWrite();
n->body = rewriter(std::move(n->body));
rewriter.Finalize(&f);
Expand Down

0 comments on commit 4a69ee6

Please sign in to comment.