Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API: Add set functions [Array API] #619

Merged
merged 3 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
roll,
tril,
triu,
unique_counts,
unique_values,
where,
)
from ._dok import DOK
Expand Down Expand Up @@ -114,6 +116,8 @@
"min",
"max",
"nanreduce",
"unique_counts",
"unique_values",
]

__array_api_version__ = "2022.12"
4 changes: 4 additions & 0 deletions sparse/_coo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
stack,
tril,
triu,
unique_counts,
unique_values,
where,
)
from .core import COO, as_coo
Expand Down Expand Up @@ -49,4 +51,6 @@
"result_type",
"diagonal",
"diagonalize",
"unique_counts",
"unique_values",
]
110 changes: 107 additions & 3 deletions sparse/_coo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from collections.abc import Iterable
from functools import reduce
from typing import Optional, Tuple
from typing import NamedTuple, Optional, Tuple

import numba

Expand Down Expand Up @@ -1059,6 +1059,106 @@
return a.clip(a_min, a_max)


# Array API set functions


class UniqueCountsResult(NamedTuple):
values: np.ndarray
counts: np.ndarray


def unique_counts(x, /):
"""
Returns the unique elements of an input array `x`, and the corresponding
counts for each unique element in `x`.

Parameters
----------
x : COO
Input COO array. It will be flattened if it is not already 1-D.

Returns
-------
out : namedtuple
The result containing:
* values - The unique elements of an input array.
* counts - The corresponding counts for each unique element.

Raises
------
ValueError
If the input array is in a different format than COO.

Examples
--------
>>> import sparse
>>> x = sparse.COO.from_numpy([1, 0, 2, 1, 2, -3])
>>> sparse.unique_counts(x)
UniqueCountsResult(values=array([-3, 0, 1, 2]), counts=array([1, 1, 2, 2]))
"""
from .core import COO
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Dismissed Show dismissed Hide dismissed

if isinstance(x, scipy.sparse.spmatrix):
x = COO.from_scipy_sparse(x)

Check warning on line 1102 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1102

Added line #L1102 was not covered by tests
elif not isinstance(x, SparseArray):
raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.")
elif not isinstance(x, COO):
x = x.asformat(COO)

Check warning on line 1106 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1106

Added line #L1106 was not covered by tests

x = x.flatten()
values, counts = np.unique(x.data, return_counts=True)
if x.nnz < x.size:
values = np.concatenate([[x.fill_value], values])
counts = np.concatenate([[x.size - x.nnz], counts])
sorted_indices = np.argsort(values)
values[sorted_indices] = values.copy()
counts[sorted_indices] = counts.copy()
hameerabbasi marked this conversation as resolved.
Show resolved Hide resolved

return UniqueCountsResult(values, counts)


def unique_values(x, /):
"""
Returns the unique elements of an input array `x`.

Parameters
----------
x : COO
Input COO array. It will be flattened if it is not already 1-D.

Returns
-------
out : ndarray
The unique elements of an input array.

Raises
------
ValueError
If the input array is in a different format than COO.

Examples
--------
>>> import sparse
>>> x = sparse.COO.from_numpy([1, 0, 2, 1, 2, -3])
>>> sparse.unique_values(x)
array([-3, 0, 1, 2])
"""
from .core import COO
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Dismissed Show dismissed Hide dismissed

if isinstance(x, scipy.sparse.spmatrix):
x = COO.from_scipy_sparse(x)

Check warning on line 1149 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1149

Added line #L1149 was not covered by tests
elif not isinstance(x, SparseArray):
raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.")
elif not isinstance(x, COO):
x = x.asformat(COO)

Check warning on line 1153 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1153

Added line #L1153 was not covered by tests

x = x.flatten()
values = np.unique(x.data)
if x.nnz < x.size:
values = np.sort(np.concatenate([[x.fill_value], values]))
return values


@numba.jit(nopython=True, nogil=True)
def _compute_minmax_args(
coords: np.ndarray,
Expand Down Expand Up @@ -1121,8 +1221,12 @@

from .core import COO

if not isinstance(x, COO):
raise ValueError(f"Only COO arrays are supported but {type(x)} was passed.")
if isinstance(x, scipy.sparse.spmatrix):
x = COO.from_scipy_sparse(x)

Check warning on line 1225 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1225

Added line #L1225 was not covered by tests
elif not isinstance(x, SparseArray):
raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.")

Check warning on line 1227 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1227

Added line #L1227 was not covered by tests
elif not isinstance(x, COO):
x = x.asformat(COO)

Check warning on line 1229 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1229

Added line #L1229 was not covered by tests

if not isinstance(axis, (int, type(None))):
raise ValueError(f"`axis` must be `int` or `None`, but it's: {type(axis)}.")
Expand Down
32 changes: 32 additions & 0 deletions sparse/tests/test_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1745,3 +1745,35 @@ def test_squeeze_validation(self):

with pytest.raises(ValueError, match="Specified axis `0` has a size greater than one: 3"):
s_arr.squeeze(0)


class TestUnique:
hameerabbasi marked this conversation as resolved.
Show resolved Hide resolved
arr = np.array([[0, 0, 1, 5, 3, 0], [1, 0, 4, 0, 3, 0], [0, 1, 0, 1, 1, 0]], dtype=np.int64)
arr_empty = np.zeros((5, 5))
arr_full = np.arange(1, 10)

@pytest.mark.parametrize("arr", [arr, arr_empty, arr_full])
@pytest.mark.parametrize("fill_value", [-1, 0, 1])
def test_unique_counts(self, arr, fill_value):
s_arr = sparse.COO.from_numpy(arr, fill_value)

result_values, result_counts = sparse.unique_counts(s_arr)
expected_values, expected_counts = np.unique(arr, return_counts=True)

np.testing.assert_equal(result_values, expected_values)
np.testing.assert_equal(result_counts, expected_counts)

@pytest.mark.parametrize("arr", [arr, arr_empty, arr_full])
@pytest.mark.parametrize("fill_value", [-1, 0, 1])
def test_unique_values(self, arr, fill_value):
s_arr = sparse.COO.from_numpy(arr, fill_value)

result = sparse.unique_values(s_arr)
expected = np.unique(arr)

np.testing.assert_equal(result, expected)

@pytest.mark.parametrize("func", [sparse.unique_counts, sparse.unique_values])
def test_input_validation(self, func):
with pytest.raises(ValueError, match=r"Input must be an instance of SparseArray"):
func(self.arr)
Loading