From 65a58d622ce42fcc1321adf0ab237fc05dfd5301 Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Mon, 30 Sep 2024 18:27:39 -0400 Subject: [PATCH] Edit implementation of jax.numpy.ldexp to get correct gradient. --- CHANGELOG.md | 1 + jax/_src/numpy/ufuncs.py | 42 ++++------------------------------------ tests/lax_numpy_test.py | 9 +++++++++ 3 files changed, 14 insertions(+), 38 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 079e055aa994..e8f8d7c6d96d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -60,6 +60,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * Bug fixes * Fixed a bug where {func}`jax.numpy.cumsum` would produce incorrect outputs if a non-boolean input was provided and `dtype=bool` was specified. + * Edit implementation of {func}`jax.numpy.ldexp` to get correct gradient. ## jax 0.4.33 (September 16, 2024) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 8c05671dc429..1ae7b6052b39 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -2441,44 +2441,10 @@ def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: if (dtypes.issubdtype(x1_dtype, np.complexfloating) or dtypes.issubdtype(x2_dtype, np.inexact)): raise ValueError(f"ldexp not supported for input types {(x1_dtype, x2_dtype)}") - - x1, x2 = promote_shapes("ldexp", x1, x2) - - dtype = dtypes.canonicalize_dtype(dtypes.to_inexact_dtype(x1_dtype)) - info = dtypes.finfo(dtype) - int_type = _INT_DTYPES[info.bits] - - x1 = lax.convert_element_type(x1, dtype) - x2 = lax.convert_element_type(x2, int_type) - - mask = (1 << info.nexp) - 1 - bias = 1 - info.minexp - x, e = _normalize_float(x1) - x2 += e + ((x >> info.nmant) & mask) - bias - - # find underflow/overflow before denormalization - underflow_cond = less(x2, -(bias + info.nmant)) - overflow_cond = greater(x2, bias) - - m = lax.full_like(x, 1, dtype=dtype) - - # denormals - cond = less(x2, -bias + 1) - x2 = _where(cond, x2 + info.nmant, x2) - m = _where(cond, m / (1 << info.nmant), m) - - x2 = lax.convert_element_type(x2, np.int32) - x &= ~(mask << info.nmant) - x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant) - - x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype) - - # underflow - x = _where(underflow_cond, lax.full_like(x, 0, dtype=dtype), x) - # overflow - x = _where(overflow_cond, lax.sign(x1) * lax.full_like(x, np.inf), x) - # ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0 - return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x) + x1, = promote_args_inexact("ldexp", x1) + x2 = lax.convert_element_type(x2, dtypes.dtype(x1)) + x = x1 * (2 ** x2) + return _where(isinf(x1) | (x1 == 0), x1, x) @implements(np.frexp, module='numpy') diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 4e81a35003f4..c38ceeb4f803 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6142,6 +6142,15 @@ def testGradLogaddexp2Complex(self, shapes, dtype): tol = 3e-2 check_grads(jnp.logaddexp2, args, 1, ["fwd", "rev"], tol, tol) + @jtu.sample_product( + n=range(-4, 5), + dtype=[jnp.float32, jnp.float64], + ) + def testGradLdexp(self, n, dtype): + rng = jtu.rand_default(self.rng()) + x = rng((), dtype) + check_grads(lambda x: jnp.ldexp(x, n), (x,), 1) + class NumpySignaturesTest(jtu.JaxTestCase):