From 36d6bb901350cb877cbde043570aead2810a66ea Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 30 Sep 2024 13:07:52 -0700 Subject: [PATCH] Better docs for jnp.gradient Also remove skip_params option from util.implements, as this was its last usage. --- jax/_src/numpy/lax_numpy.py | 59 ++++++++++++++++++++++++++++++++++++- jax/_src/numpy/util.py | 6 +--- tests/lax_numpy_test.py | 3 +- 3 files changed, 60 insertions(+), 8 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 1d1b3512fd3e..ebef1756659d 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1451,7 +1451,6 @@ def ediff1d(ary: ArrayLike, to_end: ArrayLike | None = None, result = concatenate((result, ravel(asarray(to_end, dtype=arr.dtype)))) return result -@util.implements(np.gradient, skip_params=['edge_order']) @partial(jit, static_argnames=("axis", "edge_order")) def gradient( f: ArrayLike, @@ -1459,6 +1458,64 @@ def gradient( axis: int | Sequence[int] | None = None, edge_order: int | None = None, ) -> Array | list[Array]: + """Compute the numerical gradient of a sampled function. + + JAX implementation of :func:`numpy.gradient`. + + The gradient in ``jnp.gradient`` is computed using second-order finite + differences across the array of sampled function values. This should not + be confused with :func:`jax.grad`, which computes a precise gradient of + a callable function via :ref:`automatic differentiation `. + + Args: + f: *N*-dimensional array of function values. + varargs: optional list of scalars or arrays specifying spacing of + function evaluations. Options are: + + - not specified: unit spacing in all dimensions. + - a single scalar: constant spacing in all dimensions. + - *N* values: specify different spacing in each dimension: + + - scalar values indicate constant spacing in that dimension. + - array values must match the length of the corresponding dimension, + and specify the coordinates at which ``f`` is evaluated. + + edge_order: not implemented in JAX + axis: integer or tuple of integers specifying the axis along which + to compute the gradient. If None (default) calculates the gradient + along all axes. + + Returns: + an array or tuple of arrays containing the numerical gradient along + each specified axis. + + See also: + - :func:`jax.grad`: automatic differentiation of a function with a single output. + + Examples: + Comparing numerical and automatic differentiation of a simple function: + + >>> def f(x): + ... return jnp.sin(x) * jnp.exp(-x / 4) + ... + >>> def gradf_exact(x): + ... # exact analytical gradient of f(x) + ... return -f(x) / 4 + jnp.cos(x) * jnp.exp(-x / 4) + ... + >>> x = jnp.linspace(0, 5, 10) + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print("numerical gradient:", jnp.gradient(f(x), x)) + ... print("automatic gradient:", jax.vmap(jax.grad(f))(x)) + ... print("exact gradient: ", gradf_exact(x)) + ... + numerical gradient: [ 0.83 0.61 0.18 -0.2 -0.43 -0.49 -0.39 -0.21 -0.02 0.08] + automatic gradient: [ 1. 0.62 0.17 -0.23 -0.46 -0.51 -0.41 -0.21 -0.01 0.15] + exact gradient: [ 1. 0.62 0.17 -0.23 -0.46 -0.51 -0.41 -0.21 -0.01 0.15] + + Notice that, as expected, the numerical gradient has some approximation error + compared to the automatic gradient computed via :func:`jax.grad`. + """ if edge_order is not None: raise NotImplementedError( diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 9c9bc5d389e1..c5b1530ca215 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -115,7 +115,6 @@ def implements( original_fun: Callable[..., Any] | None, update_doc: bool = True, sections: Sequence[str] = ('Parameters', 'Returns', 'References'), - skip_params: Sequence[str] = (), module: str | None = None, ) -> Callable[[_T], _T]: """Decorator for JAX functions which implement a specified NumPy function. @@ -133,8 +132,6 @@ def implements( If False, include the numpy docstring verbatim. sections: a list of sections to include in the docstring. The default is ["Parameters", "Returns", "References"] - skip_params: a list of strings containing names of parameters accepted by the - function that should be skipped in the parameter list. module: an optional string specifying the module from which the original function is imported. This is useful for objects such as ufuncs, where the module cannot be determined from the original function itself. @@ -162,8 +159,7 @@ def decorator(wrapped_fun): # Remove unrecognized parameter descriptions. parameters = _parse_parameters(parsed.sections['Parameters']) parameters = {p: desc for p, desc in parameters.items() - if (code is None or p in code.co_varnames) - and p not in skip_params} + if (code is None or p in code.co_varnames)} if parameters: parsed.sections['Parameters'] = ( "Parameters\n" diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 6f8167df9c29..4e81a35003f4 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6381,14 +6381,13 @@ def wrapped(x, out=None): if jit: wrapped = jax.jit(wrapped) - wrapped = implements(orig, skip_params=['out'])(wrapped) + wrapped = implements(orig)(wrapped) doc = wrapped.__doc__ self.assertStartsWith(doc, "Example Docstring") self.assertIn("Original docstring below", doc) self.assertIn("Parameters", doc) self.assertIn("Returns", doc) - self.assertNotIn('out', doc) self.assertNotIn('other_arg', doc) self.assertNotIn('versionadded', doc)