-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding align_corners
to jax.image.resize
#11206
Comments
@younesbelkada FYI I have an implementation of resizing with aligned corners here: I just adapted def resize_with_aligned_corners(
image: jax.Array,
shape: Tuple[int, ...],
method: Union[str, jax.image.ResizeMethod],
antialias: bool,
):
"""Alternative to jax.image.resize(), which emulates align_corners=True in PyTorch's
interpolation functions."""
spatial_dims = tuple(
i
for i in range(len(shape))
if not jax.core.symbolic_equal_dim(image.shape[i], shape[i])
)
scale = jnp.array([(shape[i] - 1.0) / (image.shape[i] - 1.0) for i in spatial_dims])
translation = -(scale / 2.0 - 0.5)
return jax.image.scale_and_translate(
image,
shape,
method=method,
scale=scale,
spatial_dims=spatial_dims,
translation=translation,
antialias=antialias,
) ps: are you still interested in getting a Flax implementation of DPT merged? This would be really useful to me. |
FYI Align corners aligns the centers of the corner pixels, this is only really useful because it happens to implement the translation that is necessary to align features from strided convolutions. |
Hi there !
I would like to reproduce the operations that are done under
torch.nn.Upsample()
withflax
🎉 . In PyTorch, it seems that the flagalign_corners
doesn't mean "align the corners" but rather "sample with equal spacing" cc @cgarciae ! Can we add this feature injax.image.resize
? 🙏 (Initially posted ongoogle/flax
)Problem description:
Ideally:
I would like to match
output_torch_upsample_with_align_corners
and flax'output
. Here is what I get:Motivation
I am converting Dense Prediction Transformers into
flax
, and would like to match the output betwen PyTorch's model and flax' implementation! huggingface/transformers#17779cc @cgarciae , moved the issue here :)
Would love to see if I can contribute as well but I may need more guidance on how to do so!
The text was updated successfully, but these errors were encountered: