Skip to content

Commit

Permalink
fixed warp_coeff
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent a0afb56 commit 7a962cd
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions src/tir/transforms/lower_warp_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,7 @@ class WarpStoreCoeffFinder : private StmtExprVisitor {
/// Visitor implementation
void VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::ptx_ldmatrix()) && op->args[3].as<VarNode>() == buffer_) {
int num_matrix = op->args[1].as<IntImmNode>()->value;
warp_coeff_ = num_matrix * 2;
UpdatePattern(op->args[4]);
} else if (op->op.same_as(builtin::mma_fill()) && op->args[1].as<VarNode>() == buffer_) {
auto* ptr = op->args[0].as<IntImmNode>();
CHECK(ptr);
Expand Down Expand Up @@ -499,7 +498,7 @@ Pass LowerWarpMemory() {
WarpMemoryRewriter warp_memory_rewriter(warp_size);
auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body));
n->body = UpdatePointerStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt);
// LOG(INFO) << f;
LOG(INFO) << f;
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {});
Expand Down

0 comments on commit 7a962cd

Please sign in to comment.