Skip to content

Commit

Permalink
add tunable 4k test, 36 TFLOPS
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent b9f7eae commit a9640f4
Show file tree
Hide file tree
Showing 3 changed files with 421 additions and 7 deletions.
3 changes: 0 additions & 3 deletions src/tir/transforms/lower_warp_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -467,13 +467,10 @@ Pass LowerWarpMemory() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
LOG(INFO)<< "Before LowerWarpMemory \n" << f;
int warp_size = target.value()->GetAttr<Integer>("thread_warp_size", 1).value();
WarpMemoryRewriter warp_memory_rewriter(warp_size);
auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body));
n->body = UpdatePointerStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt);
LOG(INFO)<< f;

return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {});
Expand Down
8 changes: 4 additions & 4 deletions tests/python/unittest/test_mma_16x8x8_4k.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None:
".b16",
A_warp.data,
4 * tx,
A_shared.data,
8 * (tx % 16),
A_shared.access_ptr("r"),
s1 * (tx % 16),
dtype="float16",
)
)
Expand Down Expand Up @@ -128,8 +128,8 @@ def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None:
".b16",
B_warp.data,
2 * tx,
B_shared.data,
8 * (tx % 8),
B_shared.access_ptr("r"),
s1 * (tx % 8),
dtype="float16",
)
)
Expand Down
Loading

0 comments on commit a9640f4

Please sign in to comment.