Skip to content

Commit

Permalink
[TVMScript] Support function call to help construct AST
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 16, 2022
1 parent 02d57bb commit 4a5dc3f
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def transform_Assign(self, node):
AST abstract grammar:
Assign(expr* targets, expr value, string? type_comment)
By now 3 patterns of Assign is supported:
By now 5 patterns of Assign is supported:
1. special stmts with return value
1.1 Buffer = T.match_buffer()/T.buffer_decl()
1.2 Var = T.var()
Expand All @@ -552,6 +552,8 @@ def transform_Assign(self, node):
3. (Store) Var[PrimExpr] = PrimExpr
4. with scope handlers with concise scoping and var def
4.1 var = T.allocate()
5. An invocation of an arbitrary python callable
x, y = f(...)
"""

if isinstance(node.rhs, ast.Call):
Expand All @@ -577,6 +579,22 @@ def transform_Assign(self, node):
arg_list = self.parse_arg_list(func, node.rhs)
func.handle(node, self.context, arg_list, node.rhs.func_name.span)
return self.parse_body(node)
elif callable(func):
# Pattern 5
args = [self.transform(arg) for arg in node.rhs.params]
out = func(*args)
assert len(out) == len(node.lhs)

for var, value in zip(node.lhs, out):
self.context.update_symbol(var.id.name, value, node)

body = self.parse_body(node)

for var, value in zip(node.lhs, out):
self.context.remove_symbol(var.name)

return body

if isinstance(node.rhs, (ast.Call, ast.Constant)):
# Pattern 4 of let binding
value = self.transform(node.rhs)
Expand Down

0 comments on commit 4a5dc3f

Please sign in to comment.