Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scale_and_translate with output_size and scale returning zero values for last row and column #24106

Open
barney-s opened this issue Oct 3, 2024 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@barney-s
Copy link

barney-s commented Oct 3, 2024

Description

Trying to replicate the behavior of pytorch upsample_bilinear2d using jax.image.scale_and_translate.

This is as part of the pytorch xla project: pytorch/xla#7389
We did some investigation as part of: pytorch/xla#8208

The behavior between pytorch upsample_bilinear(align_corners=true, shape=something) and our implementation using jax scale_and_translate is different. Since jax does not have an implementation that supports align_corners=true we are implementing ourselves using existing functions. Existing bug in jax: #11206

Look here for the difference in output as well as a script to reproduce it:
pytorch/xla#8208 (comment)

when i change jax._src..image.scale.compute_weight_mat to return weights without zeroing, the output matches.
Here is the script that copied the relevant jax code to investigate the behaviour:
pytorch/xla#8208 (comment)

We are not sure which behaviour is correct. All we know is how to make jax replicate pytorch's behavior.

If you think the jax behaviour needs to be changed, i have a patch for that. Let me know your thoughts.

System info (python version, jaxlib version, accelerator, etc.)

% python
Python 3.10.14 (main, May  6 2024, 19:42:50) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax; jax.print_environment_info()
jax:    0.4.29
jaxlib: 0.4.29
numpy:  1.26.4
python: 3.10.14 (main, May  6 2024, 19:42:50) [GCC 11.2.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='*********.com', release='6.9.10-1rodete5-amd64', version='#1 SMP PREEMPT_DYNAMIC Debian 6.9.10-1rodete5 (2024-09-04)', machine='x86_64')

>>> 
@barney-s barney-s added the bug Something isn't working label Oct 3, 2024
barney-s added a commit to barney-s/xla that referenced this issue Oct 4, 2024
…sample_bilinear2d_aa

The jax implemntation's output does not match torch's output.
We are reimplenting scale_and_translate without zeroing weights.
related Jax bug: jax-ml/jax#24106
@dfm dfm self-assigned this Oct 4, 2024
@dfm
Copy link
Collaborator

dfm commented Oct 4, 2024

Thanks for your report! Related questions have come up a few times in the past. For example, this seems related to the comment here: #15768 (comment)

My understanding from that thread and your comments here is that to core issue is related to the handling of edge effects in the kernel. I'd say that we probably don't want to unilaterally change the behavior of this function in JAX because there are almost certainly users depending on the specifics of this implementation. That being said, it does sound like it would be useful to add an option to change how the edge effects are handled. Would you be up for trying to implement something like that in your PR? Thanks!

barney-s added a commit to barney-s/xla that referenced this issue Oct 4, 2024
…sample_bilinear2d_aa

The jax implemntation's output does not match torch's output.
We are reimplenting scale_and_translate without zeroing weights.
related Jax bug: jax-ml/jax#24106
barney-s added a commit to barney-s/xla that referenced this issue Oct 4, 2024
…sample_bilinear2d_aa

The jax implemntation's output does not match torch's output.
We are reimplenting scale_and_translate without zeroing weights.
related Jax bug: jax-ml/jax#24106
barney-s added a commit to barney-s/xla that referenced this issue Oct 4, 2024
…sample_bilinear2d_aa

The jax implemntation's output does not match torch's output.
We are reimplenting scale_and_translate without zeroing weights.
related Jax bug: jax-ml/jax#24106
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants