From 3fa442c646397e8116760e0d28c4e8eab2204260 Mon Sep 17 00:00:00 2001 From: Jiawei Liu Date: Wed, 13 Apr 2022 04:54:09 -0500 Subject: [PATCH] [FIX] resolve int64/32 for AttrStmtNode (#10983) * resolve int64/32 for AttrStmtNode * rm debug header * refine * add test case * lint --- src/tir/transforms/narrow_datatype.cc | 12 +++++++++++- tests/python/relay/test_op_level10.py | 17 +++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index c2bf27393173b..8df7b57eafde7 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -276,7 +276,17 @@ class DataTypeRewriter : public StmtExprMutator { PrimExpr e = VisitExpr(iv->var); Var var = Downcast(e); if (ivmap_.find(iv) == ivmap_.end()) { - ivmap_[iv] = IterVar(iv->dom, var, iv->iter_type, iv->thread_tag); + Range dom = iv->dom; + if (dom.defined()) { + PrimExpr extend = dom->extent; + if (extend.dtype().is_int() && var.dtype().is_int() && + var.dtype().bits() != extend.dtype().bits()) { + int bits = std::max(extend.dtype().bits(), var.dtype().bits()); + DataType dtype = var.dtype().with_bits(bits); + dom = Range(cast(dtype, dom->min), cast(dtype, extend), dom->span); + } + } + ivmap_[iv] = IterVar(dom, var, iv->iter_type, iv->thread_tag); } return AttrStmt(ivmap_[iv], op->attr_key, cast(var.dtype(), op->value), op->body); } diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 0486ef40017b3..85a3dd5636f16 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -229,6 +229,23 @@ def test_broadcast_to(): tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) +@tvm.testing.uses_gpu +def test_broadcast_to_const_shape_int64(): + shape_like = relay.const(np.array([1, 5]), dtype="int64") + x = relay.var("x", shape=(1,), dtype="int64") + z = relay.broadcast_to(x, shape=shape_like) + z = relay.sum(z, axis=0) + + f = relay.Function([x], z) + + x = np.random.randint(10, size=(1,), dtype="int64") + ref_res = np.broadcast_to(x, (5,)) + for target, dev in tvm.testing.enabled_targets(): + for kind in ["graph", "debug"]: + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(f)(x) + tvm.testing.assert_allclose(op_res.numpy(), ref_res) + + @tvm.testing.uses_gpu def test_broadcast_to_like(): shape = (4, 1, 6)