From 5feddb2e05faececbffbe057d37cf9792e4378c5 Mon Sep 17 00:00:00 2001 From: Jiawei Liu Date: Tue, 18 Oct 2022 14:26:46 -0500 Subject: [PATCH] [Relay] fix: add compute tag for trilu (#13120) fix: add compute tag for trilu --- python/tvm/topi/transform.py | 2 +- tests/python/relay/test_op_level3.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 44263e131182..0347473f83b7 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1058,4 +1058,4 @@ def _apply_trilu(*indices): value = data(*other_indices, row_index, col_index) return tvm.tir.Select(check_position, value, tvm.tir.const(0, data.dtype)) - return te.compute(data.shape, _apply_trilu, name="trilu") + return te.compute(data.shape, _apply_trilu, name="trilu", tag=topi.tag.ELEMWISE) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 400f7dcf0b42..9becfc12671d 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -2264,5 +2264,24 @@ def verify_trilu(data_shape, upper=True, k=0): verify_trilu((8, 6, 6), False, -2) +def test_trilu_reduce(): + data_i0 = np.ones((2, 2), dtype="int32") + k = 0 + + i0 = relay.var("i0", shape=[2, 2], dtype="int32") + i1 = relay.var("i1", shape=(), dtype="int64") + v0 = relay.trilu(i0, i1) + v1 = relay.argmin(v0, axis=[0]) + f = relay.Function([i0, i1], v1) + tvm_res = ( + relay.create_executor("graph", device=tvm.cpu(), target="llvm") + .evaluate(f)(data_i0, k) + .numpy() + ) + + np_res = np.triu(data_i0, k).argmin(axis=0) + tvm.testing.assert_allclose(tvm_res, np_res) + + if __name__ == "__main__": tvm.testing.main()