Skip to content

Commit

Permalink
Use jax.image.scale_and_translate instead of jax.image.resize for _up…
Browse files Browse the repository at this point in the history
…sample_bilinear2d_aa

The jax implemntation's output does not match torch's output.
We are reimplenting scale_and_translate without zeroing weights.
related Jax bug: jax-ml/jax#24106
  • Loading branch information
barney-s committed Oct 4, 2024
1 parent cc631b9 commit 17134a7
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 22 deletions.
2 changes: 1 addition & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# such as 0 to negative power.
"_segment_reduce",
"bincount", # NOTE: dtype for int input torch gives float. This is weird.
"_upsample_bilinear2d_aa", # test passing scales_h, scales_w is failing.
#"_upsample_bilinear2d_aa", # test passing scales_h, scales_w is failing.
"byte",
"cat",
"cauchy",
Expand Down
44 changes: 23 additions & 21 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torch_xla2.ops import ops_registry
from torch_xla2.ops import op_base, mappings
from torch_xla2 import interop
from torch_xla2.ops import jax_reimplement

# Keys are OpOverload, value is a callable that takes
# XLATensor2
Expand Down Expand Up @@ -4192,38 +4193,39 @@ def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factor
shape[-1] = output_size[-1]
shape[-2] = output_size[-2]

# align_corners is not supported in resize()
# https://github.com/jax-ml/jax/issues/11206
if align_corners:
return resize_with_aligned_corners2d(image, shape, scale_factors, method, antialias=True)
return jax.image.resize(image, shape, method, antialias) # precision=Precision.HIGHEST

# From: https://github.com/jax-ml/jax/issues/11206
def resize_with_aligned_corners2d(
image: jax.Array,
shape: Tuple[int, ...],
scale: Tuple[int, ...],
method: Union[str, jax.image.ResizeMethod],
antialias: bool,
):
"""Alternative to jax.image.resize(), which emulates align_corners=True in PyTorch's
interpolation functions."""

# pytorch upsample_bilinear returns the input as is when the shape is the same as input
if shape == list(image.shape):
return image

spatial_dims = (2,3)
if len(shape) == 3:
spatial_dims = (1,2)

scale = jnp.array([(shape[i] - 1.0) / (image.shape[i] - 1.0) for i in spatial_dims])
scale = list([shape[i] / image.shape[i] for i in spatial_dims])
if scale_factors:
scale = scale_factors
if scales_h:
scale[0] = scales_h
if scales_w:
scale[1] = scales_w
scale = jnp.array(scale)

# align_corners is not supported in resize()
# https://github.com/jax-ml/jax/issues/11206
if align_corners:
scale = jnp.array([(shape[i] - 1.0) / (image.shape[i] - 1.0) for i in spatial_dims])

translation = jnp.array([0 for i in spatial_dims])
#translation = (scale / 2.0 - 0.5)
translation = (scale * 0.0 )

return jax.image.scale_and_translate(
#return jax.image.scale_and_translate(
# local copied fixed implentation of scale_and_translate
return jax_reimplement.scale_and_translate(
image,
shape,
method=method,
scale=scale,
spatial_dims=spatial_dims,
translation=translation,
antialias=antialias,
)
)
168 changes: 168 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jax_reimplement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@

from collections.abc import Sequence
from jax._src.numpy.util import promote_dtypes_inexact
import numpy as np
import jax
from jax import numpy as jnp
from jax._src.util import canonicalize_axis
from jax._src import core
from jax._src.image.scale import _kernels, ResizeMethod
from jax import lax
from typing import Callable

# TODO: This block of code needs to be revisited based on https://github.com/jax-ml/jax/issues/24106
# START ----------------- JAX code copied for fixing scale_and_translate -----------------------------

# JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L52

def compute_weight_mat(input_size: core.DimSize,
output_size: core.DimSize,
scale,
translation,
kernel: Callable,
antialias: bool):
dtype = jnp.result_type(scale, translation)
inv_scale = 1. / scale
# When downsampling the kernel should be scaled since we want to low pass
# filter and interpolate, but when upsampling it should not be since we only
# want to interpolate.
kernel_scale = jnp.maximum(inv_scale, 1.) if antialias else 1.
sample_f = ((jnp.arange(output_size, dtype=dtype) + 0.5) * inv_scale -
translation * inv_scale - 0.5)
x = (
jnp.abs(sample_f[jnp.newaxis, :] -
jnp.arange(input_size, dtype=dtype)[:, jnp.newaxis]) /
kernel_scale)
weights = kernel(x)

total_weight_sum = jnp.sum(weights, axis=0, keepdims=True)
weights = jnp.where(
jnp.abs(total_weight_sum) > 1000. * float(np.finfo(np.float32).eps),
jnp.divide(weights, jnp.where(total_weight_sum != 0, total_weight_sum, 1)),
0)
# Zero out weights where the sample location is completely outside the input
# range.
# Note sample_f has already had the 0.5 removed, hence the weird range below.

# (barney-s) -------------- returning weights without zeroing ---------------------
return weights
input_size_minus_0_5 = core.dimension_as_value(input_size) - 0.5
return jnp.where(
jnp.logical_and(sample_f >= -0.5,
sample_f <= input_size_minus_0_5)[jnp.newaxis, :], weights, 0)
# (barney-s) -------------- END returning weights without zeroing ---------------------

# JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L86

def _scale_and_translate(x, output_shape: core.Shape,
spatial_dims: Sequence[int], scale, translation,
kernel, antialias: bool, precision):
input_shape = x.shape
assert len(input_shape) == len(output_shape)
assert len(spatial_dims) == len(scale)
assert len(spatial_dims) == len(translation)
if len(spatial_dims) == 0:
return x
contractions = []
in_indices = list(range(len(output_shape)))
out_indices = list(range(len(output_shape)))
for i, d in enumerate(spatial_dims):
d = canonicalize_axis(d, x.ndim)
m = input_shape[d]
n = output_shape[d]
w = compute_weight_mat(m, n, scale[i], translation[i],
kernel, antialias).astype(x.dtype)
contractions.append(w)
contractions.append([d, len(output_shape) + i])
out_indices[d] = len(output_shape) + i
contractions.append(out_indices)
return jnp.einsum(x, in_indices, *contractions, precision=precision)


# JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L172

# scale and translation here are scalar elements of an np.array, what is the
# correct type annotation?
def scale_and_translate(image, shape: core.Shape,
spatial_dims: Sequence[int],
scale, translation,
method: str | ResizeMethod,
antialias: bool = True,
precision=lax.Precision.HIGHEST):
"""Apply a scale and translation to an image.
Generates a new image of shape 'shape' by resampling from the input image
using the sampling method corresponding to method. For 2D images, this
operation transforms a location in the input images, (x, y), to a location
in the output image according to::
(x * scale[1] + translation[1], y * scale[0] + translation[0])
(Note the *inverse* warp is used to generate the sample locations.)
Assumes half-centered pixels, i.e the pixel at integer location ``row, col``
has coordinates ``y, x = row + 0.5, col + 0.5``, and similarly for other input
image dimensions.
If an output location(pixel) maps to an input sample location that is outside
the input boundaries then the value for the output location will be set to
zero.
The ``method`` argument expects one of the following resize methods:
``ResizeMethod.LINEAR``, ``"linear"``, ``"bilinear"``, ``"trilinear"``,
``"triangle"`` `Linear interpolation`_. If ``antialias`` is ``True``, uses a
triangular filter when downsampling.
``ResizeMethod.CUBIC``, ``"cubic"``, ``"bicubic"``, ``"tricubic"``
`Cubic interpolation`_, using the Keys cubic kernel.
``ResizeMethod.LANCZOS3``, ``"lanczos3"``
`Lanczos resampling`_, using a kernel of radius 3.
``ResizeMethod.LANCZOS5``, ``"lanczos5"``
`Lanczos resampling`_, using a kernel of radius 5.
.. _Linear interpolation: https://en.wikipedia.org/wiki/Bilinear_interpolation
.. _Cubic interpolation: https://en.wikipedia.org/wiki/Bicubic_interpolation
.. _Lanczos resampling: https://en.wikipedia.org/wiki/Lanczos_resampling
Args:
image: a JAX array.
shape: the output shape, as a sequence of integers with length equal to the
number of dimensions of `image`.
spatial_dims: A length K tuple specifying the spatial dimensions that the
passed scale and translation should be applied to.
scale: A [K] array with the same number of dimensions as image, containing
the scale to apply in each dimension.
translation: A [K] array with the same number of dimensions as image,
containing the translation to apply in each dimension.
method: the resizing method to use; either a ``ResizeMethod`` instance or a
string. Available methods are: LINEAR, LANCZOS3, LANCZOS5, CUBIC.
antialias: Should an antialiasing filter be used when downsampling? Defaults
to ``True``. Has no effect when upsampling.
Returns:
The scale and translated image.
"""
shape = core.canonicalize_shape(shape)
if len(shape) != image.ndim:
msg = ('shape must have length equal to the number of dimensions of x; '
f' {shape} vs {image.shape}')
raise ValueError(msg)
if isinstance(method, str):
method = ResizeMethod.from_string(method)
if method == ResizeMethod.NEAREST:
# Nearest neighbor is currently special-cased for straight resize, so skip
# for now.
raise ValueError('Nearest neighbor resampling is not currently supported '
'for scale_and_translate.')
assert isinstance(method, ResizeMethod)

kernel = _kernels[method]
image, = promote_dtypes_inexact(image)
scale, translation = promote_dtypes_inexact(scale, translation)
return _scale_and_translate(image, shape, spatial_dims, scale, translation,
kernel, antialias, precision)

# END ----------------- END JAX code copied for testing -----------------------------

0 comments on commit 17134a7

Please sign in to comment.