diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 0c9d2c4381ac..5415c77097a2 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -1234,7 +1234,7 @@ def _mx_topk(inputs, attrs): new_attrs = {} new_attrs["k"] = attrs.get_int("k", 1) new_attrs["axis"] = attrs.get_int("axis", -1) - new_attrs["is_ascend"] = attrs.get_bool("is_ascend", True) + new_attrs["is_ascend"] = attrs.get_bool("is_ascend", False) ret_type = attrs.get_str("ret_typ", "indices") if ret_type == "mask": raise tvm.error.OpAttributeUnimplemented( diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 3e652cfc69e3..4eb7f6139e8f 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -1064,14 +1064,23 @@ def verify(shape, axis, is_ascend, dtype="float32"): @tvm.testing.uses_gpu def test_forward_topk(): - def verify(shape, k, axis, ret_type, is_ascend=False, dtype="float32"): + def verify(shape, k, axis, ret_type, is_ascend=None, dtype="float32"): x_np = np.random.uniform(size=shape).astype("float32") - ref_res = mx.nd.topk( - mx.nd.array(x_np), k=k, axis=axis, ret_typ=ret_type, is_ascend=is_ascend, dtype=dtype - ) - mx_sym = mx.sym.topk( - mx.sym.var("x"), k=k, axis=axis, ret_typ=ret_type, is_ascend=is_ascend, dtype=dtype - ) + if is_ascend is None: + ref_res = mx.nd.topk(mx.nd.array(x_np), k=k, axis=axis, ret_typ=ret_type, dtype=dtype) + mx_sym = mx.sym.topk(mx.sym.var("x"), k=k, axis=axis, ret_typ=ret_type, dtype=dtype) + else: + ref_res = mx.nd.topk( + mx.nd.array(x_np), + k=k, + axis=axis, + ret_typ=ret_type, + is_ascend=is_ascend, + dtype=dtype, + ) + mx_sym = mx.sym.topk( + mx.sym.var("x"), k=k, axis=axis, ret_typ=ret_type, is_ascend=is_ascend, dtype=dtype + ) mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: @@ -1086,7 +1095,7 @@ def verify(shape, k, axis, ret_type, is_ascend=False, dtype="float32"): verify((3, 4), k=1, axis=0, ret_type="both") verify((3, 4), k=1, axis=-1, ret_type="indices") - verify((3, 5, 6), k=2, axis=2, ret_type="value") + verify((3, 5, 6), k=2, axis=2, ret_type="value", is_ascend=False) verify((3, 5, 6), k=2, axis=1, ret_type="value", is_ascend=True) verify((3, 5, 6), k=0, axis=2, ret_type="both", dtype="int32")