Skip to content

Commit

Permalink
API: Array API support - Part 1
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol committed Dec 21, 2023
1 parent 0e283ff commit 1d6943c
Show file tree
Hide file tree
Showing 11 changed files with 591 additions and 21 deletions.
76 changes: 76 additions & 0 deletions sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,79 @@

__version__ = get_versions()["version"]
del get_versions

from numpy import (
bool_ as bool,
float16,
float32,
float64,
complex64,
complex128,
uint8,
uint16,
uint32,
uint64,
int8,
int16,
int32,
int64,
pi,
e,
nan,
inf,
newaxis,
sin,
sinh,
cos,
cosh,
tan,
tanh,
arcsin as asin,
arcsinh as asinh,
arccos as acos,
arccosh as acosh,
arctan as atan,
arctan2 as atan2,
arctanh as atanh,
log,
log2,
log1p,
log10,
logaddexp,
power as pow,
sign,
square,
sqrt,
logical_and,
logical_not,
logical_or,
logical_xor,
bitwise_and,
bitwise_or,
bitwise_xor,
bitwise_not,
trunc,
add,
subtract,
remainder,
positive,
not_equal,
negative,
multiply,
less_equal,
less,
greater_equal,
greater,
floor_divide,
floor,
exp,
expm1,
divide,
ceil,
left_shift as bitwise_left_shift,
right_shift as bitwise_right_shift,
invert as bitwise_invert,
finfo,
iinfo,
can_cast,
)
226 changes: 211 additions & 15 deletions sparse/_common.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import numpy as np
import numba
import scipy.sparse
import builtins
from collections.abc import Iterable
from functools import wraps, reduce
from itertools import chain
from operator import mul, index
from collections.abc import Iterable
import warnings

import numpy as np
import numba

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note

Module 'numba' is imported with both 'import' and 'import from'.
import scipy.sparse

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note

Module 'scipy.sparse' is imported with both 'import' and 'import from'.
from scipy.sparse import spmatrix
from numba import literal_unroll
import warnings

from ._sparse_array import SparseArray
from ._utils import (
Expand All @@ -33,6 +35,8 @@
roll,
kron,
argwhere,
argmax,
argmin,
isposinf,
isneginf,
result_type,
Expand Down Expand Up @@ -187,7 +191,7 @@ def tensordot(a, b, axes=2, *, return_type=None):
newshape_b = (N2, -1)
oldb = [bs[axis] for axis in notin]

if any(dim == 0 for dim in chain(newshape_a, newshape_b)):
if builtins.any(dim == 0 for dim in chain(newshape_a, newshape_b)):
res = asCOO(np.empty(olda + oldb), check=False)
if isinstance(a, np.ndarray) or isinstance(b, np.ndarray):
res = res.todense()
Expand Down Expand Up @@ -268,12 +272,12 @@ def _matmul_recurser(a, b):
if a.ndim == 2:
return dot(a, b)
res = []
for i in range(max(a.shape[0], b.shape[0])):
for i in range(builtins.max(a.shape[0], b.shape[0])):
a_i = a[0] if a.shape[0] == 1 else a[i]
b_i = b[0] if b.shape[0] == 1 else b[i]
res.append(_matmul_recurser(a_i, b_i))
mask = [isinstance(x, SparseArray) for x in res]
if all(mask):
if builtins.all(mask):
return stack(res)
else:
res = [x.todense() if isinstance(x, SparseArray) else x for x in res]
Expand Down Expand Up @@ -334,7 +338,7 @@ def _dot(a, b, return_type=None):
from ._sparse_array import SparseArray

out_shape = (a.shape[0], b.shape[1])
if all(isinstance(arr, SparseArray) for arr in [a, b]) and any(
if builtins.all(isinstance(arr, SparseArray) for arr in [a, b]) and builtins.any(
isinstance(arr, GCXS) for arr in [a, b]
):
a = a.asformat("gcxs")
Expand Down Expand Up @@ -1333,7 +1337,7 @@ def _parse_einsum_input(operands):
if operands[num].shape == ():
ellipse_count = 0
else:
ellipse_count = max(operands[num].ndim, 1)
ellipse_count = builtins.max(operands[num].ndim, 1)
ellipse_count -= len(sub) - 3

if ellipse_count > longest:
Expand Down Expand Up @@ -1573,7 +1577,7 @@ def stack(arrays, axis=0, compressed_axes=None):
"""
from ._compressed import GCXS

if not all(isinstance(arr, GCXS) for arr in arrays):
if not builtins.all(isinstance(arr, GCXS) for arr in arrays):
from ._coo import stack as coo_stack

return coo_stack(arrays, axis)
Expand Down Expand Up @@ -1612,7 +1616,7 @@ def concatenate(arrays, axis=0, compressed_axes=None):
"""
from ._compressed import GCXS

if not all(isinstance(arr, GCXS) for arr in arrays):
if not builtins.all(isinstance(arr, GCXS) for arr in arrays):
from ._coo import concatenate as coo_concat

return coo_concat(arrays, axis)
Expand All @@ -1622,6 +1626,9 @@ def concatenate(arrays, axis=0, compressed_axes=None):
return gcxs_concat(arrays, axis, compressed_axes)


concat = concatenate


def eye(N, M=None, k=0, dtype=float, format="coo", **kwargs):
"""Return a 2-D array in the specified format with ones on the diagonal and zeros elsewhere.
Expand Down Expand Up @@ -1665,14 +1672,14 @@ def eye(N, M=None, k=0, dtype=float, format="coo", **kwargs):
M = int(M)
k = int(k)

data_length = min(N, M)
data_length = builtins.min(N, M)

if k > 0:
data_length = max(min(data_length, M - k), 0)
data_length = builtins.max(builtins.min(data_length, M - k), 0)
n_coords = np.arange(data_length, dtype=np.intp)
m_coords = n_coords + k
elif k < 0:
data_length = max(min(data_length, N + k), 0)
data_length = builtins.max(builtins.min(data_length, N + k), 0)
m_coords = np.arange(data_length, dtype=np.intp)
n_coords = m_coords - k
else:
Expand Down Expand Up @@ -1905,6 +1912,20 @@ def ones_like(a, dtype=None, shape=None, format=None, **kwargs):
return full_like(a, 1, dtype=dtype, shape=shape, format=format, **kwargs)


def empty(shape, dtype=float, format="coo", **kwargs):
return full(shape, 0, np.dtype(dtype)).asformat(format, **kwargs)

Check warning on line 1916 in sparse/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_common.py#L1916

Added line #L1916 was not covered by tests


empty.__doc__ = zeros.__doc__


def empty_like(a, dtype=None, shape=None, format=None, **kwargs):
return full_like(a, 0, dtype=dtype, shape=shape, format=format, **kwargs)

Check warning on line 1923 in sparse/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_common.py#L1923

Added line #L1923 was not covered by tests


empty_like.__doc__ = zeros_like.__doc__


def outer(a, b, out=None):
"""
Return outer product of two sparse arrays.
Expand Down Expand Up @@ -2088,3 +2109,178 @@ def format_to_string(format):
return format

raise ValueError(f"invalid format: {format}")


def asarray(
obj, /, *, dtype=None, format="coo", backend="pydata", device=None, copy=False
):

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note

Mixing implicit and explicit returns may indicate an error as implicit returns always return None.
"""
Convert the input to a sparse array.
Parameters
----------
obj : array_like
Object to be converted to an array.
dtype : dtype, optional
Output array data type.
format : str, optional
Output array sparse format.
backend : str, optional
Backend for the output array.
device : str, optional
Device on which to place the created array.
copy : bool, optional
Boolean indicating whether or not to copy the input.
Returns
-------
out : Union[SparseArray, numpy.ndarray]
Sparse or 0-D array containing the data from `obj`.
Examples
--------
>>> x = np.eye(8, dtype='i8')
>>> sparse.asarray(x, format="COO")
<COO: shape=(8, 8), dtype=int64, nnz=8, fill_value=0>
"""
if format not in ["coo", "dok", "gcxs"]:
raise ValueError(f"{format} format not supported.")

if backend not in ["pydata", "taco"]:
raise ValueError(f"{backend} backend not supported.")

from ._coo import COO
from ._dok import DOK
from ._compressed import GCXS

format_dict = {"coo": COO, "dok": DOK, "gcxs": GCXS}

if backend == "pydata":
if isinstance(obj, (COO, DOK, GCXS)):
# TODO: consider `format` argument
warnings.warn("`format` argument was ignored")
return obj

elif isinstance(obj, spmatrix):
return format_dict[format].from_scipy_sparse(
obj.astype(dtype=dtype, copy=copy)
)

# check for scalars and 0-D arrays
elif np.isscalar(obj) or (isinstance(obj, np.ndarray) and obj.shape == ()):
return np.asarray(obj, dtype=dtype)

elif isinstance(obj, np.ndarray):
return format_dict[format].from_numpy(obj).astype(dtype=dtype, copy=copy)

else:
raise ValueError(f"{type(obj)} not supported.")

elif backend == "taco":
raise ValueError("Taco not yet supported.")


def _support_numpy(func):
"""
In case a NumPy array is passed to `sparse` namespace function
we want to flag it and dispatch to NumPy.
"""

def wrapper_func(*args, **kwargs):
x = args[0]
if isinstance(x, (np.ndarray, np.number)):
warnings.warn(

Check warning on line 2189 in sparse/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_common.py#L2189

Added line #L2189 was not covered by tests
f"Sparse {func.__name__} received dense NumPy array instead "
"of sparse array. Dispatching to NumPy function."
)
return getattr(np, func.__name__)(*args, **kwargs)

Check warning on line 2193 in sparse/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_common.py#L2193

Added line #L2193 was not covered by tests
else:
return func(*args, **kwargs)

return wrapper_func


def all(x, /, *, axis=None, keepdims=False):
return x.all(axis=axis, keepdims=keepdims)


def any(x, /, *, axis=None, keepdims=False):
return x.any(axis=axis, keepdims=keepdims)

Check warning on line 2205 in sparse/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_common.py#L2205

Added line #L2205 was not covered by tests


def permute_dims(x, /, axes=None):
return x.transpose(axes=axes)

Check warning on line 2209 in sparse/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_common.py#L2209

Added line #L2209 was not covered by tests


def max(x, /, *, axis=None, keepdims=False):
return x.max(axis=axis, keepdims=keepdims)


def mean(x, /, *, axis=None, keepdims=False, dtype=None):
return x.mean(axis=axis, keepdims=keepdims, dtype=dtype)


def min(x, /, *, axis=None, keepdims=False):
return x.min(axis=axis, keepdims=keepdims)


def prod(x, /, *, axis=None, dtype=None, keepdims=False):
return x.prod(axis=axis, keepdims=keepdims, dtype=dtype)


def std(x, /, *, axis=None, correction=0.0, keepdims=False):
return x.std(axis=axis, ddof=correction, keepdims=keepdims)


def sum(x, /, *, axis=None, dtype=None, keepdims=False):
return x.sum(axis=axis, keepdims=keepdims, dtype=dtype)


def var(x, /, *, axis=None, correction=0.0, keepdims=False):
return x.var(axis=axis, ddof=correction, keepdims=keepdims)


def abs(x, /):
return x.__abs__()

Check warning on line 2241 in sparse/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_common.py#L2241

Added line #L2241 was not covered by tests


def reshape(x, /, shape, *, copy=None):
return x.reshape(shape=shape)


def astype(x, dtype, /, *, copy=True):
return x.astype(dtype, copy=copy)

Check warning on line 2249 in sparse/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_common.py#L2249

Added line #L2249 was not covered by tests


@_support_numpy
def broadcast_to(x, /, shape):
return x.broadcast_to(shape)


def broadcast_arrays(*arrays):
shape = np.broadcast_shapes(*[a.shape for a in arrays])
return [a.broadcast_to(shape) for a in arrays]

Check warning on line 2259 in sparse/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_common.py#L2258-L2259

Added lines #L2258 - L2259 were not covered by tests


def equal(x1, x2, /):
return x1 == x2

Check warning on line 2263 in sparse/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_common.py#L2263

Added line #L2263 was not covered by tests


@_support_numpy
def round(x, /, decimals=0, out=None):
return x.round(decimals=decimals, out=out)


@_support_numpy
def isinf(x, /):
return x.isinf()

Check warning on line 2273 in sparse/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_common.py#L2273

Added line #L2273 was not covered by tests


@_support_numpy
def isnan(x, /):
return x.isnan()

Check warning on line 2278 in sparse/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_common.py#L2278

Added line #L2278 was not covered by tests


def isfinite(x, /):
return ~isinf(x)

Check warning on line 2282 in sparse/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_common.py#L2282

Added line #L2282 was not covered by tests


def nonzero(x, /):
return x.nonzero()

Check warning on line 2286 in sparse/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_common.py#L2286

Added line #L2286 was not covered by tests
Loading

0 comments on commit 1d6943c

Please sign in to comment.