Skip to content

Commit

Permalink
GH-44229: [Docs] Add PyArrow to JAX example to the docs (#44230)
Browse files Browse the repository at this point in the history
### Rationale for this change

Explicitly mention in the docs a way that PyArrow can interop with [JAX](https://github.com/jax-ml/jax).

### What changes are included in this PR?

 - Tweaks to the phrasing
 - Two JAX examples: one for `jax.numpy` and another for `jax.dlpack`

### Are these changes tested?

N/A
* GitHub Issue: #44229

Authored-by: Felipe Oliveira Carvalho <felipekde@gmail.com>
Signed-off-by: Felipe Oliveira Carvalho <felipekde@gmail.com>
  • Loading branch information
felipecrv committed Sep 26, 2024
1 parent c557fe5 commit 9fa78d0
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions docs/source/python/dlpack.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ PyArrow implements the second part of the protocol
(``__dlpack__(self, stream=None)`` and ``__dlpack_device__``) and can
thus be consumed by libraries implementing ``from_dlpack``.

Example
-------
Examples
--------

Convert a PyArrow CPU array to NumPy array:
Convert a PyArrow CPU array into a NumPy array:

.. code-block::
Expand All @@ -84,10 +84,20 @@ Convert a PyArrow CPU array to NumPy array:
>>> np.from_dlpack(array)
array([2, 0, 2, 4])
Convert a PyArrow CPU array to PyTorch tensor:
Convert a PyArrow CPU array into a PyTorch tensor:

.. code-block::
>>> import torch
>>> torch.from_dlpack(array)
tensor([2, 0, 2, 4])
Convert a PyArrow CPU array into a JAX array:

.. code-block::
>>> import jax
>>> jax.numpy.from_dlpack(array)
Array([2, 0, 2, 4], dtype=int32)
>>> jax.dlpack.from_dlpack(array)
Array([2, 0, 2, 4], dtype=int32)

0 comments on commit 9fa78d0

Please sign in to comment.