Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resize with mode linear always produces 0.5 on GPU regardless of the input #12091

Open
lazycal opened this issue Jul 5, 2022 · 7 comments
Open
Labels
core runtime issues related to core runtime

Comments

@lazycal
Copy link

lazycal commented Jul 5, 2022

Describe the bug
In a model with Linear layer followed by a trilinear resize like the graph below, the result on GPU is always 0.5 for any inputs, which is different than the result on CPU and than the result from PyTorch, while the latter two equal to each other.

image
The corresponding torch code:

class Model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = torch.nn.Linear(1, 1)

    @torch.no_grad()
    def forward(self, x):
        x = self.linear(x)
        x = torch.nn.functional.interpolate(
            x, size=[511, 1, 1], mode='trilinear')
        return x

Maybe related to #12019?
cc the participants there @diyessi @hariharans29. Though I don't have problems on 4D tensor (i.e., bilinear), and mine is on GPU but that one seems to be on CPU? nearest mode appears to be fine too for me. After removing the Linear node the problem also disappears.

Urgency
None

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu18.04
  • ONNX Runtime installed from (source or binary): binary
  • ONNX Runtime version: 1.11.1
  • Python version: 3.7.11
  • Visual Studio version (if applicable):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version: 11.2
  • GPU model and memory: RTX2080 8GB

To Reproduce
Run this code

import torch
from onnx import checker
import onnxruntime as ort
import numpy as np


class Model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = torch.nn.Linear(1, 1)

    @torch.no_grad()
    def forward(self, x):
        x = self.linear(x)
        x = torch.nn.functional.interpolate(
            x, size=[511, 1, 1], mode='trilinear')
        return x


x = torch.randn(1, 1, 1, 1, 1).to(torch.float32) * 100
model = Model()
model.eval()
torch.onnx.export(model, (x,), "output.onnx",
                  input_names=["x"], output_names=["y"], opset_version=14)
checker.check_model("output.onnx", full_check=True)
print('model checked')
b_tch = model(x)


sess_opt = ort.SessionOptions()
sess_opt.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess = ort.InferenceSession("output.onnx", sess_opt, providers=[
                            "CPUExecutionProvider"])
b_ort_cpu = sess.run(["y"], {"x": x.numpy()})[0]
np.testing.assert_allclose(
    b_ort_cpu, b_tch, err_msg="ort_cpu vs torch", atol=1e-2, rtol=1e-2)
print('-------------> ort_cpu is consistent to pytorch')


sess_opt = ort.SessionOptions()
sess_opt.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess = ort.InferenceSession("output.onnx", sess_opt, providers=[
                            "CUDAExecutionProvider",
                            "CPUExecutionProvider"])
b_ort_gpu = sess.run(["y"], {"x": x.numpy()})[0]
np.testing.assert_allclose(
    b_ort_gpu, b_tch, err_msg="ort_gpu vs torch", atol=1e-2, rtol=1e-2)
print('pass')

ONNX model: model.onnx.zip

Expected behavior
Generate consistent result.

Screenshots
image

Additional context
None

** UPDATE **
Below I pasted in the code in pure ONNX without using PyTorch, as PyTorch may have bugs in resize related nodes. The issues still remains.

import numpy as np
import onnxruntime as ort
import onnx
from onnx import helper, checker
from onnx import TensorProto

ash = [1, 1, 1, 1, 1]
bsh = [1, 1, 511, 1, 3]

a = helper.make_tensor_value_info('a', TensorProto.FLOAT, ash)
b = helper.make_tensor_value_info('b', TensorProto.FLOAT, bsh)
sizes = np.array(bsh, dtype=np.int64)
w = np.random.rand(1, 1).astype(np.float32)
bias = np.random.rand(1).astype(np.float32)


node_matmul = onnx.helper.make_node(
    'MatMul',
    inputs=['a', 'w'],
    outputs=['t0'],
)

node_add = onnx.helper.make_node(
    'Add',
    inputs=['t0', 'bias'],
    outputs=['t'],
)

node = onnx.helper.make_node(
    'Resize',
    inputs=['t', '', '', 'sizes'],
    outputs=['b'],
    mode='linear',
    coordinate_transformation_mode='half_pixel'
)

graph_def = helper.make_graph(
    [node_matmul, node_add, node],
    'test-model',
    [a],
    outputs=[b],
    initializer=[
        helper.make_tensor('sizes', TensorProto.INT64, [5], sizes),
        helper.make_tensor('w', TensorProto.FLOAT, [1, 1], w),
        helper.make_tensor('bias', TensorProto.FLOAT, [1], bias)
    ]
)
model_def = helper.make_model(
    graph_def, producer_name='onnx-example')

print('The model is:\n{}'.format(model_def))
checker.check_model(model_def, full_check=True)
print('The model is checked!')

onnx.save(model_def, 'output.onnx')

x = np.random.randn(1, 1, 1, 1, 1).astype(np.float32)
sess_opt = ort.SessionOptions()
sess_opt.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess = ort.InferenceSession("output.onnx", sess_opt, providers=[
                            "CPUExecutionProvider"])
b_ort_cpu = sess.run(["b"], {"a": x})[0]

sess_opt = ort.SessionOptions()
sess_opt.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess = ort.InferenceSession("output.onnx", sess_opt, providers=[
                            "CUDAExecutionProvider",
                            "CPUExecutionProvider"])
b_ort_gpu = sess.run(["b"], {"a": x})[0]
np.testing.assert_allclose(
    b_ort_gpu, b_ort_cpu, err_msg="ort_gpu vs torch", atol=1e-2, rtol=1e-2)
print('pass')

** UPDATE 2 **
So nearest mode is also problematic. See this model: [model.onnx.zip]

Use this code to reproduce:

** UPDATE 3 **
After looking it closely it seems like a separate issue, reported in #12098

@ytaous
Copy link
Contributor

ytaous commented Jul 12, 2022

@hariharans29 @yihonglyu - any comment/feedback is more than welcome.

@shgoyal33
Copy link

@lazycal I tried recreating the onnx flow using your code but I'm not getting the output from the Resize block when I look at it in netron. Would really appreciate your help here.

@lazycal
Copy link
Author

lazycal commented Dec 8, 2022

@shgoyal33 Do you mean that you did not get the 2-branch structure like below
image
from my pure ONNX code? That kind of structure is an artifact because of PyTorch's conversion and is not related to this issue. On my end my ONNX code procudes this graph below with the same error:
image
image

The onnx model:
output.onnx 2.zip

@shgoyal33
Copy link

@lazycal Hi, actually the problem which I'm facing is slightly different but your code and the graph helped a lot and the problem is as follows. When I tried to run your code and open the onnx file in netron. The output I'm getting is this.
image

According to your code and the image of the graph you share you are getting the final shape in Resize Layer but when I go to the properties I'm getting [unknown_dim_0,unknown_dim_1,unknown_dim_2,unknown_dim_3] but in your onnx file the output from the resize block is [1,1,511,1,1] and the graph looks like this.
image

I want the dimension to be displayed between the Resize and output y. I ran the code which you gave so what else did you change in your code to generate the output dimension from the Resize Block?

@lazycal
Copy link
Author

lazycal commented Dec 8, 2022

@shgoyal33 No I did not do anything else. I guess it's because of the PyTorch version difference. I am using this version: "1.13.0a0+git018d071". I forgot how I installed it. Could be built-from-source. Anyway the other code snippet is pure-ONNX which does not use PyTorch and is able to reproduce the same issue. Maybe you could use that instead.

@frankmanbb
Copy link

any update on this issue, I encoutered the same issue when add a resize layer in pytorch:

    image = F.interpolate(image.unsqueeze(0), size=(self.resize_width_height[1], 
                        self.resize_width_height[0]),
                        mode='bilinear',
                        recompute_scale_factor=False,
                            align_corners=False)

the result onnx model has issue in inference time. cpu inference is totally ok, gpu inference has 0.5 output ( unless the input size is same as target resize size, in which case, I guess the code just copy the input and avoid the all 0.5 output).

when I change mode='bicubic', everything works well

I guess there is a bug in gpu implementation of Resize layer.

@frankmanbb
Copy link

update: change image type from uint8 to float32 by calling image = image.float() can also solve the issue. so apparently, the uint8 image Resize under cuda has some bug.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
core runtime issues related to core runtime
Projects
None yet
Development

No branches or pull requests

6 participants