diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index a2ee666c2f75..8edfb64fd6cf 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -578,32 +578,33 @@ def transform_Assign(self, node): arg_list = self.parse_arg_list(func, node.rhs) func.handle(node, self.context, arg_list, node.rhs.func_name.span) return self.parse_body(node) - else: - value = self.transform(node.rhs) - if len(node.lhs) == 1 and not isinstance(node.lhs[0], ast.Var): - # This is a little confusing because it only is true when - # we have taken this branch. We might need to clarify what - # exectly is allowed in Assignments in tvmscript. - self.report_error( - "Left hand side of assignment must be an unqualified variable", - node.span, - ) - ast_var = node.lhs[0] + if isinstance(node.rhs, (ast.Call, ast.Constant)): + # Pattern 4 of let binding + value = self.transform(node.rhs) + if len(node.lhs) == 1 and not isinstance(node.lhs[0], ast.Var): + # This is a little confusing because it only is true when + # we have taken this branch. We might need to clarify what + # exectly is allowed in Assignments in tvmscript. + self.report_error( + "Left hand side of assignment must be an unqualified variable", + node.span, + ) + ast_var = node.lhs[0] - if node.ty is None and hasattr(value, "dtype"): - var_ty = value.dtype - else: - var_ty = self.parse_type(node.ty, ast_var) + if node.ty is None and hasattr(value, "dtype"): + var_ty = value.dtype + else: + var_ty = self.parse_type(node.ty, ast_var) - var = tvm.te.var( - ast_var.id.name, - var_ty, - span=tvm_span_from_synr(ast_var.span), - ) - self.context.update_symbol(var.name, var, node) - body = self.parse_body(node) - self.context.remove_symbol(var.name) - return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span)) + var = tvm.te.var( + ast_var.id.name, + var_ty, + span=tvm_span_from_synr(ast_var.span), + ) + self.context.update_symbol(var.name, var, node) + body = self.parse_body(node) + self.context.remove_symbol(var.name) + return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span)) self.report_error( """Assignments should be either diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index 1d3c8ab1f105..a0964ea4d77c 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -249,5 +249,21 @@ def func_without_type_annotation(A: T.Buffer[(1,), "int32"]): T.evaluate(x) +def test_letstmt_bind_with_constant(): + @T.prim_func + def constant_binds(): + x = 1 + y = 42.0 + T.evaluate(T.cast(x, "float32") + y) + + @T.prim_func + def constant_binds_wrapped(): + x = T.int32(1) + y = T.float32(42.0) + T.evaluate(T.cast(x, "float32") + y) + + assert_structural_equal(constant_binds, constant_binds_wrapped) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))