diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index fe71b064320f9..3e0292e6332ec 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -543,7 +543,7 @@ def transform_Assign(self, node): AST abstract grammar: Assign(expr* targets, expr value, string? type_comment) - By now 3 patterns of Assign is supported: + By now 5 patterns of Assign is supported: 1. special stmts with return value 1.1 Buffer = T.match_buffer()/T.buffer_decl() 1.2 Var = T.var() @@ -552,6 +552,8 @@ def transform_Assign(self, node): 3. (Store) Var[PrimExpr] = PrimExpr 4. with scope handlers with concise scoping and var def 4.1 var = T.allocate() + 5. An invocation of an arbitrary python callable + x, y = f(...) """ if isinstance(node.rhs, ast.Call): @@ -577,6 +579,22 @@ def transform_Assign(self, node): arg_list = self.parse_arg_list(func, node.rhs) func.handle(node, self.context, arg_list, node.rhs.func_name.span) return self.parse_body(node) + elif callable(func): + # Pattern 5 + args = [self.transform(arg) for arg in node.rhs.params] + out = func(*args) + assert len(out) == len(node.lhs) + + for var, value in zip(node.lhs, out): + self.context.update_symbol(var.id.name, value, node) + + body = self.parse_body(node) + + for var, value in zip(node.lhs, out): + self.context.remove_symbol(var.name) + + return body + if isinstance(node.rhs, (ast.Call, ast.Constant)): # Pattern 4 of let binding value = self.transform(node.rhs)