diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index b1b01b87f715..897c6a022594 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1451,7 +1451,7 @@ def _impl_v1(cls, inputs, attr, params): axis = attr.get("axis", 0) keepdims = attr.get("keepdims", True) attr = {"axis": axis, "keepdims": keepdims} - return AttrCvt("argmax")(inputs, attr) + return _op.cast(AttrCvt("argmax")(inputs, attr), "int64") class ArgMin(OnnxOpConverter): @@ -1462,7 +1462,7 @@ def _impl_v1(cls, inputs, attr, params): axis = attr.get("axis", 0) keepdims = attr.get("keepdims", True) attr = {"axis": axis, "keepdims": keepdims} - return AttrCvt("argmin")(inputs, attr) + return _op.cast(AttrCvt("argmin")(inputs, attr), "int64") class Softmax(OnnxOpConverter): @@ -2000,7 +2000,7 @@ def _impl_v1(cls, inputs, attr, params): if largest == 0: raise ValueError("TVM only supports finding TopK largest elements") - return _op.topk(inputs[0], inputs[1], axis=axis) + return _op.topk(inputs[0], inputs[1], axis=axis, dtype="int64") class Range(OnnxOpConverter): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index c666604d0e89..56d1dd5a5265 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -163,6 +163,7 @@ def verify_with_ort_with_inputs( ort_val = scipy.special.softmax(ort_val) tvm_val = scipy.special.softmax(tvm_val) tvm.testing.assert_allclose(ort_val, tvm_val, rtol=rtol, atol=atol) + assert ort_val.dtype == tvm_val.dtype def verify_with_ort(