Skip to content

Commit

Permalink
ldmatrix intrin generation with meta programming
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent fb62abb commit 5afb5f0
Show file tree
Hide file tree
Showing 8 changed files with 246 additions and 739 deletions.
1 change: 0 additions & 1 deletion python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,6 @@ def transform_Assign(self, node):
if node.ty is None and hasattr(value, "dtype"):
var_ty = value.dtype
else:
print(node.ty, ast_var)
var_ty = self.parse_type(node.ty, ast_var)

var = tvm.te.var(
Expand Down
193 changes: 193 additions & 0 deletions python/tvm/tir/tensor_intrin/cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,missing-function-docstring
"""Intrinsics for tensorization on NVIDIA GPU."""
from ..._ffi import register_func
from ...runtime import convert
from .. import TensorIntrin
from tvm.script import tir as T


def shared_16x16_to_ldmatrix_32x8_layout(i, j):
thread_id = 4 * (i % 8) + (j % 8) // 2
return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2)


def shared_16x32_to_ldmatrix_32x16_layout(i, j):
thread_id = 4 * (i % 8) + (j % 16) // 4
return thread_id, 8 * (j // 16) + (i // 8) * 4 + j % 4


def shared_32x16_to_ldmatrix_32x16_layout(i, j):
thread_id = (i % 4) + 4 * (j % 8)
return thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4


@register_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout")
def index_map_shared_16x16_to_ldmatrix_32x8_layout(i, j):
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i, j)
return convert([thread_id, local_id])


lift = convert

M_DIM = 16
WARP_SIZE = 32
HALF_WARP = WARP_SIZE // 2
HALF_WARP_expr = lift(HALF_WARP)


def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed):
local_size = (M_DIM * k_dim) // WARP_SIZE
shared_offset = None
index_map = None

if transposed:
assert is_b, "Transposed A matrix not supported"

ldmatrix_col_major = is_b and not transposed

if k_dim == 16:
assert dtype == "float16"

index_map = shared_16x16_to_ldmatrix_32x8_layout

if transposed:
shared_offset = (
lambda tx, stride: stride * 8 * (tx // HALF_WARP_expr)
+ stride * (tx % 8)
+ 8 * ((tx % HALF_WARP_expr) // 8)
)
else:
shared_offset = lambda tx, stride: stride * (tx % HALF_WARP_expr) + 8 * (
tx // HALF_WARP_expr
)

elif k_dim == 32:
assert dtype == "int8"

if ldmatrix_col_major:
print("foo")
index_map = shared_32x16_to_ldmatrix_32x16_layout
shared_offset = (
lambda _, stride: stride
) # dummy offset, ldmatrix cannot be used for int8 + trans case
elif is_b and transposed:
index_map = shared_16x32_to_ldmatrix_32x16_layout
shared_offset = (
lambda tx, stride: stride * 8 * (tx // HALF_WARP_expr)
+ (tx % 8) * stride
+ 16 * ((tx % HALF_WARP_expr) // 8)
)
else:
index_map = shared_16x32_to_ldmatrix_32x16_layout
shared_offset = lambda tx, stride: stride * (tx % 16) + 16 * (tx // 16)

else:
assert False, "Unsupported k dim"

assert index_map and shared_offset

if is_b and not transposed:
row_dim = k_dim
col_dim = M_DIM
else:
row_dim = M_DIM
col_dim = k_dim

shmem_shape = (row_dim, col_dim)

@T.prim_func
def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None:
shared = T.match_buffer(
shared_handle, shmem_shape, dtype, align=128, offset_factor=16, scope="shared"
)
warp = T.match_buffer(
warp_handle, (WARP_SIZE, local_size), dtype, align=128, offset_factor=16, scope="warp"
)

with T.block("root"):
T.reads(shared[0:row_dim, 0:col_dim])
T.writes(warp[0:WARP_SIZE, 0:local_size])

for ax0, ax1 in T.grid(row_dim, col_dim):
with T.block("shared_warp"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(shared[v0, v1])

thread_id, local_id = index_map(v0, v1)
T.writes(warp[thread_id, local_id])
warp[thread_id, local_id] = shared[v0, v1]

@T.prim_func
def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None:
s0 = T.var("int32")
s1 = T.var("int32")
shared = T.match_buffer(
shared_handle,
shmem_shape,
dtype,
align=128,
offset_factor=16,
scope="shared",
strides=[s0, s1],
)
warp = T.match_buffer(
warp_handle, (WARP_SIZE, local_size), dtype, align=128, offset_factor=16, scope="warp"
)

with T.block("root"):
T.reads(shared[0:row_dim, 0:col_dim])
T.writes(warp[0:WARP_SIZE, 0:local_size])
tx = T.env_thread("threadIdx.x")
T.launch_thread(tx, WARP_SIZE)

T.evaluate(
T.ptx_ldmatrix(
ldmatrix_col_major,
4, # Always load 4 matrices
".b16",
warp.data,
warp.elem_offset + lift(local_size) * tx,
shared.access_ptr("r"),
shared_offset(tx, s0),
dtype=dtype,
)
)

return ldmatrix_desc, ldmatrix_impl


LDMATRIX_16x16_A_INTRIN = "mma.ldmatrix_16x16_a"
TensorIntrin.register(LDMATRIX_16x16_A_INTRIN, *get_ldmatrix_intrin(16, "float16", False, False))

LDMATRIX_16x16_B_INTRIN = "mma.ldmatrix_16x16_b"
TensorIntrin.register(LDMATRIX_16x16_B_INTRIN, *get_ldmatrix_intrin(16, "float16", True, False))

LDMATRIX_16x16_B_TRANS_INTRIN = "mma.ldmatrix_16x16_b_trans"
TensorIntrin.register(
LDMATRIX_16x16_B_TRANS_INTRIN, *get_ldmatrix_intrin(16, "float16", True, True)
)

LDMATRIX_16x32_A_INTRIN = "mma.ldmatrix_16x32_a"
TensorIntrin.register(LDMATRIX_16x32_A_INTRIN, *get_ldmatrix_intrin(32, "int8", False, False))

LDMATRIX_32x16_B_INTRIN = "mma.ldmatrix_32x16_b"
TensorIntrin.register(LDMATRIX_32x16_B_INTRIN, *get_ldmatrix_intrin(32, "int8", True, False))

LDMATRIX_16x32_B_TRANS_INTRIN = "mma.ldmatrix_16x32_b_trans"
TensorIntrin.register(LDMATRIX_16x32_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "int8", True, True))
127 changes: 7 additions & 120 deletions tests/python/unittest/test_mma_16x8x16_4k_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,126 +4,15 @@
import tvm.meta_schedule.testing.te_workload as te_workload
from tvm import te, tir
from tvm import meta_schedule as ms
from tvm.tir.tensor_intrin.cuda import (
LDMATRIX_16x16_A_INTRIN,
LDMATRIX_16x16_B_INTRIN,
shared_16x16_to_ldmatrix_32x8_layout,
)
import tvm.testing
import numpy as np


def shared_16x16_to_ldmatrix_32x8_layout(i, j):
thread_id = 4 * (i % 8) + (j % 8) // 2
return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2)


@tvm._ffi.register_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout")
def index_map_shared_16x16_to_ldmatrix_32x8_layout(i, j):
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i, j)
return tvm.runtime.convert([thread_id, local_id])


@T.prim_func
def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
A_shared = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")
A_warp = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")

with T.block("root"):
T.reads(A_shared[0:16, 0:16])
T.writes(A_warp[0:32, 0:8])

for ax0, ax1 in T.grid(16, 16):
with T.block("A_shared_warp"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A_shared[v0, v1])

thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1)
T.writes(A_warp[thread_id, local_id])
A_warp[thread_id, local_id] = A_shared[v0, v1]


@T.prim_func
def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None:
s1 = T.var("int32")
s0 = T.var("int32")
A_shared = T.match_buffer(
a,
(16, 16),
"float16",
align=128,
offset_factor=16,
scope="shared",
strides=[s1, s0],
)
A_warp = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
with T.block("root"):
T.reads(A_shared[0:16, 0:16])
T.writes(A_warp[0:32, 0:8])
tx = T.env_thread("threadIdx.x")
T.launch_thread(tx, 32)

T.evaluate(
T.ptx_ldmatrix(
0,
4,
".b16",
A_warp.data,
A_warp.elem_offset + 8 * tx,
A_shared.access_ptr("r"),
s1 * (tx % 16) + 8 * (tx // 16),
dtype="float16",
)
)


@T.prim_func
def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None:
B_shared = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")
B_warp = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")

with T.block("root"):
T.reads(B_shared[0:16, 0:16])
T.writes(B_warp[0:32, 0:8])

for ax0, ax1 in T.grid(16, 16):
with T.block("B_shared_warp"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(B_shared[v0, v1])
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1)
T.writes(B_warp[thread_id, local_id])
B_warp[thread_id, local_id] = B_shared[v0, v1]


@T.prim_func
def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None:
s1 = T.var("int32")
s0 = T.var("int32")
B_shared = T.match_buffer(
a,
(16, 16),
"float16",
align=128,
offset_factor=16,
scope="shared",
strides=[s1, s0],
)
B_warp = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
with T.block("root"):
T.reads(B_shared[0:16, 0:16])
T.writes(B_warp[0:32, 0:8])
tx = T.env_thread("threadIdx.x")
T.launch_thread(tx, 32)

T.evaluate(
T.ptx_ldmatrix(
1,
4,
".b16",
B_warp.data,
B_warp.elem_offset + 8 * tx,
B_shared.access_ptr("r"),
s1 * (tx % 16) + 8 * (tx // 16),
dtype="float16",
)
)


@T.prim_func
def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
Expand Down Expand Up @@ -271,8 +160,6 @@ def mma_fill_impl(a: T.handle) -> None:
T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="float32"))


tir.TensorIntrin.register("mma.ldmatrix_a", ldmatrix_a_desc, ldmatrix_a_impl)
tir.TensorIntrin.register("mma.ldmatrix_b", ldmatrix_b_desc, ldmatrix_b_impl)
tir.TensorIntrin.register("mma_sync", mma_sync_desc, mma_sync_impl)
tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl)
tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl)
Expand Down Expand Up @@ -402,8 +289,8 @@ def index_map(i, j):
sch.transform_layout(B_warp, 0, "write", index_map)
sch.transform_layout(C_warp, 0, "read", index_map)

sch.tensorize(loop_a, "mma.ldmatrix_a")
sch.tensorize(loop_b, "mma.ldmatrix_b")
sch.tensorize(loop_a, LDMATRIX_16x16_A_INTRIN)
sch.tensorize(loop_b, LDMATRIX_16x16_B_INTRIN)
sch.tensorize(sch.get_loops(block_inner)[-3], "mma_sync")
sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill")
sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store")
Expand Down
Loading

0 comments on commit 5afb5f0

Please sign in to comment.