Skip to content

Commit

Permalink
Edit implementation of jax.numpy.ldexp to get correct gradient.
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgmartin committed Sep 30, 2024
1 parent 4a596ae commit acb47ab
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 28 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* Bug fixes
* Fixed a bug where {func}`jax.numpy.cumsum` would produce incorrect outputs
if a non-boolean input was provided and `dtype=bool` was specified.
* Edit implementation of {func}`jax.numpy.ldexp` to get correct gradient.

## jax 0.4.33 (September 16, 2024)

Expand Down
29 changes: 1 addition & 28 deletions jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2451,34 +2451,7 @@ def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1 = lax.convert_element_type(x1, dtype)
x2 = lax.convert_element_type(x2, int_type)

mask = (1 << info.nexp) - 1
bias = 1 - info.minexp
x, e = _normalize_float(x1)
x2 += e + ((x >> info.nmant) & mask) - bias

# find underflow/overflow before denormalization
underflow_cond = less(x2, -(bias + info.nmant))
overflow_cond = greater(x2, bias)

m = lax.full_like(x, 1, dtype=dtype)

# denormals
cond = less(x2, -bias + 1)
x2 = _where(cond, x2 + info.nmant, x2)
m = _where(cond, m / (1 << info.nmant), m)

x2 = lax.convert_element_type(x2, np.int32)
x &= ~(mask << info.nmant)
x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant)

x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype)

# underflow
x = _where(underflow_cond, lax.full_like(x, 0, dtype=dtype), x)
# overflow
x = _where(overflow_cond, lax.sign(x1) * lax.full_like(x, np.inf), x)
# ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0
return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x)
return x1 * lax.convert_element_type(2 ** x2, dtype)


@implements(np.frexp, module='numpy')
Expand Down
9 changes: 9 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6142,6 +6142,15 @@ def testGradLogaddexp2Complex(self, shapes, dtype):
tol = 3e-2
check_grads(jnp.logaddexp2, args, 1, ["fwd", "rev"], tol, tol)

@jtu.sample_product(
n=range(-4, 5),
dtype=[jnp.float32, jnp.float64],
)
def testGradLdexp(self, n, dtype):
rng = jtu.rand_default(self.rng())
x = rng((), dtype)
check_grads(lambda x: jnp.ldexp(x, n), (x,), 1)


class NumpySignaturesTest(jtu.JaxTestCase):

Expand Down

0 comments on commit acb47ab

Please sign in to comment.