Skip to content

Commit

Permalink
[FIX] resolve int64/32 for AttrStmtNode (apache#10983)
Browse files Browse the repository at this point in the history
* resolve int64/32 for AttrStmtNode

* rm debug header

* refine

* add test case

* lint
  • Loading branch information
ganler authored and Lucien0 committed Apr 19, 2022
1 parent 1f8129d commit 87a2a9d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
12 changes: 11 additions & 1 deletion src/tir/transforms/narrow_datatype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,17 @@ class DataTypeRewriter : public StmtExprMutator {
PrimExpr e = VisitExpr(iv->var);
Var var = Downcast<Var>(e);
if (ivmap_.find(iv) == ivmap_.end()) {
ivmap_[iv] = IterVar(iv->dom, var, iv->iter_type, iv->thread_tag);
Range dom = iv->dom;
if (dom.defined()) {
PrimExpr extend = dom->extent;
if (extend.dtype().is_int() && var.dtype().is_int() &&
var.dtype().bits() != extend.dtype().bits()) {
int bits = std::max(extend.dtype().bits(), var.dtype().bits());
DataType dtype = var.dtype().with_bits(bits);
dom = Range(cast(dtype, dom->min), cast(dtype, extend), dom->span);
}
}
ivmap_[iv] = IterVar(dom, var, iv->iter_type, iv->thread_tag);
}
return AttrStmt(ivmap_[iv], op->attr_key, cast(var.dtype(), op->value), op->body);
}
Expand Down
17 changes: 17 additions & 0 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,23 @@ def test_broadcast_to():
tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5)


@tvm.testing.uses_gpu
def test_broadcast_to_const_shape_int64():
shape_like = relay.const(np.array([1, 5]), dtype="int64")
x = relay.var("x", shape=(1,), dtype="int64")
z = relay.broadcast_to(x, shape=shape_like)
z = relay.sum(z, axis=0)

f = relay.Function([x], z)

x = np.random.randint(10, size=(1,), dtype="int64")
ref_res = np.broadcast_to(x, (5,))
for target, dev in tvm.testing.enabled_targets():
for kind in ["graph", "debug"]:
op_res = relay.create_executor(kind, device=dev, target=target).evaluate(f)(x)
tvm.testing.assert_allclose(op_res.numpy(), ref_res)


@tvm.testing.uses_gpu
def test_broadcast_to_like():
shape = (4, 1, 6)
Expand Down

0 comments on commit 87a2a9d

Please sign in to comment.