Skip to content

Commit

Permalink
Merge pull request #35 from willow-ahrens/more-array-api-funcs
Browse files Browse the repository at this point in the history
API: More elemwise and reduction functions
  • Loading branch information
mtsokol committed Apr 26, 2024
2 parents 27ceecc + 0a31757 commit 874fd1b
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "finch-tensor"
version = "0.1.11"
version = "0.1.12"
description = ""
authors = ["Willow Ahrens <willow.marie.ahrens@gmail.com>"]
readme = "README.md"
Expand Down
30 changes: 30 additions & 0 deletions src/finch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
multiply,
sum,
prod,
max,
min,
all,
any,
add,
subtract,
divide,
Expand All @@ -47,6 +51,17 @@
atan,
atanh,
atan2,
log,
log10,
log1p,
log2,
sqrt,
exp,
expm1,
sign,
round,
floor,
ceil,
)
from .compiled import (
lazy,
Expand Down Expand Up @@ -116,6 +131,10 @@
"compute",
"sum",
"prod",
"max",
"min",
"all",
"any",
"add",
"subtract",
"multiply",
Expand All @@ -138,4 +157,15 @@
"atan",
"atanh",
"atan2",
"log",
"log10",
"log1p",
"log2",
"sqrt",
"exp",
"expm1",
"sign",
"round",
"floor",
"ceil",
]
2 changes: 1 addition & 1 deletion src/finch/julia.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import juliapkg

_FINCH_VERSION = "0.6.21"
_FINCH_VERSION = "0.6.22"
_FINCH_HASH = "9177782c-1635-4eb9-9bfb-d9dfa25e6bce"

deps = juliapkg.deps.load_cur_deps()
Expand Down
94 changes: 91 additions & 3 deletions src/finch/tensor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import builtins
from typing import Callable, Iterable, Optional, Union, Literal

import numpy as np
Expand Down Expand Up @@ -83,6 +84,9 @@ def __init__(
*,
fill_value: np.number = 0.0,
):
if isinstance(obj, (int, float, complex, bool)):
obj = np.array(obj)

if _is_scipy_sparse_obj(obj): # scipy constructor
jl_data = self._from_scipy_sparse(obj)
self._obj = jl_data
Expand Down Expand Up @@ -232,11 +236,11 @@ def preprocess_order(
elif order == cls.row_major or order is None:
permutation = tuple(range(1, ndim + 1)[::-1])
elif isinstance(order, tuple):
if min(order) == 0:
if builtins.min(order) == 0:
order = tuple(i + 1 for i in order)
if (
len(order) == ndim and
all([i in order for i in range(1, ndim + 1)])
builtins.all([i in order for i in range(1, ndim + 1)])
):
permutation = order
else:
Expand Down Expand Up @@ -550,7 +554,7 @@ def nonzero(x: Tensor, /) -> tuple[np.ndarray, ...]:
return tuple(i[sort_order] for i in indices)


def _reduce(x: Tensor, fn: Callable, axis, dtype):
def _reduce(x: Tensor, fn: Callable, axis, dtype = None):
if axis is not None:
axis = normalize_axis_tuple(axis, x.ndim)
axis = tuple(i + 1 for i in axis)
Expand Down Expand Up @@ -587,6 +591,46 @@ def prod(
return _reduce(x, jl.prod, axis, dtype)


def max(
x: Tensor,
/,
*,
axis: Union[int, tuple[int, ...], None] = None,
keepdims: bool = False,
) -> Tensor:
return _reduce(x, jl.maximum, axis)


def min(
x: Tensor,
/,
*,
axis: Union[int, tuple[int, ...], None] = None,
keepdims: bool = False,
) -> Tensor:
return _reduce(x, jl.minimum, axis)


def any(
x: Tensor,
/,
*,
axis: Union[int, tuple[int, ...], None] = None,
keepdims: bool = False,
) -> Tensor:
return _reduce(x, jl.any, axis)


def all(
x: Tensor,
/,
*,
axis: Union[int, tuple[int, ...], None] = None,
keepdims: bool = False,
) -> Tensor:
return _reduce(x, jl.all, axis)


def eye(
n_rows: int,
n_cols: Optional[int] = None,
Expand Down Expand Up @@ -657,6 +701,50 @@ def pow(x1: Tensor, x2: Tensor, /) -> Tensor:
return x1 ** x2


def log(x: Tensor, /) -> Tensor:
return x._elemwise_op("log")


def log10(x: Tensor, /) -> Tensor:
return x._elemwise_op("log10")


def log1p(x: Tensor, /) -> Tensor:
return x._elemwise_op("log1p")


def log2(x: Tensor, /) -> Tensor:
return x._elemwise_op("log2")


def sqrt(x: Tensor, /) -> Tensor:
return x._elemwise_op("sqrt")


def sign(x: Tensor, /) -> Tensor:
return x._elemwise_op("sign")


def round(x: Tensor, /) -> Tensor:
return x._elemwise_op("round")


def exp(x: Tensor, /) -> Tensor:
return x._elemwise_op("exp")


def expm1(x: Tensor, /) -> Tensor:
return x._elemwise_op("expm1")


def floor(x: Tensor, /) -> Tensor:
return x._elemwise_op("floor")


def ceil(x: Tensor, /) -> Tensor:
return x._elemwise_op("ceil")


def positive(x: Tensor, /) -> Tensor:
return +x

Expand Down
29 changes: 22 additions & 7 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from numpy.testing import assert_equal
from numpy.testing import assert_equal, assert_allclose
import pytest

import finch
Expand Down Expand Up @@ -45,10 +45,23 @@ def my_custom_fun(arr1, arr2, arr3):
assert_equal(result.todense(), np.multiply(arr3d, arr2d))


@pytest.mark.parametrize(
"func_name", ["log", "log10", "log1p", "log2", "sqrt", "sign", "round", "exp", "expm1", "floor", "ceil"],
)
def test_elemwise_ops_1_arg(arr3d, func_name):
arr = arr3d + 1.6
A_finch = finch.Tensor(arr)

actual = getattr(finch, func_name)(A_finch)
expected = getattr(np, func_name)(arr)

assert_allclose(actual.todense(), expected)


@pytest.mark.parametrize(
"meth_name", ["__pos__", "__neg__", "__abs__"],
)
def test_elemwise_ops_1_arg(arr3d, meth_name):
def test_elemwise_tensor_ops_1_arg(arr3d, meth_name):
A_finch = finch.Tensor(arr3d)

actual = getattr(A_finch, meth_name)()
Expand All @@ -62,7 +75,7 @@ def test_elemwise_ops_1_arg(arr3d, meth_name):
["__add__", "__mul__", "__sub__", "__truediv__", # "__floordiv__", "__mod__",
"__pow__", "__and__", "__or__", "__xor__", "__lshift__", "__rshift__"],
)
def test_elemwise_ops_2_args(arr3d, meth_name):
def test_elemwise_tensor_ops_2_args(arr3d, meth_name):
arr2d = np.array([[2, 3, 2, 3], [3, 2, 3, 2]])
A_finch = finch.Tensor(arr3d)
B_finch = finch.Tensor(arr2d)
Expand All @@ -73,14 +86,16 @@ def test_elemwise_ops_2_args(arr3d, meth_name):
assert_equal(actual.todense(), expected)


@pytest.mark.parametrize("func_name", ["sum", "prod"])
@pytest.mark.parametrize("func_name", ["sum", "prod", "max", "min", "any", "all"])
@pytest.mark.parametrize("axis", [None, -1, 1, (0, 1), (0, 1, 2)])
@pytest.mark.parametrize("dtype", [None])
@pytest.mark.parametrize("dtype", [None]) # not supported yet
def test_reductions(arr3d, func_name, axis, dtype):
if func_name in ("any", "all"):
arr3d = arr3d.astype(bool)
A_finch = finch.Tensor(arr3d)

actual = getattr(finch, func_name)(A_finch, axis=axis, dtype=dtype)
expected = getattr(np, func_name)(arr3d, axis=axis, dtype=dtype)
actual = getattr(finch, func_name)(A_finch, axis=axis)
expected = getattr(np, func_name)(arr3d, axis=axis)

if isinstance(actual, finch.Tensor):
actual = actual.todense()
Expand Down
11 changes: 9 additions & 2 deletions tests/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
import finch


@pytest.mark.parametrize("dtype", [np.int64, np.float64, np.complex128])
@pytest.mark.parametrize(
"dtype,jl_dtype",
[(np.int64, finch.int64), (np.float64, finch.float64), (np.complex128, finch.complex128)]
)
@pytest.mark.parametrize("order", ["C", "F", None])
def test_wrappers(dtype, order):
def test_wrappers(dtype, jl_dtype, order):
A = np.array([[0, 0, 4], [1, 0, 0], [2, 0, 5], [3, 0, 0]], dtype=dtype, order=order)
B = np.array(np.stack([A, A], axis=2, dtype=dtype), order=order)

Expand All @@ -19,13 +22,17 @@ def test_wrappers(dtype, order):
)
B_finch = B_finch.to_device(storage)

assert B_finch.shape == B.shape
assert B_finch.dtype == jl_dtype
assert_equal(B_finch.todense(), B)

storage = finch.Storage(
finch.Dense(finch.Dense(finch.Element(dtype(1.0)))), order=order
)
A_finch = finch.Tensor(A).to_device(storage)

assert A_finch.shape == A.shape
assert A_finch.dtype == jl_dtype
assert_equal(A_finch.todense(), A)
assert A_finch.todense().dtype == A.dtype and B_finch.todense().dtype == B.dtype

Expand Down

0 comments on commit 874fd1b

Please sign in to comment.