Skip to content

Commit

Permalink
poking with the parser
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent 596582c commit dd8ccf9
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 58 deletions.
28 changes: 28 additions & 0 deletions python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,12 +554,14 @@ def transform_Assign(self, node):
4.1 var = T.allocate()
"""

print("parsing ", node.rhs.func_name)
if isinstance(node.rhs, ast.Call):
# Pattern 1 & Pattern 4
if isinstance(node.rhs.func_name, ast.Op):
func = None
else:
func = self.transform(node.rhs.func_name)
print(func)

if isinstance(func, WithScopeHandler):
if not func.concise_scope or not func.def_symbol:
Expand All @@ -577,6 +579,31 @@ def transform_Assign(self, node):
arg_list = self.parse_arg_list(func, node.rhs)
func.handle(node, self.context, arg_list, node.rhs.func_name.span)
return self.parse_body(node)
elif callable(func):
args = [self.transform(arg) for arg in node.rhs.params]
out = func(*args)
print(out)
print(node.lhs)
assert len(out) == len(node.lhs)

lhs_vars = []
for ast_var, value in zip(node.lhs, out):
var = tvm.te.var(
ast_var.id.name,
"int32",
span=tvm_span_from_synr(ast_var.span),
)
self.context.update_symbol(var.name, var, node)
lhs_vars.append(var)

body = self.parse_body(node)

for var, value in reversed(list(zip(lhs_vars, out))):
self.context.remove_symbol(var.name)
body = tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span))

return body

if isinstance(node.rhs, (ast.Call, ast.Constant)):
# Pattern 4 of let binding
value = self.transform(node.rhs)
Expand All @@ -593,6 +620,7 @@ def transform_Assign(self, node):
if node.ty is None and hasattr(value, "dtype"):
var_ty = value.dtype
else:
print(node.ty, ast_var)
var_ty = self.parse_type(node.ty, ast_var)

var = tvm.te.var(
Expand Down
142 changes: 84 additions & 58 deletions tests/python/unittest/test_mma_16x8x16_4k_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
import numpy as np


def shared_16x16_to_ldmatrix_32x8_layout(i, j):
thread_id = 4 * (i % 8) + (j % 8) // 2
return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2)


@T.prim_func
def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
A_shared = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")
Expand All @@ -21,11 +26,15 @@ def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
with T.block("A_shared_warp"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A_shared[v0, v1])
T.writes(A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = A_shared[
v0, v1
]

thread_id, y = shared_16x16_to_ldmatrix_32x8_layout(v0, v1)
T.writes(A_warp[thread_id, y])
A_warp[thread_id, y] = A_shared[v0, v1]

# T.writes(A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
# A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = A_shared[
# v0, v1
# ]

@T.prim_func
def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None:
Expand Down Expand Up @@ -390,22 +399,39 @@ def tile_wmma_fragment(block_read, height):
sch.reorder(i0, j0, i1, j1)
return i1

def shared_16x16_to_ldmatrix_32x8_layout(i, j):
i_0 = i // 16
j_0 = j // 16

i = i % 16
j = j % 16

thread_id = 4 * (i % 8) + (j % 8) // 2
return i_0, j_0, thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 8) % 2

loop_a = tile_wmma_fragment(A_warp, 16)
loop_b = tile_wmma_fragment(B_warp, 16)

sch.transform_layout(A_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout)
sch.transform_layout(B_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout)
sch.transform_layout(C_warp, 0, "read", index_map=shared_16x16_to_ldmatrix_32x8_layout)
sch.transform_layout(
A_warp,
0,
"write",
index_map=lambda i, j: (
i // 16,
j // 16,
*shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16),
),
)
sch.transform_layout(
B_warp,
0,
"write",
index_map=lambda i, j: (
i // 16,
j // 16,
*shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16),
),
)
sch.transform_layout(
C_warp,
0,
"read",
index_map=lambda i, j: (
i // 16,
j // 16,
*shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16),
),
)

sch.tensorize(loop_a, "mma.ldmatrix_a")
sch.tensorize(loop_b, "mma.ldmatrix_b")
Expand Down Expand Up @@ -438,44 +464,44 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j):
schedule(sch)
print(sch.mod.script())

if tune:
with tempfile.TemporaryDirectory() as work_dir:
sch = ms.tune_tir(
mod=workload,
target=tvm.target.Target("nvidia/geforce-rtx-3070"),
config=ms.TuneConfig(
strategy="evolutionary",
num_trials_per_iter=32,
max_trials_per_task=128,
max_trials_global=128,
),
work_dir=work_dir,
space=ms.space_generator.ScheduleFn(schedule),
)
if sch is None:
print("No valid schedule found!")
else:
print(sch.mod.script())
print(sch.trace)
else:
target = "cuda"
f = tvm.build(sch.mod["main"], target=target, name="dense")

dev = tvm.device("cuda", 0)
a_np = np.random.uniform(size=(N, K)).astype("float16")
b_np = np.random.uniform(size=(K, M)).astype("float16")
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32"))
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(b_np, dev)
c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev)
f = tvm.build(sch.mod["main"], target="cuda", name="dense")

print(f.imported_modules[0].get_source())
f(a, b, c)
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
print("ok")

evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
gflops = (N * M * K) * 2 / 1e9
time_ms = evaluator(a, b, c).mean * 1e3
print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))
# if tune:
# with tempfile.TemporaryDirectory() as work_dir:
# sch = ms.tune_tir(
# mod=workload,
# target=tvm.target.Target("nvidia/geforce-rtx-3070"),
# config=ms.TuneConfig(
# strategy="evolutionary",
# num_trials_per_iter=32,
# max_trials_per_task=128,
# max_trials_global=128,
# ),
# work_dir=work_dir,
# space=ms.space_generator.ScheduleFn(schedule),
# )
# if sch is None:
# print("No valid schedule found!")
# else:
# print(sch.mod.script())
# print(sch.trace)
# else:
# target = "cuda"
# f = tvm.build(sch.mod["main"], target=target, name="dense")

# dev = tvm.device("cuda", 0)
# a_np = np.random.uniform(size=(N, K)).astype("float16")
# b_np = np.random.uniform(size=(K, M)).astype("float16")
# c_np = np.dot(a_np.astype("float32"), b_np.astype("float32"))
# a = tvm.nd.array(a_np, dev)
# b = tvm.nd.array(b_np, dev)
# c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev)
# f = tvm.build(sch.mod["main"], target="cuda", name="dense")

# print(f.imported_modules[0].get_source())
# f(a, b, c)
# tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
# print("ok")

# evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
# gflops = (N * M * K) * 2 / 1e9
# time_ms = evaluator(a, b, c).mean * 1e3
# print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))

0 comments on commit dd8ccf9

Please sign in to comment.