From 9fa78d0d1e2cc22086e8d79afba518710a1e657e Mon Sep 17 00:00:00 2001 From: Felipe Oliveira Carvalho Date: Wed, 25 Sep 2024 21:20:47 -0300 Subject: [PATCH] GH-44229: [Docs] Add PyArrow to JAX example to the docs (#44230) ### Rationale for this change Explicitly mention in the docs a way that PyArrow can interop with [JAX](https://github.com/jax-ml/jax). ### What changes are included in this PR? - Tweaks to the phrasing - Two JAX examples: one for `jax.numpy` and another for `jax.dlpack` ### Are these changes tested? N/A * GitHub Issue: #44229 Authored-by: Felipe Oliveira Carvalho Signed-off-by: Felipe Oliveira Carvalho --- docs/source/python/dlpack.rst | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/docs/source/python/dlpack.rst b/docs/source/python/dlpack.rst index 024c2800e1107..9f0d3b58aa6e5 100644 --- a/docs/source/python/dlpack.rst +++ b/docs/source/python/dlpack.rst @@ -63,10 +63,10 @@ PyArrow implements the second part of the protocol (``__dlpack__(self, stream=None)`` and ``__dlpack_device__``) and can thus be consumed by libraries implementing ``from_dlpack``. -Example -------- +Examples +-------- -Convert a PyArrow CPU array to NumPy array: +Convert a PyArrow CPU array into a NumPy array: .. code-block:: @@ -84,10 +84,20 @@ Convert a PyArrow CPU array to NumPy array: >>> np.from_dlpack(array) array([2, 0, 2, 4]) -Convert a PyArrow CPU array to PyTorch tensor: +Convert a PyArrow CPU array into a PyTorch tensor: .. code-block:: >>> import torch >>> torch.from_dlpack(array) tensor([2, 0, 2, 4]) + +Convert a PyArrow CPU array into a JAX array: + +.. code-block:: + + >>> import jax + >>> jax.numpy.from_dlpack(array) + Array([2, 0, 2, 4], dtype=int32) + >>> jax.dlpack.from_dlpack(array) + Array([2, 0, 2, 4], dtype=int32)