-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Notice a different between jax.image.resize and F.interpolate when using "bicubic" #15768
Comments
I'm not sure what's up here. The output of JAX exactly matches Pillow and TensorFlow for the same bicubic resize, so I don't think we're doing something wrong here, but PyTorch must be using a different convention for something as opposed to all the other systems. |
@johnpjf I could use your input here. I looked into this a bit more. There are two groups of behaviors for cubic upsampling:
I think there are two different things happening here: JAX and PyTorch use different cubic interpolation kernelsJAX/pillow/TF use the Keys cubic kernel with A = -0.5 (https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm). PyTorch/OpenCV uses a cubic kernel with A = -0.75. At least some users consider this a bug (opencv/opencv#17720 ), and It's actually quite easy to get JAX to use the PyTorch/OpenCV kernel, if we use:
This solves the mismatch except for a band of 3-4 pixels near the edge of the image. This is because: JAX and PyTorch use different approaches for edge paddingCubic interpolation requires sampling pixel values from outside the original input image. JAX/pillow/TF handle this by truncating the convolutional kernel and rescaling it to keep its total weight 1. I believe PyTorch handles this by edge-padding (i.e., repeating the value on the edge). This leads to slightly different values in a 3-4 pixel band near each edge. I'll look into this a bit further, we can probably replicate the PyTorch behavior (optionally) with a bit more work. Footnotes
|
I'm curious if it's actually important to you that we replicate PyTorch's behavior exactly, or you merely noticed it is different and were curious about it. |
Hi @hawkinsp Thanks for digging into these differences! I actually tried to reproduce some research works in JAX and noticed there are some performance degrades, then I started to look into some lower-level implementations and found out this : ) |
@hawkinsp I don't know about pytorch but your description of JAX's resize sounds correct, and yes there are different choices for the bicuibic kernel. The other thing to keep in mind is that when downsampling JAX's resize will do anti-aliasing by default. |
@johnpjf I'm particularly curious about the edge padding behavior. How do the OpenCV/PyTorch behaviors compare with what JAX/TF/pillow do? |
I should add: it's actually slightly annoying to mimic the PyTorch edge padding behavior because of the antialiasing we do on downsampling. The antialiasing means the kernel can be of an arbitrary width measured in input pixels. |
TF is the same: |
Hi @hawkinsp , thanks for digging into this and the great write up. Thanks! |
Hi @hawkinsp, it would be really nice if this work is continued. I'm trying to convert a pytorch model to flax and this "Bicubic interpolate" layer is the only part that prevents the exact replication of the model. |
Description
Hi, I am trying to convert a Pytorch model to JAX model, and I find there are some implementations different in "bicubic".
Here is a script I have used to confirm that the outputs are different.
from the prints I get:
I tried some other resizing methods. e.g., linear and bilinear, and they look fine. Does anyone have a workaround for it?
What jax/jaxlib version are you using?
No response
Which accelerator(s) are you using?
No response
Additional system info
No response
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: