Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom_jvp to jax.numpy.ldexp. #23923

Merged
merged 1 commit into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* Bug fixes
* Fixed a bug where {func}`jax.numpy.cumsum` would produce incorrect outputs
if a non-boolean input was provided and `dtype=bool` was specified.
* Edit implementation of {func}`jax.numpy.ldexp` to get correct gradient.

## jax 0.4.33 (September 16, 2024)

Expand Down
42 changes: 4 additions & 38 deletions jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2441,44 +2441,10 @@ def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
if (dtypes.issubdtype(x1_dtype, np.complexfloating)
or dtypes.issubdtype(x2_dtype, np.inexact)):
raise ValueError(f"ldexp not supported for input types {(x1_dtype, x2_dtype)}")

x1, x2 = promote_shapes("ldexp", x1, x2)

dtype = dtypes.canonicalize_dtype(dtypes.to_inexact_dtype(x1_dtype))
info = dtypes.finfo(dtype)
int_type = _INT_DTYPES[info.bits]

x1 = lax.convert_element_type(x1, dtype)
x2 = lax.convert_element_type(x2, int_type)

mask = (1 << info.nexp) - 1
bias = 1 - info.minexp
x, e = _normalize_float(x1)
x2 += e + ((x >> info.nmant) & mask) - bias

# find underflow/overflow before denormalization
underflow_cond = less(x2, -(bias + info.nmant))
overflow_cond = greater(x2, bias)

m = lax.full_like(x, 1, dtype=dtype)

# denormals
cond = less(x2, -bias + 1)
x2 = _where(cond, x2 + info.nmant, x2)
m = _where(cond, m / (1 << info.nmant), m)

x2 = lax.convert_element_type(x2, np.int32)
x &= ~(mask << info.nmant)
x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant)

x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype)

# underflow
x = _where(underflow_cond, lax.full_like(x, 0, dtype=dtype), x)
# overflow
x = _where(overflow_cond, lax.sign(x1) * lax.full_like(x, np.inf), x)
# ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0
return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x)
x1, = promote_args_inexact("ldexp", x1)
x2 = lax.convert_element_type(x2, dtypes.dtype(x1))
x = x1 * (2 ** x2)
return _where(isinf(x1) | (x1 == 0), x1, x)


@implements(np.frexp, module='numpy')
Expand Down
9 changes: 9 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6142,6 +6142,15 @@ def testGradLogaddexp2Complex(self, shapes, dtype):
tol = 3e-2
check_grads(jnp.logaddexp2, args, 1, ["fwd", "rev"], tol, tol)

@jtu.sample_product(
n=range(-4, 5),
dtype=[jnp.float32, jnp.float64],
)
def testGradLdexp(self, n, dtype):
rng = jtu.rand_default(self.rng())
x = rng((), dtype)
check_grads(lambda x: jnp.ldexp(x, n), (x,), 1)


class NumpySignaturesTest(jtu.JaxTestCase):

Expand Down