From 460237616647e1ce999a59d0a0607e3e72ef554e Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Tue, 24 Sep 2024 13:28:56 +0200 Subject: [PATCH] Update with suggestions from @pearu. --- .../numba_backend/_compressed/compressed.py | 18 +++++----------- sparse/numba_backend/_coo/core.py | 11 +++------- sparse/numba_backend/_dok.py | 7 ++----- sparse/numba_backend/_sparse_array.py | 21 ++++++------------- 4 files changed, 16 insertions(+), 41 deletions(-) diff --git a/sparse/numba_backend/_compressed/compressed.py b/sparse/numba_backend/_compressed/compressed.py index 8ac5a036..c504da26 100644 --- a/sparse/numba_backend/_compressed/compressed.py +++ b/sparse/numba_backend/_compressed/compressed.py @@ -847,11 +847,8 @@ def isnan(self): # `GCXS` is a reshaped/transposed `CSR`, but it can't (usually) # be expressed in the `binsparse` 0.1 language. # We are missing index maps. - def __binsparse_descriptor__(self) -> dict: - return super().__binsparse_descriptor__() - - def __binsparse_dlpack__(self) -> dict[str, np.ndarray]: - return super().__binsparse_dlpack__() + def __binsparse__(self) -> tuple[dict, list[np.ndarray]]: + return super().__binsparse__() class _Compressed2d(GCXS): @@ -892,13 +889,13 @@ def from_numpy(cls, x, fill_value=0, idx_dtype=None): coo = COO.from_numpy(x, fill_value=fill_value, idx_dtype=idx_dtype) return cls.from_coo(coo, cls.class_compressed_axes, idx_dtype) - def __binsparse_descriptor__(self) -> dict: + def __binsparse__(self) -> tuple[dict, list[np.ndarray]]: from sparse._version import __version__ data_dt = str(self.data.dtype) if np.issubdtype(data_dt, np.complexfloating): data_dt = f"complex[float{self.data.dtype.itemsize // 2}]" - return { + descriptor = { "binsparse": { "version": "0.1", "format": self.format.upper(), @@ -913,12 +910,7 @@ def __binsparse_descriptor__(self) -> dict: "original_source": f"`sparse`, version {__version__}", } - def __binsparse_dlpack__(self) -> dict[str, np.ndarray]: - return { - "pointers_to_1": self.indices, - "indices_1": self.indptr, - "values": self.data, - } + return descriptor, [self.indices, self.indptr, self.data] class CSR(_Compressed2d): diff --git a/sparse/numba_backend/_coo/core.py b/sparse/numba_backend/_coo/core.py index a787fed8..fb39aaab 100644 --- a/sparse/numba_backend/_coo/core.py +++ b/sparse/numba_backend/_coo/core.py @@ -1537,13 +1537,13 @@ def isnan(self): prune=True, ) - def __binsparse_descriptor__(self) -> dict: + def __binsparse__(self) -> tuple[dict, list[np.ndarray]]: from sparse._version import __version__ data_dt = str(self.data.dtype) if np.issubdtype(data_dt, np.complexfloating): data_dt = f"complex[float{self.data.dtype.itemsize // 2}]" - return { + descriptor = { "binsparse": { "version": "0.1", "format": { @@ -1568,12 +1568,7 @@ def __binsparse_descriptor__(self) -> dict: "original_source": f"`sparse`, version {__version__}", } - def __binsparse_dlpack__(self) -> dict[str, np.ndarray]: - return { - "pointers_to_1": np.array([0, self.nnz], dtype=np.uint8), - "indices_1": self.coords, - "values": self.data, - } + return descriptor, [np.array([0, self.nnz], dtype=np.uint8), self.coords, self.data] def as_coo(x, shape=None, fill_value=None, idx_dtype=None): diff --git a/sparse/numba_backend/_dok.py b/sparse/numba_backend/_dok.py index 809e926e..b3a8e747 100644 --- a/sparse/numba_backend/_dok.py +++ b/sparse/numba_backend/_dok.py @@ -548,11 +548,8 @@ def reshape(self, shape, order="C"): return DOK.from_coo(self.to_coo().reshape(shape)) - def __binsparse_descriptor__(self) -> dict: - raise RuntimeError("`DOK` doesn't support the `__binsparse_descriptor__` protocol.") - - def __binsparse_dlpack__(self) -> dict[str, np.ndarray]: - raise RuntimeError("`DOK` doesn't support the `__binsparse_dlpack__` protocol.") + def __binsparse__(self) -> tuple[dict, list[np.ndarray]]: + raise RuntimeError("`DOK` doesn't support the `__binsparse__` protocol.") def to_slice(k): diff --git a/sparse/numba_backend/_sparse_array.py b/sparse/numba_backend/_sparse_array.py index e776a347..7a7b84cc 100644 --- a/sparse/numba_backend/_sparse_array.py +++ b/sparse/numba_backend/_sparse_array.py @@ -219,27 +219,18 @@ def _str_impl(self, summary): return summary @abstractmethod - def __binsparse_descriptor__(self) -> dict: - """Return a `dict` equivalent to a parsed JSON [`binsparse` descriptor](https://graphblas.org/binsparse-specification/#descriptor) + def __binsparse__(self) -> tuple[dict, list[np.ndarray]]: + """Return a 2-tuple: + * First element is a `dict` equivalent to a parsed JSON [`binsparse` descriptor](https://graphblas.org/binsparse-specification/#descriptor) of this array. + * Second element is a `list[np.ndarray]` of the constituent arrays. Returns ------- dict Parsed `binsparse` descriptor. - """ - raise NotImplementedError - - @abstractmethod - def __binsparse_dlpack__(self) -> dict[str, np.ndarray]: - """A `dict` containing the constituent arrays of this sparse array. The keys are compatible with the - [`binsparse`](https://graphblas.org/binsparse-specification/) scheme, and the values are [`__dlpack__`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html) - compatible objects. - - Returns - ------- - dict[str, np.ndarray] - The constituent arrays. + list[np.ndarray] + The constituent arrays """ raise NotImplementedError