Skip to content

Commit

Permalink
fix: raise error for invalid object in from_dlpack (#2662)
Browse files Browse the repository at this point in the history
* fix: raise error for invalid object

* feat: allow Index objects to be converted to dlpack
  • Loading branch information
agoose77 committed Aug 21, 2023
1 parent 01cbec2 commit 3b322a2
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
11 changes: 10 additions & 1 deletion src/awkward/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpylike import NumpyLike, NumpyMetadata
from awkward._nplikes.typetracer import TypeTracer
from awkward._typing import Final, Self
from awkward._typing import Any, Final, Self

np: Final = NumpyMetadata.instance()
numpy: Final = Numpy.instance()
Expand Down Expand Up @@ -145,6 +145,15 @@ def __cuda_array_interface__(self):
def __array_interface__(self):
return self._data.__array_interface__

def __dlpack__(self) -> tuple[int, int]:
return self._data.__dlpack_device__()

def __dlpack_device__(self, stream: Any = None) -> Any:
if stream is None:
return self._data.__dlpack__()
else:
return self._data.__dlpack__(stream)

def __repr__(self):
return self._repr("", "", "")

Expand Down
6 changes: 4 additions & 2 deletions src/awkward/operations/ak_from_dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ def from_dlpack(
"""
try:
dlpack_info_func = array.__dlpack_device__
except AttributeError:
...
except AttributeError as err:
raise TypeError(
f"Expected an object that implements the DLPack protocol, received {type(array)}"
) from err
device_type, device_id = dlpack_info_func()

# Only a subset of known devices are supported.
Expand Down
7 changes: 7 additions & 0 deletions tests/test_2649_dlpack_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,10 @@ def test_to_layout():

np_from_ak = ak.to_numpy(layout)
assert np.shares_memory(np_array, np_from_ak)


def test_invalid_argument():
with pytest.raises(
TypeError, match=r"Expected an object that implements the DLPack protocol"
):
ak.from_dlpack([1, 2, 3])

0 comments on commit 3b322a2

Please sign in to comment.