Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent result to NumPy and PyTorch when consecutively casting a float tensor to int32 and then to bool #11994

Open
lazycal opened this issue Jun 26, 2022 · 2 comments
Labels
core runtime issues related to core runtime

Comments

@lazycal
Copy link

lazycal commented Jun 26, 2022

Describe the bug
[-0.2, -0.1, 0, 0.1, 0.2].cast(int32).cast(bool) returns [ True, True, False, True, True] in ORT, but should be [ False, False, False, False, False].

Some dataponts:

  • When I also return the intermediate result of cast(int32) as the model output, the problem disappears.
  • ORT's result is the same as directly casting to bool.

So I guess ORT does a fusion optimization here that mistakenly rewrite cast(int32).cast(bool) to cast(bool).

Urgency
None

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 18.04
  • ONNX Runtime installed from (source or binary): binary
  • ONNX Runtime version: 1.11.1
  • Python version: 3.7.11
  • Visual Studio version (if applicable):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version: 11.2
  • GPU model and memory: RTX 2080, 8GB

To Reproduce
Run the following script.

import torch
import onnx
import onnxruntime as ort
import numpy as np


for ret_intermediate in [True, False]:
    print(f'----->Testing model with ret_intermediate={ret_intermediate}...')

    class Model(torch.nn.Module):
        def forward(self, x):
            y = x.to(torch.int32)
            if ret_intermediate:
                return y.to(torch.bool), y
            else:
                return (y.to(torch.bool), )

    model = Model()
    x = torch.tensor([-0.2, -0.1, 0, 0.1, 0.2])
    output_names = ['o0', 'o1'] if ret_intermediate else ['o0']
    model_name = "output.onnx"
    torch.onnx.export(model, (x,), model_name,
                      input_names=["i0"], output_names=output_names, opset_version=14)
    onnx_model = onnx.load(model_name)
    onnx.checker.check_model(onnx_model, full_check=True)

    sess_options = ort.SessionOptions()
    sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
    sess = ort.InferenceSession("output.onnx", sess_options=sess_options, providers=[
        'CPUExecutionProvider'])
    y_ort = sess.run(output_names, {'i0': x.numpy()})[0]

    y_tch = model(x)[0]
    y_np = x.numpy().astype(np.int32).astype(bool)
    np.testing.assert_allclose(y_np, y_tch, rtol=1e-02,
                               atol=1e-02, err_msg='numpy vs torch')  # always pass
    np.testing.assert_allclose(y_ort, y_tch, rtol=1e-02,
                               atol=1e-02, err_msg='ort vs torch')  # failed
    print('----->Passed')

Expected behavior
Be consistent to NumPy and PyTorch.

Screenshots
image
Additional context

@fefe982
Copy link
Contributor

fefe982 commented Jun 30, 2022

May be caused by the RemoveDuplicateCastTransformer called at the end of InsertCastTransformer::ApplyImpl, which is called by InferenceSession::TransformGraph.

@lazycal
Copy link
Author

lazycal commented Jul 2, 2022

Thanks @fefe982 for the datapoint. That does sound probable. I read the code and I guess ORT considers my case as high->low->lower precision so it removes the first cast?

// Other cases are OK for this optimization, including below two cases,
// which are not actual loss of precision:
// - (low precision -> high precision ->low precision)
// - (high precision -> low precision -> lower precision)

So maybe having a better function to determine if this is a loss of precision can solve this problem? The current way does not make sense to me:
if (src_type_group > dst_type_group) {
loss_precision_cast = true;
}

And these are the type group:
enum TypeGroup {
Unknown = -1,
Bool = 0,
Integer = 1,
Float = 2,
};

It does not account for float64->float32 kind of things, and it loses precision when casting int32 to float32 but is considered not here.

@sophies927 sophies927 added core runtime issues related to core runtime and removed component:operator labels Aug 12, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
core runtime issues related to core runtime
Projects
None yet
Development

No branches or pull requests

4 participants