Skip to content

Commit

Permalink
MAINT: Refine order support
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol committed Jun 24, 2024
1 parent 3571918 commit ccca462
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 16 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "finch-tensor"
version = "0.1.30"
version = "0.1.31"
description = ""
authors = ["Willow Ahrens <willow.marie.ahrens@gmail.com>"]
readme = "README.md"
Expand Down
11 changes: 4 additions & 7 deletions src/finch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,11 +383,12 @@ def to_storage(self, storage: Storage) -> "Tensor":
return Tensor(self._from_other_tensor(self, storage=storage))

@classmethod
def _from_other_tensor(cls, tensor: "Tensor", storage: Storage | None) -> JuliaObj:
def _from_other_tensor(cls, tensor: "Tensor", storage: Storage) -> JuliaObj:
order = cls.preprocess_order(storage.order, tensor.ndim)
return jl.swizzle(
jl.Tensor(storage.levels_descr._obj, tensor._obj.body), *order
result = jl.copyto_b(
jl.swizzle(jl.Tensor(storage.levels_descr._obj), *order), tensor._obj
)
return jl.dropfills(result)

@classmethod
def _from_numpy(cls, arr: np.ndarray, fill_value: np.number, copy: bool | None = None) -> JuliaObj:
Expand Down Expand Up @@ -664,12 +665,8 @@ def asarray(
if format == "coo":
storage = Storage(SparseCOO(tensor.ndim, Element(tensor.fill_value)), order)
elif format == "csr":
if order != (1, 0):
raise ValueError("Invalid order for csr")
storage = Storage(Dense(SparseList(Element(tensor.fill_value))), (2, 1))
elif format == "csc":
if order != (0, 1):
raise ValueError("Invalid order for csc")
storage = Storage(Dense(SparseList(Element(tensor.fill_value))), (1, 2))
elif format == "csf":
storage = Element(tensor.fill_value)
Expand Down
10 changes: 2 additions & 8 deletions tests/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,8 @@ def test_asarray(arr2d, arr3d, order, format):
arr = np.array(arr, order=order)
arr_finch = finch.Tensor(arr)

if (format, order) in [("csr", "F"), ("csc", "C")]:
with pytest.raises(ValueError, match="Invalid order for (csr|csc)"):
finch.asarray(arr_finch, format=format)
else:
result = finch.asarray(arr_finch, format=format)
assert_equal(result.todense(), arr)
result = finch.asarray(arr_finch, format=format)
assert_equal(result.todense(), arr)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -302,8 +298,6 @@ def test_where(order_and_format):
)
def test_nonzero(order, format_shape):
format, shape = format_shape
if (format, order) in [("csr", "F"), ("csc", "C")]:
pytest.skip("invalid format+order")
rng = np.random.default_rng(0)
arr = rng.random(shape)
arr = np.array(arr, order=order)
Expand Down

0 comments on commit ccca462

Please sign in to comment.