Skip to content

Commit

Permalink
[UnitTests] Parametrized test_topi_argwhere.py
Browse files Browse the repository at this point in the history
Refactored while debugging breakage of tests in
apache#11646.  Submitting as a separate
PR, as it isn't necessary or related to the primary changes in that
PR.
  • Loading branch information
Lunderberg committed Jun 10, 2022
1 parent 609d6af commit bf16c6e
Showing 1 changed file with 34 additions and 38 deletions.
72 changes: 34 additions & 38 deletions tests/python/topi/python/test_topi_argwhere.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
# under the License.
"""Test for argwhere operator"""
import numpy as np
import pytest

import tvm
import tvm.testing
from tvm import te
from tvm import topi
import tvm.topi.testing
Expand All @@ -29,56 +31,50 @@

_argwhere_compute = {"llvm": topi.argwhere, "cuda": topi.cuda.argwhere}

data_shape = tvm.testing.parameter(
(1,),
(100,),
(1, 1),
(5, 3),
(32, 64),
(128, 65),
(200, 500),
(6, 5, 3),
(1, 1, 1),
(1, 1, 1, 1),
(6, 4, 5, 3),
(1, 1, 1, 1, 1),
(6, 4, 5, 3, 7),
)

def verify_argwhere(data_shape):

@tvm.testing.parametrize_targets("llvm", "cuda")
def test_argwhere(target, dev, data_shape):
dtype = "int32"
np_data = np.random.choice([0, 1, 2, 3], size=data_shape).astype(dtype)
np_out = np.argwhere(np_data)
out_shape = np_out.shape[0]

np_shape = np.ones(shape=(out_shape, len(data_shape)), dtype=dtype)

out_shape = te.placeholder(shape=(out_shape, len(data_shape)), name="out_shape", dtype=dtype)
condition = te.placeholder(shape=data_shape, name="condition", dtype=dtype)

def check_device(target):
dev = tvm.device(target, 0)
if not dev.exist or target not in _argwhere_compute:
return

with tvm.target.Target(target):
out = _argwhere_compute[target](out_shape, condition)
s_func = tvm.topi.testing.dispatch(target, _argwhere_schedule)
sch = s_func(out)

func = tvm.build(sch, [out_shape, condition, out], target, name="argwhere")

args = [tvm.nd.array(np_shape, dev)]
args.append(tvm.nd.array(np_data, dev))
args.append(tvm.nd.empty(out.shape, device=dev, dtype=condition.dtype))
func(*args)
np.set_printoptions(threshold=np.inf)
tvm.testing.assert_allclose(args[-1].numpy(), np.array(np_out))

for target, _ in tvm.testing.enabled_targets():
check_device(target)
with tvm.target.Target(target):
out = _argwhere_compute[target](out_shape, condition)
s_func = tvm.topi.testing.dispatch(target, _argwhere_schedule)
sch = s_func(out)

func = tvm.build(sch, [out_shape, condition, out], target, name="argwhere")

@tvm.testing.uses_gpu
def test_argwhere():
verify_argwhere((1,))
verify_argwhere((100,))
verify_argwhere((1, 1))
verify_argwhere((5, 3))
verify_argwhere((32, 64))
verify_argwhere((128, 65))
verify_argwhere((200, 500))
verify_argwhere((6, 5, 3))
verify_argwhere((1, 1, 1))
verify_argwhere((1, 1, 1, 1))
verify_argwhere((6, 4, 5, 3))
verify_argwhere((1, 1, 1, 1, 1))
verify_argwhere((6, 4, 5, 3, 7))
args = [tvm.nd.array(np_shape, dev)]
args.append(tvm.nd.array(np_data, dev))
args.append(tvm.nd.empty(out.shape, device=dev, dtype=condition.dtype))
func(*args)
np.set_printoptions(threshold=np.inf)
tvm_out = args[-1].numpy()
tvm.testing.assert_allclose(tvm_out, np_out)


if __name__ == "__main__":
test_argwhere()
tvm.testing.main()

0 comments on commit bf16c6e

Please sign in to comment.