From 05705e0b3b93736442875851dc25c5428fcb2064 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sok=C3=B3=C5=82?= Date: Tue, 5 Mar 2024 15:46:29 +0100 Subject: [PATCH] API: Lazy API draft --- src/finch/__init__.py | 26 ++++++ src/finch/compiled.py | 34 +++++++ src/finch/julia.py | 2 +- src/finch/tensor.py | 206 +++++++++++++++++++++++++++++++++++++++--- tests/test_ops.py | 88 ++++++++++++++++++ 5 files changed, 343 insertions(+), 13 deletions(-) create mode 100644 src/finch/compiled.py create mode 100644 tests/test_ops.py diff --git a/src/finch/__init__.py b/src/finch/__init__.py index 0232cfa..bdfc2c1 100644 --- a/src/finch/__init__.py +++ b/src/finch/__init__.py @@ -16,6 +16,20 @@ astype, fsprand, permute_dims, + multiply, + sum, + prod, + add, + subtract, + multiply, + divide, + positive, + negative, +) +from .compiled import ( + lazy, + compiled, + compute, ) from .dtypes import ( int_, @@ -68,4 +82,16 @@ "complex64", "complex128", "bool", + "multiply", + "lazy", + "compiled", + "compute", + "sum", + "prod", + "add", + "subtract", + "multiply", + "divide", + "positive", + "negative", ] diff --git a/src/finch/compiled.py b/src/finch/compiled.py new file mode 100644 index 0000000..bae0d20 --- /dev/null +++ b/src/finch/compiled.py @@ -0,0 +1,34 @@ +from functools import wraps + +from .julia import jl +from .tensor import Tensor + + +def compiled(func): + @wraps(func) + def wrapper_func(*args, **kwargs): + new_args = [] + for arg in args: + if isinstance(arg, Tensor) and not jl.isa(arg._obj, jl.Finch.LazyTensor): + new_args.append(Tensor(jl.Finch.LazyTensor(arg._obj))) + else: + new_args.append(arg) + + result = func(*new_args, **kwargs) + result_tensor = Tensor(jl.Finch.compute(result._obj)) + + return result_tensor + + return wrapper_func + + +def lazy(tensor: Tensor): + if tensor.is_computed(): + return Tensor(jl.Finch.LazyTensor(tensor._obj)) + return tensor + + +def compute(tensor: Tensor): + if not tensor.is_computed(): + return Tensor(jl.Finch.compute(tensor._obj)) + return tensor diff --git a/src/finch/julia.py b/src/finch/julia.py index c691d36..5880f7f 100644 --- a/src/finch/julia.py +++ b/src/finch/julia.py @@ -1,6 +1,6 @@ import juliapkg -juliapkg.add("Finch", "9177782c-1635-4eb9-9bfb-d9dfa25e6bce", version="0.6.16") +juliapkg.add("Finch", "9177782c-1635-4eb9-9bfb-d9dfa25e6bce", version="0.6.19") import juliacall # noqa juliapkg.resolve() diff --git a/src/finch/tensor.py b/src/finch/tensor.py index 5321a64..7afdfc2 100644 --- a/src/finch/tensor.py +++ b/src/finch/tensor.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Callable, Optional, Union import numpy as np from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple @@ -21,7 +21,7 @@ class Tensor(_Display): Tensor(Storage) Initialize a Tensor with a `storage` description. `storage` can already hold data. Tensor(julia_object) - Tensor created from a compatible raw Julia object. Must be a `SwizzleArray`. + Tensor created from a compatible raw Julia object. Must be a `SwizzleArray` or `LazyTensor`. This is a no-copy operation. Parameters @@ -57,8 +57,8 @@ class Tensor(_Display): array([[0, 1, 2], [3, 4, 5]]) """ - row_major = "C" - column_major = "F" + row_major: str = "C" + column_major: str = "F" def __init__( self, @@ -76,7 +76,9 @@ def __init__( elif isinstance(obj, Storage): # from-storage constructor order = self.preprocess_order(obj.order, self.get_lvl_ndim(obj.levels_descr._obj)) self._obj = jl.swizzle(jl.Tensor(obj.levels_descr._obj), *order) - elif jl.isa(obj, jl.Finch.SwizzleArray): # raw-Julia-object constructor + elif jl.isa(obj, jl.Finch.Tensor): # raw-Julia-object constructors + self._obj = jl.swizzle(obj, *tuple(range(1, jl.ndims(obj) + 1))) + elif jl.isa(obj, jl.Finch.SwizzleArray) or jl.isa(obj, jl.Finch.LazyTensor): self._obj = obj else: raise ValueError( @@ -84,19 +86,71 @@ def __init__( ) def __pos__(self): - return Tensor(jl.Base.broadcast(jl.seval("+"), self._obj)) + return self._elemwise_op("+") + + def __neg__(self): + return self._elemwise_op("-") def __add__(self, other): - return Tensor(jl.Base.broadcast(jl.seval("+"), self._obj, other._obj)) + return self._elemwise_op(".+", other) def __mul__(self, other): - return Tensor(jl.Base.broadcast(jl.seval("*"), self._obj, other._obj)) + return self._elemwise_op(".*", other) def __sub__(self, other): - return Tensor(jl.Base.broadcast(jl.seval("-"), self._obj, other._obj)) + return self._elemwise_op(".-", other) def __truediv__(self, other): - return Tensor(jl.Base.broadcast(jl.seval("/"), self._obj, other._obj)) + return self._elemwise_op("./", other) + + def __floordiv__(self, other): + return self._elemwise_op(".//", other) + + def __mod__(self, other): + return self._elemwise_op("rem", other) + + def __pow__(self, other): + return self._elemwise_op(".^", other) + + def __matmul__(self, other): + raise NotImplementedError + + def __abs__(self): + return self._elemwise_op("abs") + + def __invert__(self): + return self._elemwise_op("~") + + def __and__(self, other): + return self._elemwise_op("&", other) + + def __or__(self, other): + return self._elemwise_op("|", other) + + def __xor__(self, other): + return self._elemwise_op("xor", other) + + def __lshift__(self, other): + return self._elemwise_op("<<", other) + + def __rshift__(self, other): + return self._elemwise_op(">>", other) + + def _elemwise_op(self, op: str, other: Optional["Tensor"] = None) -> "Tensor": + if other is None: + result = jl.broadcast(jl.seval(op), self._obj) + else: + axis_x1, axis_x2 = range(self.ndim, 0, -1), range(other.ndim, 0, -1) + # inverse swizzle, so `broadcast` appends new dims to the front + result = jl.broadcast( + jl.seval(op), + jl.permutedims(self._obj, tuple(axis_x1)), + jl.permutedims(other._obj, tuple(axis_x2)), + ) + # swizzle back to the original order + result = jl.permutedims(result, tuple(range(jl.ndims(result), 0, -1))) + + return Tensor(result) def __getitem__(self, key): if not isinstance(key, tuple): @@ -106,7 +160,7 @@ def __getitem__(self, key): key = _add_plus_one(key, self.shape) result = self._obj[key] - if jl.isa(result, jl.Finch.SwizzleArray): + if jl.isa(result, jl.Finch.SwizzleArray) or jl.isa(result, jl.Finch.LazyTensor): return Tensor(result) elif jl.isa(result, jl.Finch.Tensor): return Tensor(jl.swizzle(result, *range(1, jl.ndims(result) + 1))) @@ -142,6 +196,9 @@ def _is_dense(self) -> bool: def _order(self) -> tuple[int, ...]: return jl.typeof(self._obj).parameters[1] + def is_computed(self) -> bool: + return not jl.isa(self._obj, jl.Finch.LazyTensor) + @classmethod def preprocess_order( cls, order: OrderType, ndim: int @@ -339,7 +396,7 @@ def construct_csf(cls, arg: TupleOf3Arrays, shape: tuple[int, ...]) -> "Tensor": return Tensor(cls.construct_csf_jl_object(arg, shape)) -def fsprand(*args, order=None): +def fsprand(*args): return Tensor(jl.fsprand(*args)) @@ -361,6 +418,131 @@ def astype(x: Tensor, dtype: jl.DataType, /, *, copy: bool = True): return Tensor(jl.swizzle(result, *x.get_order(zero_indexing=False))) +def _reduce(x: Tensor, fn: Callable, axis, dtype): + if axis is not None: + axis = normalize_axis_tuple(axis, x.ndim) + axis = tuple(i + 1 for i in axis) + result = fn(x._obj, dims=axis) + else: + result = fn(x._obj) + + if jl.isa(result, jl.Finch.Tensor) or jl.isa(result, jl.Finch.LazyTensor): + result = Tensor(result) + else: + result = np.array(result) + return result + + +def sum( + x: Tensor, + /, + *, + axis: Union[int, tuple[int, ...], None] = None, + dtype: Union[jl.DataType, None] = None, + keepdims: bool = False, +) -> Tensor: + return _reduce(x, jl.sum, axis, dtype) + + +def prod( + x: Tensor, + /, + *, + axis: Union[int, tuple[int, ...], None] = None, + dtype: Union[jl.DataType, None] = None, + keepdims: bool = False, +) -> Tensor: + return _reduce(x, jl.prod, axis, dtype) + + +def add(x1: Tensor, x2: Tensor, /) -> Tensor: + return x1 + x2 + + +def subtract(x1: Tensor, x2: Tensor, /) -> Tensor: + return x1 - x2 + + +def multiply(x1: Tensor, x2: Tensor, /) -> Tensor: + return x1 * x2 + + +def divide(x1: Tensor, x2: Tensor, /) -> Tensor: + return x1 / x2 + + +def floor_divide(x1: Tensor, x2: Tensor, /) -> Tensor: + return x1 // x2 + + +def pow(x1: Tensor, x2: Tensor, /) -> Tensor: + return x1 ** x2 + + +def positive(x: Tensor, /) -> Tensor: + return +x + + +def negative(x: Tensor, /) -> Tensor: + return -x + + +def abs(x: Tensor, /) -> Tensor: + return x.__abs__() + + +def cos(x: Tensor, /) -> Tensor: + return x._elemwise_op("cos") + + +def cosh(x: Tensor, /) -> Tensor: + return x._elemwise_op("cosh") + + +def acos(x: Tensor, /) -> Tensor: + return x._elemwise_op("acos") + + +def acosh(x: Tensor, /) -> Tensor: + return x._elemwise_op("acosh") + + +def sin(x: Tensor, /) -> Tensor: + return x._elemwise_op("sin") + + +def sinh(x: Tensor, /) -> Tensor: + return x._elemwise_op("sinh") + + +def asin(x: Tensor, /) -> Tensor: + return x._elemwise_op("asin") + + +def asinh(x: Tensor, /) -> Tensor: + return x._elemwise_op("asinh") + + +def tan(x: Tensor, /) -> Tensor: + return x._elemwise_op("tan") + + +def tanh(x: Tensor, /) -> Tensor: + return x._elemwise_op("tanh") + + +def atan(x: Tensor, /) -> Tensor: + return x._elemwise_op("atan") + + +def atanh(x: Tensor, /) -> Tensor: + return x._elemwise_op("atanh") + + +def atan2(x: Tensor, other: Tensor, /) -> Tensor: + return x._elemwise_op("atand", other) + + def _is_scipy_sparse_obj(x): return hasattr(x, "__module__") and x.__module__.startswith("scipy.sparse") diff --git a/tests/test_ops.py b/tests/test_ops.py new file mode 100644 index 0000000..08338ba --- /dev/null +++ b/tests/test_ops.py @@ -0,0 +1,88 @@ +import numpy as np +from numpy.testing import assert_equal +import pytest + +import finch + + +arr2d = np.array([[1, 2, 0, 0], [0, 1, 0, 1]]) + +arr1d = np.array([1, 1, 2, 3]) + + +def test_eager(arr3d): + A_finch = finch.Tensor(arr3d) + B_finch = finch.Tensor(arr2d) + + result = finch.multiply(A_finch, B_finch) + + assert_equal(result.todense(), np.multiply(arr3d, arr2d)) + + +def test_lazy_mode(arr3d): + A_finch = finch.Tensor(arr3d) + B_finch = finch.Tensor(arr2d) + C_finch = finch.Tensor(arr1d) + + @finch.compiled + def my_custom_fun(arr1, arr2, arr3): + temp = finch.multiply(arr1, arr2) + temp = finch.divide(temp, arr3) + reduced = finch.sum(temp, axis=(0, 1)) + return finch.add(temp, reduced) + + result = my_custom_fun(A_finch, B_finch, C_finch) + + temp = np.divide(np.multiply(arr3d, arr2d), arr1d) + expected = np.add(temp, np.sum(temp, axis=(0, 1))) + assert_equal(result.todense(), expected) + + A_lazy = finch.lazy(A_finch) + B_lazy = finch.lazy(B_finch) + mul_lazy = finch.multiply(A_lazy, B_lazy) + result = finch.compute(mul_lazy) + + assert_equal(result.todense(), np.multiply(arr3d, arr2d)) + + +@pytest.mark.parametrize( + "meth_name", ["__pos__", "__neg__", "__abs__"], +) +def test_elemwise_ops_1_arg(arr3d, meth_name): + A_finch = finch.Tensor(arr3d) + + actual = getattr(A_finch, meth_name)() + expected = getattr(arr3d, meth_name)() + + assert_equal(actual.todense(), expected) + + +@pytest.mark.parametrize( + "meth_name", + ["__add__", "__mul__", "__sub__", "__truediv__", # "__floordiv__", "__mod__", + "__pow__", "__and__", "__or__", "__xor__", "__lshift__", "__rshift__"], +) +def test_elemwise_ops_2_args(arr3d, meth_name): + arr2d = np.array([[2, 3, 2, 3], [3, 2, 3, 2]]) + A_finch = finch.Tensor(arr3d) + B_finch = finch.Tensor(arr2d) + + actual = getattr(A_finch, meth_name)(B_finch) + expected = getattr(arr3d, meth_name)(arr2d) + + assert_equal(actual.todense(), expected) + + +@pytest.mark.parametrize("func_name", ["sum", "prod"]) +@pytest.mark.parametrize("axis", [None, -1, 1, (0, 1), (0, 1, 2)]) +@pytest.mark.parametrize("dtype", [None]) +def test_reductions(arr3d, func_name, axis, dtype): + A_finch = finch.Tensor(arr3d) + + actual = getattr(finch, func_name)(A_finch, axis=axis, dtype=dtype) + expected = getattr(np, func_name)(arr3d, axis=axis, dtype=dtype) + + if isinstance(actual, finch.Tensor): + actual = actual.todense() + + assert_equal(actual, expected)