Skip to content

Commit

Permalink
Merge pull request #17 from mtsokol/lazy-api
Browse files Browse the repository at this point in the history
API: Lazy API
  • Loading branch information
mtsokol committed Mar 19, 2024
2 parents 1bf21a2 + 05705e0 commit 86f4429
Show file tree
Hide file tree
Showing 5 changed files with 343 additions and 13 deletions.
26 changes: 26 additions & 0 deletions src/finch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,20 @@
astype,
fsprand,
permute_dims,
multiply,
sum,
prod,
add,
subtract,
multiply,
divide,
positive,
negative,
)
from .compiled import (
lazy,
compiled,
compute,
)
from .dtypes import (
int_,
Expand Down Expand Up @@ -68,4 +82,16 @@
"complex64",
"complex128",
"bool",
"multiply",
"lazy",
"compiled",
"compute",
"sum",
"prod",
"add",
"subtract",
"multiply",
"divide",
"positive",
"negative",
]
34 changes: 34 additions & 0 deletions src/finch/compiled.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from functools import wraps

from .julia import jl
from .tensor import Tensor


def compiled(func):
@wraps(func)
def wrapper_func(*args, **kwargs):
new_args = []
for arg in args:
if isinstance(arg, Tensor) and not jl.isa(arg._obj, jl.Finch.LazyTensor):
new_args.append(Tensor(jl.Finch.LazyTensor(arg._obj)))
else:
new_args.append(arg)

result = func(*new_args, **kwargs)
result_tensor = Tensor(jl.Finch.compute(result._obj))

return result_tensor

return wrapper_func


def lazy(tensor: Tensor):
if tensor.is_computed():
return Tensor(jl.Finch.LazyTensor(tensor._obj))
return tensor


def compute(tensor: Tensor):
if not tensor.is_computed():
return Tensor(jl.Finch.compute(tensor._obj))
return tensor
2 changes: 1 addition & 1 deletion src/finch/julia.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import juliapkg

juliapkg.add("Finch", "9177782c-1635-4eb9-9bfb-d9dfa25e6bce", version="0.6.16")
juliapkg.add("Finch", "9177782c-1635-4eb9-9bfb-d9dfa25e6bce", version="0.6.19")
import juliacall # noqa

juliapkg.resolve()
Expand Down
206 changes: 194 additions & 12 deletions src/finch/tensor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Union
from typing import Callable, Optional, Union

import numpy as np
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
Expand All @@ -21,7 +21,7 @@ class Tensor(_Display):
Tensor(Storage)
Initialize a Tensor with a `storage` description. `storage` can already hold data.
Tensor(julia_object)
Tensor created from a compatible raw Julia object. Must be a `SwizzleArray`.
Tensor created from a compatible raw Julia object. Must be a `SwizzleArray` or `LazyTensor`.
This is a no-copy operation.
Parameters
Expand Down Expand Up @@ -57,8 +57,8 @@ class Tensor(_Display):
array([[0, 1, 2],
[3, 4, 5]])
"""
row_major = "C"
column_major = "F"
row_major: str = "C"
column_major: str = "F"

def __init__(
self,
Expand All @@ -76,27 +76,81 @@ def __init__(
elif isinstance(obj, Storage): # from-storage constructor
order = self.preprocess_order(obj.order, self.get_lvl_ndim(obj.levels_descr._obj))
self._obj = jl.swizzle(jl.Tensor(obj.levels_descr._obj), *order)
elif jl.isa(obj, jl.Finch.SwizzleArray): # raw-Julia-object constructor
elif jl.isa(obj, jl.Finch.Tensor): # raw-Julia-object constructors
self._obj = jl.swizzle(obj, *tuple(range(1, jl.ndims(obj) + 1)))
elif jl.isa(obj, jl.Finch.SwizzleArray) or jl.isa(obj, jl.Finch.LazyTensor):
self._obj = obj
else:
raise ValueError(
"Either `arr`, `storage` or a raw julia object should be provided."
)

def __pos__(self):
return Tensor(jl.Base.broadcast(jl.seval("+"), self._obj))
return self._elemwise_op("+")

def __neg__(self):
return self._elemwise_op("-")

def __add__(self, other):
return Tensor(jl.Base.broadcast(jl.seval("+"), self._obj, other._obj))
return self._elemwise_op(".+", other)

def __mul__(self, other):
return Tensor(jl.Base.broadcast(jl.seval("*"), self._obj, other._obj))
return self._elemwise_op(".*", other)

def __sub__(self, other):
return Tensor(jl.Base.broadcast(jl.seval("-"), self._obj, other._obj))
return self._elemwise_op(".-", other)

def __truediv__(self, other):
return Tensor(jl.Base.broadcast(jl.seval("/"), self._obj, other._obj))
return self._elemwise_op("./", other)

def __floordiv__(self, other):
return self._elemwise_op(".//", other)

def __mod__(self, other):
return self._elemwise_op("rem", other)

def __pow__(self, other):
return self._elemwise_op(".^", other)

def __matmul__(self, other):
raise NotImplementedError

def __abs__(self):
return self._elemwise_op("abs")

def __invert__(self):
return self._elemwise_op("~")

def __and__(self, other):
return self._elemwise_op("&", other)

def __or__(self, other):
return self._elemwise_op("|", other)

def __xor__(self, other):
return self._elemwise_op("xor", other)

def __lshift__(self, other):
return self._elemwise_op("<<", other)

def __rshift__(self, other):
return self._elemwise_op(">>", other)

def _elemwise_op(self, op: str, other: Optional["Tensor"] = None) -> "Tensor":
if other is None:
result = jl.broadcast(jl.seval(op), self._obj)
else:
axis_x1, axis_x2 = range(self.ndim, 0, -1), range(other.ndim, 0, -1)
# inverse swizzle, so `broadcast` appends new dims to the front
result = jl.broadcast(
jl.seval(op),
jl.permutedims(self._obj, tuple(axis_x1)),
jl.permutedims(other._obj, tuple(axis_x2)),
)
# swizzle back to the original order
result = jl.permutedims(result, tuple(range(jl.ndims(result), 0, -1)))

return Tensor(result)

def __getitem__(self, key):
if not isinstance(key, tuple):
Expand All @@ -106,7 +160,7 @@ def __getitem__(self, key):
key = _add_plus_one(key, self.shape)

result = self._obj[key]
if jl.isa(result, jl.Finch.SwizzleArray):
if jl.isa(result, jl.Finch.SwizzleArray) or jl.isa(result, jl.Finch.LazyTensor):
return Tensor(result)
elif jl.isa(result, jl.Finch.Tensor):
return Tensor(jl.swizzle(result, *range(1, jl.ndims(result) + 1)))
Expand Down Expand Up @@ -142,6 +196,9 @@ def _is_dense(self) -> bool:
def _order(self) -> tuple[int, ...]:
return jl.typeof(self._obj).parameters[1]

def is_computed(self) -> bool:
return not jl.isa(self._obj, jl.Finch.LazyTensor)

@classmethod
def preprocess_order(
cls, order: OrderType, ndim: int
Expand Down Expand Up @@ -339,7 +396,7 @@ def construct_csf(cls, arg: TupleOf3Arrays, shape: tuple[int, ...]) -> "Tensor":
return Tensor(cls.construct_csf_jl_object(arg, shape))


def fsprand(*args, order=None):
def fsprand(*args):
return Tensor(jl.fsprand(*args))


Expand All @@ -361,6 +418,131 @@ def astype(x: Tensor, dtype: jl.DataType, /, *, copy: bool = True):
return Tensor(jl.swizzle(result, *x.get_order(zero_indexing=False)))


def _reduce(x: Tensor, fn: Callable, axis, dtype):
if axis is not None:
axis = normalize_axis_tuple(axis, x.ndim)
axis = tuple(i + 1 for i in axis)
result = fn(x._obj, dims=axis)
else:
result = fn(x._obj)

if jl.isa(result, jl.Finch.Tensor) or jl.isa(result, jl.Finch.LazyTensor):
result = Tensor(result)
else:
result = np.array(result)
return result


def sum(
x: Tensor,
/,
*,
axis: Union[int, tuple[int, ...], None] = None,
dtype: Union[jl.DataType, None] = None,
keepdims: bool = False,
) -> Tensor:
return _reduce(x, jl.sum, axis, dtype)


def prod(
x: Tensor,
/,
*,
axis: Union[int, tuple[int, ...], None] = None,
dtype: Union[jl.DataType, None] = None,
keepdims: bool = False,
) -> Tensor:
return _reduce(x, jl.prod, axis, dtype)


def add(x1: Tensor, x2: Tensor, /) -> Tensor:
return x1 + x2


def subtract(x1: Tensor, x2: Tensor, /) -> Tensor:
return x1 - x2


def multiply(x1: Tensor, x2: Tensor, /) -> Tensor:
return x1 * x2


def divide(x1: Tensor, x2: Tensor, /) -> Tensor:
return x1 / x2


def floor_divide(x1: Tensor, x2: Tensor, /) -> Tensor:
return x1 // x2


def pow(x1: Tensor, x2: Tensor, /) -> Tensor:
return x1 ** x2


def positive(x: Tensor, /) -> Tensor:
return +x


def negative(x: Tensor, /) -> Tensor:
return -x


def abs(x: Tensor, /) -> Tensor:
return x.__abs__()


def cos(x: Tensor, /) -> Tensor:
return x._elemwise_op("cos")


def cosh(x: Tensor, /) -> Tensor:
return x._elemwise_op("cosh")


def acos(x: Tensor, /) -> Tensor:
return x._elemwise_op("acos")


def acosh(x: Tensor, /) -> Tensor:
return x._elemwise_op("acosh")


def sin(x: Tensor, /) -> Tensor:
return x._elemwise_op("sin")


def sinh(x: Tensor, /) -> Tensor:
return x._elemwise_op("sinh")


def asin(x: Tensor, /) -> Tensor:
return x._elemwise_op("asin")


def asinh(x: Tensor, /) -> Tensor:
return x._elemwise_op("asinh")


def tan(x: Tensor, /) -> Tensor:
return x._elemwise_op("tan")


def tanh(x: Tensor, /) -> Tensor:
return x._elemwise_op("tanh")


def atan(x: Tensor, /) -> Tensor:
return x._elemwise_op("atan")


def atanh(x: Tensor, /) -> Tensor:
return x._elemwise_op("atanh")


def atan2(x: Tensor, other: Tensor, /) -> Tensor:
return x._elemwise_op("atand", other)


def _is_scipy_sparse_obj(x):
return hasattr(x, "__module__") and x.__module__.startswith("scipy.sparse")

Expand Down
Loading

0 comments on commit 86f4429

Please sign in to comment.