Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[Relay] fix: add compute tag for trilu (apache#13120)
Browse files Browse the repository at this point in the history
fix: add compute tag for trilu
  • Loading branch information
ganler authored and xinetzone committed Nov 25, 2022
1 parent ad76c34 commit 5feddb2
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,4 +1058,4 @@ def _apply_trilu(*indices):
value = data(*other_indices, row_index, col_index)
return tvm.tir.Select(check_position, value, tvm.tir.const(0, data.dtype))

return te.compute(data.shape, _apply_trilu, name="trilu")
return te.compute(data.shape, _apply_trilu, name="trilu", tag=topi.tag.ELEMWISE)
19 changes: 19 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2264,5 +2264,24 @@ def verify_trilu(data_shape, upper=True, k=0):
verify_trilu((8, 6, 6), False, -2)


def test_trilu_reduce():
data_i0 = np.ones((2, 2), dtype="int32")
k = 0

i0 = relay.var("i0", shape=[2, 2], dtype="int32")
i1 = relay.var("i1", shape=(), dtype="int64")
v0 = relay.trilu(i0, i1)
v1 = relay.argmin(v0, axis=[0])
f = relay.Function([i0, i1], v1)
tvm_res = (
relay.create_executor("graph", device=tvm.cpu(), target="llvm")
.evaluate(f)(data_i0, k)
.numpy()
)

np_res = np.triu(data_i0, k).argmin(axis=0)
tvm.testing.assert_allclose(tvm_res, np_res)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 5feddb2

Please sign in to comment.