From 12a376ae2f44aa6660121e64e0358f2866624f7f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 16 May 2022 17:54:58 +0900 Subject: [PATCH] Squashed commit of the following: commit 48eef4981d1a55aaf3b0ac935f2a10347cb1ac2d Author: Masahiro Masuda Date: Mon May 16 17:40:48 2022 +0900 more comment commit 8f67fc87038834e9f7e2c5cd3dfe61fabf442206 Author: Masahiro Masuda Date: Mon May 16 17:11:27 2022 +0900 update test commit ad85036621c005b733763e67ceffae39c356ec99 Author: Masahiro Masuda Date: Mon May 16 16:54:01 2022 +0900 add test commit 4a5dc3ffd5d0bb4a1700e57897c9e0f26e3d2a88 Author: Masahiro Masuda Date: Mon May 16 16:40:47 2022 +0900 [TVMScript] Support function call to help construct AST --- python/tvm/script/parser.py | 17 ++++-- .../unittest/test_tvmscript_syntax_sugar.py | 59 +++++++++++++++++++ 2 files changed, 72 insertions(+), 4 deletions(-) diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 2a19dfc33dc2..f46cddce989e 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -543,7 +543,7 @@ def transform_Assign(self, node): AST abstract grammar: Assign(expr* targets, expr value, string? type_comment) - By now 3 patterns of Assign is supported: + By now 5 patterns of Assign is supported: 1. special stmts with return value 1.1 Buffer = T.match_buffer()/T.buffer_decl() 1.2 Var = T.var() @@ -552,6 +552,9 @@ def transform_Assign(self, node): 3. (Store) Var[PrimExpr] = PrimExpr 4. with scope handlers with concise scoping and var def 4.1 var = T.allocate() + 5. A call to a pure python function, consuming and producing TVMScript values. + The outputs are inlined into the following body (no variable is created). + x, y = f(...) """ if isinstance(node.rhs, ast.Call): @@ -578,14 +581,20 @@ def transform_Assign(self, node): func.handle(node, self.context, arg_list, node.rhs.func_name.span) return self.parse_body(node) elif callable(func): + # Pattern 5 args = [self.transform(arg) for arg in node.rhs.params] out = func(*args) assert len(out) == len(node.lhs) - for ast_var, value in zip(node.lhs, out): - self.context.update_symbol(ast_var.id.name, value, node) + for var, value in zip(node.lhs, out): + self.context.update_symbol(var.id.name, value, node) - return self.parse_body(node) + body = self.parse_body(node) + + for var, value in zip(node.lhs, out): + self.context.remove_symbol(var.id.name) + + return body if isinstance(node.rhs, (ast.Call, ast.Constant)): # Pattern 4 of let binding diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index a0964ea4d77c..c9566c07c329 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -265,5 +265,64 @@ def constant_binds_wrapped(): assert_structural_equal(constant_binds, constant_binds_wrapped) +def test_func_call(): + def shared_16x16_to_ldmatrix_32x8_layout(i, j): + thread_id = (i % 8) * 4 + (j % 8) // 2 + return thread_id, (j // 8) * 4 + (i // 8) * 2 + (j % 2) + + @T.prim_func + def mma_sync_m16n16k16_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp") + B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp") + C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp") + + with T.block("root"): + T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8]) + T.writes(C[0:32, 0:8]) + for i, j, k in T.grid(16, 16, 16): + with T.block("C"): + i, j, k = T.axis.remap("SSR", [i, j, k]) + thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i, j) + thread_id_A, local_id_A = shared_16x16_to_ldmatrix_32x8_layout(i, k) + thread_id_B, local_id_B = shared_16x16_to_ldmatrix_32x8_layout(k, j) + + T.reads( + C[thread_id_C, local_id_C], + A[thread_id_A, local_id_A], + B[thread_id_B, local_id_B], + ) + T.writes(C[thread_id_C, local_id_C]) + + C[thread_id_C, local_id_C] += ( + A[thread_id_A, local_id_A] * B[thread_id_B, local_id_B] + ) + + @T.prim_func + def mma_sync_m16n16k16_desc_manual(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp") + B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp") + C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp") + + with T.block("root"): + T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8]) + T.writes(C[0:32, 0:8]) + for i, j, k in T.grid(16, 16, 16): + with T.block("C"): + i, j, k = T.axis.remap("SSR", [i, j, k]) + T.reads( + C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2], + A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2], + B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2], + ) + T.writes(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2]) + C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = ( + C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] + + A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2] + * B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2] + ) + + assert_structural_equal(mma_sync_m16n16k16_desc, mma_sync_m16n16k16_desc_manual) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))