diff --git a/tests/python/topi/python/test_topi_argwhere.py b/tests/python/topi/python/test_topi_argwhere.py index 8592f57b74a4..bc43dbb2b051 100644 --- a/tests/python/topi/python/test_topi_argwhere.py +++ b/tests/python/topi/python/test_topi_argwhere.py @@ -16,8 +16,10 @@ # under the License. """Test for argwhere operator""" import numpy as np +import pytest import tvm +import tvm.testing from tvm import te from tvm import topi import tvm.topi.testing @@ -29,56 +31,50 @@ _argwhere_compute = {"llvm": topi.argwhere, "cuda": topi.cuda.argwhere} +data_shape = tvm.testing.parameter( + (1,), + (100,), + (1, 1), + (5, 3), + (32, 64), + (128, 65), + (200, 500), + (6, 5, 3), + (1, 1, 1), + (1, 1, 1, 1), + (6, 4, 5, 3), + (1, 1, 1, 1, 1), + (6, 4, 5, 3, 7), +) -def verify_argwhere(data_shape): + +@tvm.testing.parametrize_targets("llvm", "cuda") +def test_argwhere(target, dev, data_shape): dtype = "int32" np_data = np.random.choice([0, 1, 2, 3], size=data_shape).astype(dtype) np_out = np.argwhere(np_data) out_shape = np_out.shape[0] + np_shape = np.ones(shape=(out_shape, len(data_shape)), dtype=dtype) out_shape = te.placeholder(shape=(out_shape, len(data_shape)), name="out_shape", dtype=dtype) condition = te.placeholder(shape=data_shape, name="condition", dtype=dtype) - def check_device(target): - dev = tvm.device(target, 0) - if not dev.exist or target not in _argwhere_compute: - return - - with tvm.target.Target(target): - out = _argwhere_compute[target](out_shape, condition) - s_func = tvm.topi.testing.dispatch(target, _argwhere_schedule) - sch = s_func(out) - - func = tvm.build(sch, [out_shape, condition, out], target, name="argwhere") - - args = [tvm.nd.array(np_shape, dev)] - args.append(tvm.nd.array(np_data, dev)) - args.append(tvm.nd.empty(out.shape, device=dev, dtype=condition.dtype)) - func(*args) - np.set_printoptions(threshold=np.inf) - tvm.testing.assert_allclose(args[-1].numpy(), np.array(np_out)) - - for target, _ in tvm.testing.enabled_targets(): - check_device(target) + with tvm.target.Target(target): + out = _argwhere_compute[target](out_shape, condition) + s_func = tvm.topi.testing.dispatch(target, _argwhere_schedule) + sch = s_func(out) + func = tvm.build(sch, [out_shape, condition, out], target, name="argwhere") -@tvm.testing.uses_gpu -def test_argwhere(): - verify_argwhere((1,)) - verify_argwhere((100,)) - verify_argwhere((1, 1)) - verify_argwhere((5, 3)) - verify_argwhere((32, 64)) - verify_argwhere((128, 65)) - verify_argwhere((200, 500)) - verify_argwhere((6, 5, 3)) - verify_argwhere((1, 1, 1)) - verify_argwhere((1, 1, 1, 1)) - verify_argwhere((6, 4, 5, 3)) - verify_argwhere((1, 1, 1, 1, 1)) - verify_argwhere((6, 4, 5, 3, 7)) + args = [tvm.nd.array(np_shape, dev)] + args.append(tvm.nd.array(np_data, dev)) + args.append(tvm.nd.empty(out.shape, device=dev, dtype=condition.dtype)) + func(*args) + np.set_printoptions(threshold=np.inf) + tvm_out = args[-1].numpy() + tvm.testing.assert_allclose(tvm_out, np_out) if __name__ == "__main__": - test_argwhere() + tvm.testing.main()