Skip to content

Commit

Permalink
Merge pull request #24024 from rajasekharporeddy:testbranch1
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 680669011
  • Loading branch information
Google-ML-Automation committed Sep 30, 2024
2 parents ff1c2ac + 5904fe1 commit bdae9ac
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,9 +1054,32 @@ def sqrt(x: ArrayLike, /) -> Array:
"""
return lax.sqrt(*promote_args_inexact('sqrt', x))

@implements(np.cbrt, module='numpy')

@partial(jit, inline=True)
def cbrt(x: ArrayLike, /) -> Array:
"""Calculates element-wise cube root of the input array.
JAX implementation of :obj:`numpy.cbrt`.
Args:
x: input array or scalar. ``complex`` dtypes are not supported.
Returns:
An array containing the cube root of the elements of ``x``.
See also:
- :func:`jax.numpy.sqrt`: Calculates the element-wise non-negative square root
of the input.
- :func:`jax.numpy.square`: Calculates the element-wise square of the input.
Examples:
>>> x = jnp.array([[216, 125, 64],
... [-27, -8, -1]])
>>> with jnp.printoptions(precision=3, suppress=True):
... jnp.cbrt(x)
Array([[ 6., 5., 4.],
[-3., -2., -1.]], dtype=float32)
"""
return lax.cbrt(*promote_args_inexact('cbrt', x))

@partial(jit, inline=True)
Expand Down

0 comments on commit bdae9ac

Please sign in to comment.