Skip to content

Commit

Permalink
Merge pull request #43 from willow-ahrens/scipy-canonical
Browse files Browse the repository at this point in the history
Verify SciPy canonical input
  • Loading branch information
mtsokol committed May 10, 2024
2 parents 0c3a202 + 26134a7 commit 477b5cd
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
14 changes: 12 additions & 2 deletions src/finch/tensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import builtins
from typing import Any, Callable, Optional, Iterable, Literal
import warnings

import numpy as np
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
Expand Down Expand Up @@ -357,12 +358,21 @@ def _from_numpy(cls, arr: np.ndarray, fill_value: np.number) -> JuliaObj:
def from_scipy_sparse(cls, x) -> "Tensor":
if not _is_scipy_sparse_obj(x):
raise ValueError("{x} is not a SciPy sparse object.")
if x.format not in ("coo", "csr", "csc"):
x = x.asformat("coo")
return Tensor(x)

@classmethod
def _from_scipy_sparse(cls, x) -> JuliaObj:
if x.format not in ("coo", "csr", "csc"):
x = x.asformat("coo")
if not x.has_canonical_format:
warnings.warn(
"SciPy sparse input must be in a canonical format. "
"Calling `sum_duplicates`."
)
x = x.copy()
x.sum_duplicates()
assert x.has_canonical_format

if x.format == "coo":
return cls.construct_coo_jl_object(
coords=(x.col, x.row),
Expand Down
12 changes: 12 additions & 0 deletions tests/test_scipy_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,15 @@ def test_from_scipy_sparse(format_with_pattern):

result = finch.Tensor.from_scipy_sparse(sp_arr)
assert pattern in str(result)


@pytest.mark.parametrize("format", ["coo", "bsr"])
def test_non_canonical_format(format):
sp_arr = sp.random(3, 4, density=0.5, format=format)

with pytest.warns(
UserWarning, match="SciPy sparse input must be in a canonical format."
):
finch_arr = finch.asarray(sp_arr)

assert_equal(finch_arr.todense(), sp_arr.toarray())

0 comments on commit 477b5cd

Please sign in to comment.