Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resize with nearest mode have inconsistent results compared to PyTorch and TVM #12098

Open
lazycal opened this issue Jul 6, 2022 · 4 comments
Labels
core runtime issues related to core runtime

Comments

@lazycal
Copy link

lazycal commented Jul 6, 2022

Describe the bug
The result is one pixel off:

ORT=   [... 12. 12. 12. 13. ...]
Torch= [... 12. 12. 13. 13. ...]

when resizing a tensor a with shape [1, 1, 1, 1, 26] and values a[...,i]=i to [1, 1, 1, 1, 64] with nearest mode.

Urgency
None

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 18.04
  • ONNX Runtime installed from (source or binary): binary
  • ONNX Runtime version: 1.11.1
  • Python version: 3.7.11
  • Visual Studio version (if applicable):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version:
  • GPU model and memory:

To Reproduce

import torch
from onnx import checker
import onnxruntime as ort
import numpy as np
ash = [1, 1, 1, 1, 26]
bsh = [1, 1, 1, 1, 64]


class Model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    @torch.no_grad()
    def forward(self, a):
        b = torch.nn.functional.interpolate(
            a, size=bsh[2:], mode='nearest')
        return b


anp = np.reshape(np.arange(26), ash).astype(np.float32)
model = Model()
model.eval()
torch.onnx.export(model, (torch.from_numpy(anp),), "output.onnx",
                  input_names=["i0"], output_names=["o0"], opset_version=14)
checker.check_model("output.onnx", full_check=True)
print('model checked')
b_tch = model(torch.from_numpy(anp)).numpy()

sess_opt = ort.SessionOptions()
sess_opt.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess = ort.InferenceSession("output.onnx", sess_opt, providers=[
                            "CPUExecutionProvider"])
b_ort_gpu = sess.run(["o0"], {"i0": anp})[0]


from tvm import relay
import onnx
mod, param = relay.frontend.from_onnx(onnx.load("output.onnx"))
b_tvm = relay.create_executor(mod=mod).evaluate()(
    **{"i0": anp}).numpy()
np.testing.assert_allclose(
    b_tch, b_tvm, err_msg="torch vs tvm", atol=1e-2, rtol=1e-2)
print('----------------> torch is consistent with tvm')
print('ORT=  ', b_ort_gpu[0, 0, 0, 0, 30:34])
print('Torch=', b_tch[0, 0, 0, 0, 30:34])
torch.testing.assert_allclose(
    b_ort_gpu, b_tvm, atol=1e-2, rtol=1e-2)
print('pass')

The model file: model.onnx.zip

Expected behavior
They should be consistent either TVM and PyTorch are both wrong or ORT is wrong.

Screenshots
image

Additional context
Initiall I thought related to #12091, but now it looks like a separate issue as the inconsistent value is not constant, but just the value from the pixel right next.

@lazycal lazycal changed the title Nearest result inconsistent to PyTorch and TVM Resize with nearest mode have inconsistent results compared to PyTorch and TVM Jul 6, 2022
@YUNQIUGUO YUNQIUGUO added core runtime issues related to core runtime component:operator labels Jul 6, 2022
@ytaous
Copy link
Contributor

ytaous commented Jul 12, 2022

@hariharans29 @yihonglyu - any comment/feedback is more than welcome.

@yihonglyu
Copy link
Contributor

yihonglyu commented Jul 12, 2022

@lazycal Is this issue on GPU? If so, could you provide system information?

@lazycal
Copy link
Author

lazycal commented Jul 12, 2022

Hi @yihonglyu, this is on CPU:

sess = ort.InferenceSession("output.onnx", sess_opt, providers=[
                            "CPUExecutionProvider"])

Sorry for my bad variable naming.

@AsiaCao
Copy link

AsiaCao commented Feb 1, 2023

would appreciate any update on this issue.
i also encountered a similar issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
core runtime issues related to core runtime
Projects
None yet
Development

No branches or pull requests

6 participants