Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Notice a different between jax.image.resize and F.interpolate when using "bicubic" #15768

Open
Haotian-Zhang opened this issue Apr 27, 2023 · 10 comments
Assignees
Labels
bug Something isn't working

Comments

@Haotian-Zhang
Copy link

Haotian-Zhang commented Apr 27, 2023

Description

Hi, I am trying to convert a Pytorch model to JAX model, and I find there are some implementations different in "bicubic".

Here is a script I have used to confirm that the outputs are different.

import torch
import torch.nn as nn
import jax.numpy as jnp
import jax
import numpy as np

input_arr = jnp.array([
    [0, 1, 2, 3, 4],
    [5, 6, 7, 8, 9],
]) / 9

torch_input_arr = torch.from_numpy(np.array(input_arr)).unsqueeze(0).unsqueeze(0)
output = jax.image.resize(
    image=input_arr,
    shape=(4,10),
    method="bicubic")

interpolate_without_align_corners = torch.nn.functional.interpolate(
    torch_input_arr.float(),
    size=(4, 10),
    mode="bicubic",
    align_corners=False,
)

print("jax: ", output)
print("torch interpolate without align corners",interpolate_without_align_corners)

from the prints I get:

jax:  [[-0.05882353 -0.03036592  0.0298608   0.08986928  0.14542481  0.20098035
   0.25653598  0.31654444  0.37677112  0.40522873]
 [ 0.10527544  0.13373306  0.19395977  0.25396827  0.30952382  0.36507937
   0.42063496  0.48064342  0.54087013  0.5693277 ]
 [ 0.4306723   0.45912993  0.5193566   0.57936513  0.6349207   0.69047624
   0.74603176  0.8060403   0.86626697  0.89472455]
 [ 0.5947712   0.62322885  0.6834555   0.74346405  0.79901963  0.85457516
   0.9101307   0.9701392   1.0303658   1.0588235 ]]
torch interpolate without align corners tensor([[[[-0.0703, -0.0373,  0.0156,  0.0855,  0.1306,  0.1966,  0.2418,
            0.3116,  0.3646,  0.3976],
          [ 0.1141,  0.1471,  0.2001,  0.2700,  0.3151,  0.3811,  0.4262,
            0.4961,  0.5490,  0.5820],
          [ 0.4180,  0.4510,  0.5039,  0.5738,  0.6189,  0.6849,  0.7300,
            0.7999,  0.8529,  0.8859],
          [ 0.6024,  0.6354,  0.6884,  0.7582,  0.8034,  0.8694,  0.9145,
            0.9844,  1.0373,  1.0703]]]])

I tried some other resizing methods. e.g., linear and bilinear, and they look fine. Does anyone have a workaround for it?

What jax/jaxlib version are you using?

No response

Which accelerator(s) are you using?

No response

Additional system info

No response

NVIDIA GPU info

No response

@Haotian-Zhang Haotian-Zhang added the bug Something isn't working label Apr 27, 2023
@hawkinsp
Copy link
Collaborator

I'm not sure what's up here. The output of JAX exactly matches Pillow and TensorFlow for the same bicubic resize, so I don't think we're doing something wrong here, but PyTorch must be using a different convention for something as opposed to all the other systems.

@hawkinsp
Copy link
Collaborator

hawkinsp commented May 1, 2023

@johnpjf I could use your input here.

I looked into this a bit more. There are two groups of behaviors for cubic upsampling:

  • pillow, JAX, Tensorflow all agree (almost exactly)
  • PyTorch (with align_corners=False) and OpenCV agree (almost exactly)

I think there are two different things happening here:

JAX and PyTorch use different cubic interpolation kernels

JAX/pillow/TF use the Keys cubic kernel with A = -0.5 (https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm).

PyTorch/OpenCV uses a cubic kernel with A = -0.75. At least some users consider this a bug (opencv/opencv#17720 ), and
that set of parameters suffers from ringing artifacts 1.

It's actually quite easy to get JAX to use the PyTorch/OpenCV kernel, if we use:

def _fill_pytorch_cubic_kernel(x):
  out = ((1.25 * x - 2.25) * x) * x + 1.
  out = jnp.where(x >= 1., ((-0.75 * x + 5*0.75) * x - 8 * 0.75) * x + 4*0.75, out)
  return jnp.where(x >= 2., 0., out)

in place of https://github.com/google/jax/blob/e51d12cdef0fd118e6f8cc3357cd8f95a68a79fa/jax/_src/image/scale.py#L36

This solves the mismatch except for a band of 3-4 pixels near the edge of the image. This is because:

JAX and PyTorch use different approaches for edge padding

Cubic interpolation requires sampling pixel values from outside the original input image. JAX/pillow/TF handle this by truncating the convolutional kernel and rescaling it to keep its total weight 1. I believe PyTorch handles this by edge-padding (i.e., repeating the value on the edge). This leads to slightly different values in a 3-4 pixel band near each edge.

I'll look into this a bit further, we can probably replicate the PyTorch behavior (optionally) with a bit more work.

Footnotes

  1. Mitchell, D.P. and Netravali, A.N., 1988. Reconstruction filters in computer-graphics. ACM Siggraph Computer Graphics, 22(4), pp.221-228.

@hawkinsp hawkinsp self-assigned this May 1, 2023
@hawkinsp
Copy link
Collaborator

hawkinsp commented May 1, 2023

I'm curious if it's actually important to you that we replicate PyTorch's behavior exactly, or you merely noticed it is different and were curious about it.

@Haotian-Zhang
Copy link
Author

Hi @hawkinsp Thanks for digging into these differences! I actually tried to reproduce some research works in JAX and noticed there are some performance degrades, then I started to look into some lower-level implementations and found out this : )

@johnpjf
Copy link
Contributor

johnpjf commented May 2, 2023

@hawkinsp I don't know about pytorch but your description of JAX's resize sounds correct, and yes there are different choices for the bicuibic kernel. The other thing to keep in mind is that when downsampling JAX's resize will do anti-aliasing by default.

@hawkinsp
Copy link
Collaborator

hawkinsp commented May 2, 2023

@johnpjf I'm particularly curious about the edge padding behavior. How do the OpenCV/PyTorch behaviors compare with what JAX/TF/pillow do?

@hawkinsp
Copy link
Collaborator

hawkinsp commented May 2, 2023

I should add: it's actually slightly annoying to mimic the PyTorch edge padding behavior because of the antialiasing we do on downsampling. The antialiasing means the kernel can be of an arbitrary width measured in input pixels.

@johnpjf
Copy link
Contributor

johnpjf commented May 2, 2023

TF is the same:
https://chromium.googlesource.com/external/github.com/tensorflow/tensorflow/+/refs/heads/master/tensorflow/core/kernels/image/scale_and_translate_op.cc#103
I think we could mimic the PyTorch edge behavior if we wanted to, no, we'd just shift the kernel weight to that last pixel and make it zero outside? That said, I'm skeptical this will cause large enough differences to break models.

@gokul-uf
Copy link

gokul-uf commented Mar 23, 2024

Hi @hawkinsp , thanks for digging into this and the great write up.
I have a similar issue but I want to replicate JAX's behaviour in Torch.
Based on your findings, is it possible to do this with an appropriate set of args to jax.image.resize?

Thanks!

@MHRDYN7
Copy link

MHRDYN7 commented Jul 24, 2024

@johnpjf I could use your input here.

I looked into this a bit more. There are two groups of behaviors for cubic upsampling:

  • pillow, JAX, Tensorflow all agree (almost exactly)
  • PyTorch (with align_corners=False) and OpenCV agree (almost exactly)

I think there are two different things happening here:

JAX and PyTorch use different cubic interpolation kernels

JAX/pillow/TF use the Keys cubic kernel with A = -0.5 (https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm).

PyTorch/OpenCV uses a cubic kernel with A = -0.75. At least some users consider this a bug (opencv/opencv#17720 ), and that set of parameters suffers from ringing artifacts 1.

It's actually quite easy to get JAX to use the PyTorch/OpenCV kernel, if we use:

def _fill_pytorch_cubic_kernel(x):
  out = ((1.25 * x - 2.25) * x) * x + 1.
  out = jnp.where(x >= 1., ((-0.75 * x + 5*0.75) * x - 8 * 0.75) * x + 4*0.75, out)
  return jnp.where(x >= 2., 0., out)

in place of

https://github.com/google/jax/blob/e51d12cdef0fd118e6f8cc3357cd8f95a68a79fa/jax/_src/image/scale.py#L36

This solves the mismatch except for a band of 3-4 pixels near the edge of the image. This is because:

JAX and PyTorch use different approaches for edge padding

Cubic interpolation requires sampling pixel values from outside the original input image. JAX/pillow/TF handle this by truncating the convolutional kernel and rescaling it to keep its total weight 1. I believe PyTorch handles this by edge-padding (i.e., repeating the value on the edge). This leads to slightly different values in a 3-4 pixel band near each edge.

I'll look into this a bit further, we can probably replicate the PyTorch behavior (optionally) with a bit more work.

Footnotes

  1. Mitchell, D.P. and Netravali, A.N., 1988. Reconstruction filters in computer-graphics. ACM Siggraph Computer Graphics, 22(4), pp.221-228.

Hi @hawkinsp, it would be really nice if this work is continued. I'm trying to convert a pytorch model to flax and this "Bicubic interpolate" layer is the only part that prevents the exact replication of the model.
Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants