Skip to content

Commit

Permalink
mma 16x8x32 int8 working with ldmatrix b workaround
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent 54f1cb7 commit 3ca8ca0
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 7 deletions.
18 changes: 15 additions & 3 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -818,9 +818,21 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string local_ptr = this->PrintExpr(op->args[3]);
std::string local_elem_offset = this->PrintExpr(op->args[4]);
std::string smem_ptr = this->PrintExpr(op->args[5]);
std::string smem_elem_offset = this->PrintExpr(op->args[6]);
this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset,
smem_ptr, smem_elem_offset);

if (trans && op->dtype.bits() == 8) {
std::string smem_stride = this->PrintExpr(op->args[6]);
LOG(INFO) << op->dtype;
CHECK(num == 4);
os << "for (int i = 0; i < 16; ++i) {\n";
os << local_ptr << "[" + local_elem_offset + " + i] = " << smem_ptr
<< "[(i % 8) / 4 * " + smem_stride + " * 16 + (threadIdx.x % 4) * 4 * " + smem_stride +
"+ (i % 4) * " + smem_stride + " + threadIdx.x / 4 + (i / 8) * 8];\n";
os << "}\n";
} else {
std::string smem_elem_offset = this->PrintExpr(op->args[6]);
this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset,
smem_ptr, smem_elem_offset);
}
} else if (op->op.same_as(builtin::mma_store())) {
int m = Downcast<Integer>(op->args[0])->value;
int n = Downcast<Integer>(op->args[1])->value;
Expand Down
5 changes: 3 additions & 2 deletions tests/python/unittest/test_mma_16x8x32_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None:
B_warp.data,
16 * tx,
B_shared.data,
32 * (tx % 16) + 8 * (tx // 16),
s1,
dtype="int8",
)
)
Expand Down Expand Up @@ -372,6 +372,7 @@ def shared_32x16_to_ldmatrix_32x16_layout(i, j):
b = tvm.nd.array(b_np, dev)
c = tvm.nd.array(np.zeros((M, N), dtype="int32"), dev)

# print(f.imported_modules[0].get_source())
print(f.imported_modules[0].get_source())
f(a, b, c)
np.testing.assert_equal(c.numpy(), c_np)
print("ok")
141 changes: 139 additions & 2 deletions tests/python/unittest/test_tir_ptx_ldmatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,5 +97,142 @@ def test_ptx_ldmatrix():
tvm.testing.assert_allclose(B_nd.numpy(), A_mask_np)


if __name__ == "__main__":
test_ptx_ldmatrix()
@T.prim_func
def ptx_ldmatrix_uint8_B(
A: T.Buffer[(32, 16), "uint8"], B: T.Buffer[(32, 16), "uint8"]) -> None:
T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
bx = T.env_thread("blockIdx.x")
tx = T.env_thread("threadIdx.x")
T.launch_thread(bx, 1)
T.launch_thread(tx, 32)
with T.block():
A_shared = T.alloc_buffer([32, 16], "uint8", scope="shared")
A_local = T.alloc_buffer([16], "uint8", scope="local")

for i in range(16):
A_shared[tx, i] = A[tx, i]

T.evaluate(
T.ptx_ldmatrix(
1,
4,
".b16",
A_local.data,
0,
A_shared.data,
16,
dtype="uint8",
)
)

for i in range(16):
B[tx, i] = A_local[i]


@T.prim_func
def ptx_ldmatrix_uint8_A(
A: T.Buffer[(16, 32), "uint8"], B: T.Buffer[(16, 32), "uint8"]) -> None:
T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
bx = T.env_thread("blockIdx.x")
tx = T.env_thread("threadIdx.x")
T.launch_thread(bx, 1)
T.launch_thread(tx, 32)
with T.block():
A_shared = T.alloc_buffer([16, 32], "uint8", scope="shared")
A_local = T.alloc_buffer([16], "uint8", scope="local")

for i in range(16):
A_shared[i, tx] = A[i, tx]

T.evaluate(
T.ptx_ldmatrix(
0,
4,
".b16",
A_local.data,
0,
A_shared.data,
32 * (tx % 16) + 16 * (tx // 16),
dtype="uint8",
)
)

for i in range(16):
B[i, tx] = A_local[i]


@T.prim_func
def ptx_ldmatrix_float16_A(
A: T.Buffer[(16, 16), "float16"], B: T.Buffer[(32, 8), "float16"]) -> None:
T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
bx = T.env_thread("blockIdx.x")
tx = T.env_thread("threadIdx.x")
T.launch_thread(bx, 1)
T.launch_thread(tx, 32)
with T.block():
A_shared = T.alloc_buffer([16, 16], "float16", scope="shared")
A_local = T.alloc_buffer([8], "float16", scope="local")

for i in range(8):
A_shared[i * 2 + tx // 16, tx % 16] = A[i * 2 + tx // 16, tx % 16]

T.evaluate(
T.ptx_ldmatrix(
1,
4,
".b16",
A_local.data,
0,
A_shared.data,
16 * (tx % 16) + 8 * (tx // 16),
dtype="float16",
)
)

for i in range(8):
B[tx, i] = A_local[i]


# def test_ptx_ldmatrix_uint8_B():
f = ptx_ldmatrix_uint8_B
mod = tvm.build(f, target="cuda")
print(mod.imported_modules[0].get_source())
A_np = np.arange(32 * 16).reshape((32, 16)).astype("uint8")

B_np = np.zeros((32, 16)).astype("uint8")
dev = tvm.cuda(0)
A_nd = tvm.nd.array(A_np, device=dev)
B_nd = tvm.nd.array(B_np, device=dev)
mod(A_nd, B_nd)
print(A_np)
print(B_nd.numpy()[0])
print(B_nd.numpy()[1])
print(B_nd.numpy()[31])


def test_ptx_ldmatrix_uint8_A():
f = ptx_ldmatrix_uint8_A
mod = tvm.build(f, target="cuda")
A_np = np.arange(32 * 16).reshape((16, 32)).astype("uint8")

B_np = np.zeros((16, 32)).astype("uint8")
dev = tvm.cuda(0)
A_nd = tvm.nd.array(A_np, device=dev)
B_nd = tvm.nd.array(B_np, device=dev)
mod(A_nd, B_nd)
print(A_np)
print(B_nd.numpy()[:, 0])


def test_ptx_ldmatrix_float16_A():
f = ptx_ldmatrix_float16_A
mod = tvm.build(f, target="cuda")
A_np = np.arange(16 * 16).reshape((16, 16)).astype("float16")

B_np = np.zeros((32, 8)).astype("float16")
dev = tvm.cuda(0)
A_nd = tvm.nd.array(A_np, device=dev)
B_nd = tvm.nd.array(B_np, device=dev)
mod(A_nd, B_nd)
print(A_np.astype("uint8"))
print(B_nd.numpy()[0])

0 comments on commit 3ca8ca0

Please sign in to comment.