Skip to content

Commit

Permalink
parameterize over storage scope in mma store intrin
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 18, 2022
1 parent 827ea4c commit 7a235b6
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 17 deletions.
22 changes: 14 additions & 8 deletions python/tvm/tir/tensor_intrin/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,14 +354,14 @@ def mma_fill_impl(a: T.handle) -> None:
return mma_fill_desc, mma_fill_impl


def get_mma_store_intrin(dtype, local_size):
def get_mma_store_intrin(dtype, local_size, scope="global"):
# Assume M = N = 16
index_map = shared_16x16_to_ldmatrix_32x8_layout

@T.prim_func
def mma_store_desc(a: T.handle, c: T.handle) -> None:
C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype, scope="warp")
C = T.match_buffer(c, [M_DIM, N_DIM], dtype=dtype, scope="global")
C = T.match_buffer(c, [M_DIM, N_DIM], dtype=dtype, scope=scope)

with T.block("root"):
T.reads(C_warp[0:WARP_SIZE, 0:local_size])
Expand Down Expand Up @@ -454,11 +454,17 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None:
MMA_fill_16x16_i32_INTRIN = "mma_fill_16x16_i32"
TensorIntrin.register(MMA_fill_16x16_i32_INTRIN, *get_mma_fill_intrin("int32", 8))

MMA_store_16x16_f32_INTRIN = "mma_store_16x16_f32"
TensorIntrin.register(MMA_store_16x16_f32_INTRIN, *get_mma_store_intrin("float32", 8))
MMA_store_16x16_f32_global_INTRIN = "mma_store_16x16_f32_global_"
TensorIntrin.register(
MMA_store_16x16_f32_global_INTRIN, *get_mma_store_intrin("float32", 8, "global")
)

MMA_store_16x16_f16_INTRIN = "mma_store_16x16_f16"
TensorIntrin.register(MMA_store_16x16_f16_INTRIN, *get_mma_store_intrin("float16", 8))
MMA_store_16x16_f16_global_INTRIN = "mma_store_16x16_f16_global_"
TensorIntrin.register(
MMA_store_16x16_f16_global_INTRIN, *get_mma_store_intrin("float16", 8, "global")
)

MMA_store_16x16_i32_INTRIN = "mma_store_16x16_i32"
TensorIntrin.register(MMA_store_16x16_i32_INTRIN, *get_mma_store_intrin("int32", 8))
MMA_store_16x16_i32_global_INTRIN = "mma_store_16x16_i32_global_"
TensorIntrin.register(
MMA_store_16x16_i32_global_INTRIN, *get_mma_store_intrin("int32", 8, "global")
)
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@
MMA_fill_16x16_f32_INTRIN,
MMA_fill_16x16_f16_INTRIN,
MMA_fill_16x16_i32_INTRIN,
MMA_store_16x16_f32_INTRIN,
MMA_store_16x16_f16_INTRIN,
MMA_store_16x16_i32_INTRIN,
MMA_store_16x16_f32_global_INTRIN,
MMA_store_16x16_f16_global_INTRIN,
MMA_store_16x16_i32_global_INTRIN,
shared_16x16_to_ldmatrix_32x8_layout,
shared_32x16_to_ldmatrix_32x16_layout,
shared_16x32_to_ldmatrix_32x16_layout,
Expand Down Expand Up @@ -249,7 +249,7 @@ def index_map(i, j):
LDMATRIX_16x16_B_INTRIN,
MMA_f16f16f32_INTRIN,
MMA_fill_16x16_f32_INTRIN,
MMA_store_16x16_f32_INTRIN,
MMA_store_16x16_f32_global_INTRIN,
)

if measure_perf:
Expand All @@ -270,7 +270,7 @@ def index_map(i, j):
LDMATRIX_16x16_B_TRANS_INTRIN,
MMA_f16f16f32_TRANS_INTRIN,
MMA_fill_16x16_f32_INTRIN,
MMA_store_16x16_f32_INTRIN,
MMA_store_16x16_f32_global_INTRIN,
)

if measure_perf:
Expand Down Expand Up @@ -305,7 +305,7 @@ def index_map(i, j):
LDMATRIX_16x16_B_INTRIN,
MMA_f16f16f16_INTRIN,
MMA_fill_16x16_f16_INTRIN,
MMA_store_16x16_f16_INTRIN,
MMA_store_16x16_f16_global_INTRIN,
)

if measure_perf:
Expand All @@ -326,7 +326,7 @@ def index_map(i, j):
LDMATRIX_16x16_B_TRANS_INTRIN,
MMA_f16f16f16_TRANS_INTRIN,
MMA_fill_16x16_f16_INTRIN,
MMA_store_16x16_f16_INTRIN,
MMA_store_16x16_f16_global_INTRIN,
)

if measure_perf:
Expand Down Expand Up @@ -375,7 +375,7 @@ def index_map_C(i, j):
LDMATRIX_32x16_B_INTRIN,
MMA_i8i8i32_INTRIN,
MMA_fill_16x16_i32_INTRIN,
MMA_store_16x16_i32_INTRIN,
MMA_store_16x16_i32_global_INTRIN,
)

if measure_perf:
Expand All @@ -396,7 +396,7 @@ def index_map_C(i, j):
LDMATRIX_16x32_B_TRANS_INTRIN,
MMA_i8i8i32_TRANS_INTRIN,
MMA_fill_16x16_i32_INTRIN,
MMA_store_16x16_i32_INTRIN,
MMA_store_16x16_i32_global_INTRIN,
)

if measure_perf:
Expand Down

0 comments on commit 7a235b6

Please sign in to comment.