Skip to content

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit 48eef49
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon May 16 17:40:48 2022 +0900

    more comment

commit 8f67fc8
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon May 16 17:11:27 2022 +0900

    update test

commit ad85036
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon May 16 16:54:01 2022 +0900

    add test

commit 4a5dc3f
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon May 16 16:40:47 2022 +0900

    [TVMScript] Support function call to help construct AST
  • Loading branch information
masahi committed May 17, 2022
1 parent 76c1bcf commit 12a376a
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 4 deletions.
17 changes: 13 additions & 4 deletions python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def transform_Assign(self, node):
AST abstract grammar:
Assign(expr* targets, expr value, string? type_comment)
By now 3 patterns of Assign is supported:
By now 5 patterns of Assign is supported:
1. special stmts with return value
1.1 Buffer = T.match_buffer()/T.buffer_decl()
1.2 Var = T.var()
Expand All @@ -552,6 +552,9 @@ def transform_Assign(self, node):
3. (Store) Var[PrimExpr] = PrimExpr
4. with scope handlers with concise scoping and var def
4.1 var = T.allocate()
5. A call to a pure python function, consuming and producing TVMScript values.
The outputs are inlined into the following body (no variable is created).
x, y = f(...)
"""

if isinstance(node.rhs, ast.Call):
Expand All @@ -578,14 +581,20 @@ def transform_Assign(self, node):
func.handle(node, self.context, arg_list, node.rhs.func_name.span)
return self.parse_body(node)
elif callable(func):
# Pattern 5
args = [self.transform(arg) for arg in node.rhs.params]
out = func(*args)
assert len(out) == len(node.lhs)

for ast_var, value in zip(node.lhs, out):
self.context.update_symbol(ast_var.id.name, value, node)
for var, value in zip(node.lhs, out):
self.context.update_symbol(var.id.name, value, node)

return self.parse_body(node)
body = self.parse_body(node)

for var, value in zip(node.lhs, out):
self.context.remove_symbol(var.id.name)

return body

if isinstance(node.rhs, (ast.Call, ast.Constant)):
# Pattern 4 of let binding
Expand Down
59 changes: 59 additions & 0 deletions tests/python/unittest/test_tvmscript_syntax_sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,5 +265,64 @@ def constant_binds_wrapped():
assert_structural_equal(constant_binds, constant_binds_wrapped)


def test_func_call():
def shared_16x16_to_ldmatrix_32x8_layout(i, j):
thread_id = (i % 8) * 4 + (j % 8) // 2
return thread_id, (j // 8) * 4 + (i // 8) * 2 + (j % 2)

@T.prim_func
def mma_sync_m16n16k16_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")

with T.block("root"):
T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8])
T.writes(C[0:32, 0:8])
for i, j, k in T.grid(16, 16, 16):
with T.block("C"):
i, j, k = T.axis.remap("SSR", [i, j, k])
thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i, j)
thread_id_A, local_id_A = shared_16x16_to_ldmatrix_32x8_layout(i, k)
thread_id_B, local_id_B = shared_16x16_to_ldmatrix_32x8_layout(k, j)

T.reads(
C[thread_id_C, local_id_C],
A[thread_id_A, local_id_A],
B[thread_id_B, local_id_B],
)
T.writes(C[thread_id_C, local_id_C])

C[thread_id_C, local_id_C] += (
A[thread_id_A, local_id_A] * B[thread_id_B, local_id_B]
)

@T.prim_func
def mma_sync_m16n16k16_desc_manual(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")

with T.block("root"):
T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8])
T.writes(C[0:32, 0:8])
for i, j, k in T.grid(16, 16, 16):
with T.block("C"):
i, j, k = T.axis.remap("SSR", [i, j, k])
T.reads(
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2],
A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2],
B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2],
)
T.writes(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2])
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = (
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2]
+ A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2]
* B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2]
)

assert_structural_equal(mma_sync_m16n16k16_desc, mma_sync_m16n16k16_desc_manual)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit 12a376a

Please sign in to comment.