Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding align_corners to jax.image.resize #11206

Open
younesbelkada opened this issue Jun 22, 2022 · 3 comments
Open

Adding align_corners to jax.image.resize #11206

younesbelkada opened this issue Jun 22, 2022 · 3 comments
Assignees
Labels
enhancement New feature or request

Comments

@younesbelkada
Copy link

younesbelkada commented Jun 22, 2022

Hi there !

I would like to reproduce the operations that are done under torch.nn.Upsample() with flax 🎉 . In PyTorch, it seems that the flag align_corners doesn't mean "align the corners" but rather "sample with equal spacing" cc @cgarciae ! Can we add this feature in jax.image.resize? 🙏 (Initially posted on google/flax)

Problem description:

Ideally:

import torch
import torch.nn as nn
import jax.numpy as jnp
import jax
import numpy as np

input_arr = jnp.array([
    [0, 1, 2, 3, 4],
    [5, 6, 7, 8, 9],
]) / 9

torch_input_arr = torch.from_numpy(np.array(input_arr)).unsqueeze(0).unsqueeze(0)
upsample_with_align_corners = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
upsample_without_align_corners = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)

output = jax.image.resize(
    image=input_arr,
    shape=(4,10),
    method="bilinear")

output_torch_upsample_with_align_corners = upsample_with_align_corners(torch_input_arr)
output_torch_upsample_without_align_corners = upsample_without_align_corners(torch_input_arr)

print("jax: ", output)
print("torch upsample align corner:", output_torch_upsample_with_align_corners)
print("torch upsample without align corners",output_torch_upsample_without_align_corners)

I would like to match output_torch_upsample_with_align_corners and flax' output. Here is what I get:

# jax default output
DeviceArray([[0.        , 0.02777778, 0.08333334, 0.1388889 , 0.19444445,
              0.25      , 0.30555555, 0.3611111 , 0.4166667 , 0.44444445],
             [0.1388889 , 0.16666667, 0.22222222, 0.2777778 , 0.33333334,
              0.3888889 , 0.44444442, 0.5       , 0.5555556 , 0.5833334 ],
             [0.4166667 , 0.44444448, 0.5       , 0.5555555 , 0.6111111 ,
              0.6666667 , 0.7222222 , 0.7777778 , 0.8333333 , 0.8611111 ],
             [0.5555556 , 0.5833334 , 0.6388889 , 0.6944444 , 0.75      ,
              0.8055556 , 0.8611111 , 0.9166667 , 0.9722222 , 1.        ]],            dtype=float32)
# torch with align corner
tensor([[[[0.0000, 0.0494, 0.0988, 0.1481, 0.1975, 0.2469, 0.2963, 0.3457,
           0.3951, 0.4444],
          [0.1852, 0.2346, 0.2840, 0.3333, 0.3827, 0.4321, 0.4815, 0.5309,
           0.5802, 0.6296],
          [0.3704, 0.4198, 0.4691, 0.5185, 0.5679, 0.6173, 0.6667, 0.7160,
           0.7654, 0.8148],
          [0.5556, 0.6049, 0.6543, 0.7037, 0.7531, 0.8025, 0.8519, 0.9012,
           0.9506, 1.0000]]]])
# torch without align corner
tensor([[[[0.0000, 0.0278, 0.0833, 0.1389, 0.1944, 0.2500, 0.3056, 0.3611,
           0.4167, 0.4444],
          [0.1389, 0.1667, 0.2222, 0.2778, 0.3333, 0.3889, 0.4444, 0.5000,
           0.5556, 0.5833],
          [0.4167, 0.4444, 0.5000, 0.5556, 0.6111, 0.6667, 0.7222, 0.7778,
           0.8333, 0.8611],
          [0.5556, 0.5833, 0.6389, 0.6944, 0.7500, 0.8056, 0.8611, 0.9167,
           0.9722, 1.0000]]]])

Motivation

I am converting Dense Prediction Transformers into flax, and would like to match the output betwen PyTorch's model and flax' implementation! huggingface/transformers#17779

cc @cgarciae , moved the issue here :)

Would love to see if I can contribute as well but I may need more guidance on how to do so!

@brentyi
Copy link

brentyi commented Feb 8, 2023

@younesbelkada FYI I have an implementation of resizing with aligned corners here:
https://github.com/brentyi/pips-jax/blob/8ae798a31caf03dcdd1c237ac39ca5ff13d0ea4d/src/pips_jax/utils_bilerp.py#L60-L83

I just adapted jax.image.scale_and_translate:

def resize_with_aligned_corners(
    image: jax.Array,
    shape: Tuple[int, ...],
    method: Union[str, jax.image.ResizeMethod],
    antialias: bool,
):
    """Alternative to jax.image.resize(), which emulates align_corners=True in PyTorch's
    interpolation functions."""
    spatial_dims = tuple(
        i
        for i in range(len(shape))
        if not jax.core.symbolic_equal_dim(image.shape[i], shape[i])
    )
    scale = jnp.array([(shape[i] - 1.0) / (image.shape[i] - 1.0) for i in spatial_dims])
    translation = -(scale / 2.0 - 0.5)
    return jax.image.scale_and_translate(
        image,
        shape,
        method=method,
        scale=scale,
        spatial_dims=spatial_dims,
        translation=translation,
        antialias=antialias,
    )

ps: are you still interested in getting a Flax implementation of DPT merged? This would be really useful to me.

@Haotian-Zhang
Copy link

Haotian-Zhang commented Apr 26, 2023

Hi, just wondering are there any updates on this? @hawkinsp @froystig. The outputs from jax.image.resize also may be different from F.interpolate when using "bilinear" for 2D inputs.

@johnpjf
Copy link
Contributor

johnpjf commented Feb 1, 2024

FYI Align corners aligns the centers of the corner pixels, this is only really useful because it happens to implement the translation that is necessary to align features from strided convolutions.
The better and more general way to do this is to use the more general function, scale_and_translate, that @brentyi shows above. (jax.image.resize uses jax.image.scale_and_translate underneath anyway).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants