Skip to content

Commit

Permalink
Merge pull request #23 from mtsokol/finch-random
Browse files Browse the repository at this point in the history
API: Add `finch.random` function
  • Loading branch information
mtsokol committed Mar 22, 2024
2 parents 86f4429 + 06bd810 commit d1b33c2
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "finch-tensor"
version = "0.1.6"
version = "0.1.7"
description = ""
authors = ["Willow Ahrens <willow.marie.ahrens@gmail.com>"]
readme = "README.md"
Expand Down
4 changes: 2 additions & 2 deletions src/finch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .tensor import (
Tensor,
astype,
fsprand,
random,
permute_dims,
multiply,
sum,
Expand Down Expand Up @@ -64,7 +64,7 @@
"Storage",
"DenseStorage",
"astype",
"fsprand",
"random",
"permute_dims",
"int_",
"int8",
Expand Down
1 change: 1 addition & 0 deletions src/finch/julia.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from juliacall import Main as jl # noqa

jl.seval("using Finch")
jl.seval("using Random: default_rng")
10 changes: 9 additions & 1 deletion src/finch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,15 @@ def construct_csf(cls, arg: TupleOf3Arrays, shape: tuple[int, ...]) -> "Tensor":
return Tensor(cls.construct_csf_jl_object(arg, shape))


def fsprand(*args):
def random(shape, density=0.01, random_state=None):
args = [*shape, density]
if random_state is not None:
if isinstance(random_state, np.random.Generator):
seed = random_state.integers(np.iinfo(np.int32).max)
else:
seed = random_state
rng = jl.default_rng(seed)
args = [rng] + args
return Tensor(jl.fsprand(*args))


Expand Down
8 changes: 8 additions & 0 deletions tests/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,11 @@ def test_astype(arr3d, order):

with pytest.raises(ValueError, match="Unable to avoid a copy while casting in no-copy mode."):
finch.astype(arr_finch, finch.float64, copy=False)


@pytest.mark.parametrize("random_state", [42, np.random.default_rng(42)])
def test_random(random_state):
result = finch.random((10, 20, 30), density=0.0, random_state=random_state)
expected = sparse.random((10, 20, 30), density=0.0, random_state=random_state)

assert_equal(result.todense(), expected.todense())

0 comments on commit d1b33c2

Please sign in to comment.