Skip to content

Commit

Permalink
Merge pull request #24025 from jakevdp:gradient-doc
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 680703792
  • Loading branch information
Google-ML-Automation committed Sep 30, 2024
2 parents 3766f88 + 36d6bb9 commit cdc7278
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 8 deletions.
59 changes: 58 additions & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1451,14 +1451,71 @@ def ediff1d(ary: ArrayLike, to_end: ArrayLike | None = None,
result = concatenate((result, ravel(asarray(to_end, dtype=arr.dtype))))
return result

@util.implements(np.gradient, skip_params=['edge_order'])
@partial(jit, static_argnames=("axis", "edge_order"))
def gradient(
f: ArrayLike,
*varargs: ArrayLike,
axis: int | Sequence[int] | None = None,
edge_order: int | None = None,
) -> Array | list[Array]:
"""Compute the numerical gradient of a sampled function.
JAX implementation of :func:`numpy.gradient`.
The gradient in ``jnp.gradient`` is computed using second-order finite
differences across the array of sampled function values. This should not
be confused with :func:`jax.grad`, which computes a precise gradient of
a callable function via :ref:`automatic differentiation <automatic-differentiation>`.
Args:
f: *N*-dimensional array of function values.
varargs: optional list of scalars or arrays specifying spacing of
function evaluations. Options are:
- not specified: unit spacing in all dimensions.
- a single scalar: constant spacing in all dimensions.
- *N* values: specify different spacing in each dimension:
- scalar values indicate constant spacing in that dimension.
- array values must match the length of the corresponding dimension,
and specify the coordinates at which ``f`` is evaluated.
edge_order: not implemented in JAX
axis: integer or tuple of integers specifying the axis along which
to compute the gradient. If None (default) calculates the gradient
along all axes.
Returns:
an array or tuple of arrays containing the numerical gradient along
each specified axis.
See also:
- :func:`jax.grad`: automatic differentiation of a function with a single output.
Examples:
Comparing numerical and automatic differentiation of a simple function:
>>> def f(x):
... return jnp.sin(x) * jnp.exp(-x / 4)
...
>>> def gradf_exact(x):
... # exact analytical gradient of f(x)
... return -f(x) / 4 + jnp.cos(x) * jnp.exp(-x / 4)
...
>>> x = jnp.linspace(0, 5, 10)
>>> with jnp.printoptions(precision=2, suppress=True):
... print("numerical gradient:", jnp.gradient(f(x), x))
... print("automatic gradient:", jax.vmap(jax.grad(f))(x))
... print("exact gradient: ", gradf_exact(x))
...
numerical gradient: [ 0.83 0.61 0.18 -0.2 -0.43 -0.49 -0.39 -0.21 -0.02 0.08]
automatic gradient: [ 1. 0.62 0.17 -0.23 -0.46 -0.51 -0.41 -0.21 -0.01 0.15]
exact gradient: [ 1. 0.62 0.17 -0.23 -0.46 -0.51 -0.41 -0.21 -0.01 0.15]
Notice that, as expected, the numerical gradient has some approximation error
compared to the automatic gradient computed via :func:`jax.grad`.
"""

if edge_order is not None:
raise NotImplementedError(
Expand Down
6 changes: 1 addition & 5 deletions jax/_src/numpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def implements(
original_fun: Callable[..., Any] | None,
update_doc: bool = True,
sections: Sequence[str] = ('Parameters', 'Returns', 'References'),
skip_params: Sequence[str] = (),
module: str | None = None,
) -> Callable[[_T], _T]:
"""Decorator for JAX functions which implement a specified NumPy function.
Expand All @@ -133,8 +132,6 @@ def implements(
If False, include the numpy docstring verbatim.
sections: a list of sections to include in the docstring. The default is
["Parameters", "Returns", "References"]
skip_params: a list of strings containing names of parameters accepted by the
function that should be skipped in the parameter list.
module: an optional string specifying the module from which the original function
is imported. This is useful for objects such as ufuncs, where the module cannot
be determined from the original function itself.
Expand Down Expand Up @@ -162,8 +159,7 @@ def decorator(wrapped_fun):
# Remove unrecognized parameter descriptions.
parameters = _parse_parameters(parsed.sections['Parameters'])
parameters = {p: desc for p, desc in parameters.items()
if (code is None or p in code.co_varnames)
and p not in skip_params}
if (code is None or p in code.co_varnames)}
if parameters:
parsed.sections['Parameters'] = (
"Parameters\n"
Expand Down
3 changes: 1 addition & 2 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6381,14 +6381,13 @@ def wrapped(x, out=None):
if jit:
wrapped = jax.jit(wrapped)

wrapped = implements(orig, skip_params=['out'])(wrapped)
wrapped = implements(orig)(wrapped)
doc = wrapped.__doc__

self.assertStartsWith(doc, "Example Docstring")
self.assertIn("Original docstring below", doc)
self.assertIn("Parameters", doc)
self.assertIn("Returns", doc)
self.assertNotIn('out', doc)
self.assertNotIn('other_arg', doc)
self.assertNotIn('versionadded', doc)

Expand Down

0 comments on commit cdc7278

Please sign in to comment.