Skip to content

Commit

Permalink
revert some change
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent e599a55 commit 5a80adc
Showing 1 changed file with 2 additions and 139 deletions.
141 changes: 2 additions & 139 deletions tests/python/unittest/test_tir_ptx_ldmatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,142 +97,5 @@ def test_ptx_ldmatrix():
tvm.testing.assert_allclose(B_nd.numpy(), A_mask_np)


@T.prim_func
def ptx_ldmatrix_uint8_B(
A: T.Buffer[(32, 16), "uint8"], B: T.Buffer[(32, 16), "uint8"]) -> None:
T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
bx = T.env_thread("blockIdx.x")
tx = T.env_thread("threadIdx.x")
T.launch_thread(bx, 1)
T.launch_thread(tx, 32)
with T.block():
A_shared = T.alloc_buffer([32, 16], "uint8", scope="shared")
A_local = T.alloc_buffer([16], "uint8", scope="local")

for i in range(16):
A_shared[tx, i] = A[tx, i]

T.evaluate(
T.ptx_ldmatrix(
1,
4,
".b16",
A_local.data,
0,
A_shared.data,
16,
dtype="uint8",
)
)

for i in range(16):
B[tx, i] = A_local[i]


@T.prim_func
def ptx_ldmatrix_uint8_A(
A: T.Buffer[(16, 32), "uint8"], B: T.Buffer[(16, 32), "uint8"]) -> None:
T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
bx = T.env_thread("blockIdx.x")
tx = T.env_thread("threadIdx.x")
T.launch_thread(bx, 1)
T.launch_thread(tx, 32)
with T.block():
A_shared = T.alloc_buffer([16, 32], "uint8", scope="shared")
A_local = T.alloc_buffer([16], "uint8", scope="local")

for i in range(16):
A_shared[i, tx] = A[i, tx]

T.evaluate(
T.ptx_ldmatrix(
0,
4,
".b16",
A_local.data,
0,
A_shared.data,
32 * (tx % 16) + 16 * (tx // 16),
dtype="uint8",
)
)

for i in range(16):
B[i, tx] = A_local[i]


@T.prim_func
def ptx_ldmatrix_float16_A(
A: T.Buffer[(16, 16), "float16"], B: T.Buffer[(32, 8), "float16"]) -> None:
T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
bx = T.env_thread("blockIdx.x")
tx = T.env_thread("threadIdx.x")
T.launch_thread(bx, 1)
T.launch_thread(tx, 32)
with T.block():
A_shared = T.alloc_buffer([16, 16], "float16", scope="shared")
A_local = T.alloc_buffer([8], "float16", scope="local")

for i in range(8):
A_shared[i * 2 + tx // 16, tx % 16] = A[i * 2 + tx // 16, tx % 16]

T.evaluate(
T.ptx_ldmatrix(
1,
4,
".b16",
A_local.data,
0,
A_shared.data,
16 * (tx % 16) + 8 * (tx // 16),
dtype="float16",
)
)

for i in range(8):
B[tx, i] = A_local[i]


# def test_ptx_ldmatrix_uint8_B():
f = ptx_ldmatrix_uint8_B
mod = tvm.build(f, target="cuda")
print(mod.imported_modules[0].get_source())
A_np = np.arange(32 * 16).reshape((32, 16)).astype("uint8")

B_np = np.zeros((32, 16)).astype("uint8")
dev = tvm.cuda(0)
A_nd = tvm.nd.array(A_np, device=dev)
B_nd = tvm.nd.array(B_np, device=dev)
mod(A_nd, B_nd)
print(A_np)
print(B_nd.numpy()[0])
print(B_nd.numpy()[1])
print(B_nd.numpy()[31])


def test_ptx_ldmatrix_uint8_A():
f = ptx_ldmatrix_uint8_A
mod = tvm.build(f, target="cuda")
A_np = np.arange(32 * 16).reshape((16, 32)).astype("uint8")

B_np = np.zeros((16, 32)).astype("uint8")
dev = tvm.cuda(0)
A_nd = tvm.nd.array(A_np, device=dev)
B_nd = tvm.nd.array(B_np, device=dev)
mod(A_nd, B_nd)
print(A_np)
print(B_nd.numpy()[:, 0])


def test_ptx_ldmatrix_float16_A():
f = ptx_ldmatrix_float16_A
mod = tvm.build(f, target="cuda")
A_np = np.arange(16 * 16).reshape((16, 16)).astype("float16")

B_np = np.zeros((32, 8)).astype("float16")
dev = tvm.cuda(0)
A_nd = tvm.nd.array(A_np, device=dev)
B_nd = tvm.nd.array(B_np, device=dev)
mod(A_nd, B_nd)
print(A_np.astype("uint8"))
print(B_nd.numpy()[0])
if __name__ == "__main__":
test_ptx_ldmatrix()

0 comments on commit 5a80adc

Please sign in to comment.