Skip to content

Commit

Permalink
Merge pull request #18706 from jakevdp:i0-grad
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 586079527
  • Loading branch information
jax authors committed Nov 28, 2023
2 parents 8020e7d + a8723ec commit 1d269ed
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
7 changes: 7 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src.custom_derivatives import custom_jvp
from jax._src import dispatch
from jax._src import dtypes
from jax._src.api_util import _ensure_index_tuple
Expand Down Expand Up @@ -2644,6 +2645,7 @@ def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False,
return output


@custom_jvp
@util._wraps(np.i0)
@jit
def i0(x: ArrayLike) -> Array:
Expand All @@ -2653,6 +2655,11 @@ def i0(x: ArrayLike) -> Array:
x_arr = lax.abs(x_arr)
return lax.mul(lax.exp(x_arr), lax.bessel_i0e(x_arr))

@i0.defjvp
def _i0_jvp(primals, tangents):
primal_out, tangent_out = jax.jvp(i0.fun, primals, tangents)
return primal_out, where(primals[0] == 0, 0.0, tangent_out)


@util._wraps(np.ix_)
def ix_(*args: ArrayLike) -> tuple[Array, ...]:
Expand Down
5 changes: 5 additions & 0 deletions tests/lax_numpy_operators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,11 @@ def __rmul__(self, other):
self.assertIsInstance(b * a, MyArray)
self.assertIsInstance(jax.jit(operator.mul)(b, a), MyArray)

def testI0Grad(self):
# Regression test for https://github.com/google/jax/issues/11479
dx = jax.grad(jax.numpy.i0)(0.0)
self.assertArraysEqual(dx, 0.0)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 1d269ed

Please sign in to comment.