diff --git a/CHANGELOG.md b/CHANGELOG.md index 079e055aa994..e8f8d7c6d96d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -60,6 +60,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * Bug fixes * Fixed a bug where {func}`jax.numpy.cumsum` would produce incorrect outputs if a non-boolean input was provided and `dtype=bool` was specified. + * Edit implementation of {func}`jax.numpy.ldexp` to get correct gradient. ## jax 0.4.33 (September 16, 2024) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 8c05671dc429..e9c496e4b654 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -2451,34 +2451,7 @@ def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: x1 = lax.convert_element_type(x1, dtype) x2 = lax.convert_element_type(x2, int_type) - mask = (1 << info.nexp) - 1 - bias = 1 - info.minexp - x, e = _normalize_float(x1) - x2 += e + ((x >> info.nmant) & mask) - bias - - # find underflow/overflow before denormalization - underflow_cond = less(x2, -(bias + info.nmant)) - overflow_cond = greater(x2, bias) - - m = lax.full_like(x, 1, dtype=dtype) - - # denormals - cond = less(x2, -bias + 1) - x2 = _where(cond, x2 + info.nmant, x2) - m = _where(cond, m / (1 << info.nmant), m) - - x2 = lax.convert_element_type(x2, np.int32) - x &= ~(mask << info.nmant) - x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant) - - x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype) - - # underflow - x = _where(underflow_cond, lax.full_like(x, 0, dtype=dtype), x) - # overflow - x = _where(overflow_cond, lax.sign(x1) * lax.full_like(x, np.inf), x) - # ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0 - return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x) + return x1 * lax.convert_element_type(2 ** x2, dtype) @implements(np.frexp, module='numpy') diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 4e81a35003f4..c38ceeb4f803 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6142,6 +6142,15 @@ def testGradLogaddexp2Complex(self, shapes, dtype): tol = 3e-2 check_grads(jnp.logaddexp2, args, 1, ["fwd", "rev"], tol, tol) + @jtu.sample_product( + n=range(-4, 5), + dtype=[jnp.float32, jnp.float64], + ) + def testGradLdexp(self, n, dtype): + rng = jtu.rand_default(self.rng()) + x = rng((), dtype) + check_grads(lambda x: jnp.ldexp(x, n), (x,), 1) + class NumpySignaturesTest(jtu.JaxTestCase):