Skip to content

Commit

Permalink
allow constant value let binding in script (apache#11115)
Browse files Browse the repository at this point in the history
  • Loading branch information
wrongtest-intellif authored and juda committed Jun 21, 2022
1 parent 2de418e commit fd7faf4
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 24 deletions.
49 changes: 25 additions & 24 deletions python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,32 +578,33 @@ def transform_Assign(self, node):
arg_list = self.parse_arg_list(func, node.rhs)
func.handle(node, self.context, arg_list, node.rhs.func_name.span)
return self.parse_body(node)
else:
value = self.transform(node.rhs)
if len(node.lhs) == 1 and not isinstance(node.lhs[0], ast.Var):
# This is a little confusing because it only is true when
# we have taken this branch. We might need to clarify what
# exectly is allowed in Assignments in tvmscript.
self.report_error(
"Left hand side of assignment must be an unqualified variable",
node.span,
)
ast_var = node.lhs[0]
if isinstance(node.rhs, (ast.Call, ast.Constant)):
# Pattern 4 of let binding
value = self.transform(node.rhs)
if len(node.lhs) == 1 and not isinstance(node.lhs[0], ast.Var):
# This is a little confusing because it only is true when
# we have taken this branch. We might need to clarify what
# exectly is allowed in Assignments in tvmscript.
self.report_error(
"Left hand side of assignment must be an unqualified variable",
node.span,
)
ast_var = node.lhs[0]

if node.ty is None and hasattr(value, "dtype"):
var_ty = value.dtype
else:
var_ty = self.parse_type(node.ty, ast_var)
if node.ty is None and hasattr(value, "dtype"):
var_ty = value.dtype
else:
var_ty = self.parse_type(node.ty, ast_var)

var = tvm.te.var(
ast_var.id.name,
var_ty,
span=tvm_span_from_synr(ast_var.span),
)
self.context.update_symbol(var.name, var, node)
body = self.parse_body(node)
self.context.remove_symbol(var.name)
return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span))
var = tvm.te.var(
ast_var.id.name,
var_ty,
span=tvm_span_from_synr(ast_var.span),
)
self.context.update_symbol(var.name, var, node)
body = self.parse_body(node)
self.context.remove_symbol(var.name)
return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span))

self.report_error(
"""Assignments should be either
Expand Down
16 changes: 16 additions & 0 deletions tests/python/unittest/test_tvmscript_syntax_sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,5 +249,21 @@ def func_without_type_annotation(A: T.Buffer[(1,), "int32"]):
T.evaluate(x)


def test_letstmt_bind_with_constant():
@T.prim_func
def constant_binds():
x = 1
y = 42.0
T.evaluate(T.cast(x, "float32") + y)

@T.prim_func
def constant_binds_wrapped():
x = T.int32(1)
y = T.float32(42.0)
T.evaluate(T.cast(x, "float32") + y)

assert_structural_equal(constant_binds, constant_binds_wrapped)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit fd7faf4

Please sign in to comment.