Skip to content
This repository has been archived by the owner on Aug 26, 2024. It is now read-only.

Commit

Permalink
🤗 Modularized interfaces for #18.
Browse files Browse the repository at this point in the history
This makes type checking much faster.
  • Loading branch information
RenChu Wang committed Jan 2, 2022
1 parent a4d6657 commit 08b349b
Show file tree
Hide file tree
Showing 10 changed files with 285 additions and 302 deletions.
14 changes: 0 additions & 14 deletions koila/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +0,0 @@
from . import constants, gpus
from .errors import UnsupportedError
from .immediate import ImmediateNumber, ImmediateTensor, immediate
from .lazy import DelayedTensor, LazyFunction, LazyTensor, lazy
from .prepasses import CallBack, MetaData, PrePass, PrePassFunc
from .runnables import (
BatchedPair,
BatchInfo,
Runnable,
RunnableTensor,
TensorMixin,
run,
)
from .tensors import TensorLike
2 changes: 2 additions & 0 deletions koila/interfaces/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .runnable import Runnable, RunnableTensor
from .tensorlike import TensorLike
4 changes: 4 additions & 0 deletions koila/interfaces/components/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .arithmetic import Arithmetic
from .datatypes import DataType
from .meminfo import MemoryInfo
from .multidim import MultiDimensional
173 changes: 173 additions & 0 deletions koila/interfaces/components/arithmetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from __future__ import annotations

from abc import abstractmethod
from typing import NoReturn, Protocol, Union, runtime_checkable


Numeric = Union[int, float, bool]


@runtime_checkable
class Arithmetic(Protocol):
"""
Arithmetic is a type that supports arithmetic operations.
Operations such as +-*/ etc are considered arithmetic, basically everything that can be used on a scalar.
Inheriting this class, requires half of the methods to be overwritten.
For example, either overload `add` or `__add__`.
If `__add__` is overwritten, `add` is implemented automatically using `__add__`, and vice versa.
The only exception is `eq` and `ne`. They must be manually implemented.
"""

def __invert__(self) -> Arithmetic:
return self.logical_not()

@abstractmethod
def logical_not(self) -> Arithmetic:
...

def __pos__(self) -> Arithmetic:
return self.pos()

def pos(self) -> Arithmetic:
return +self

def __neg__(self) -> Arithmetic:
return self.neg()

def neg(self) -> Arithmetic:
return -self

def __add__(self, other: Arithmetic) -> Arithmetic:
return Arithmetic.add(self, other)

def __radd__(self, other: Arithmetic) -> Arithmetic:
return Arithmetic.add(other, self)

def add(self: Arithmetic, other: Arithmetic) -> Arithmetic:
return self + other

def __sub__(self, other: Arithmetic) -> Arithmetic:
return Arithmetic.sub(self, other)

def __rsub__(self, other: Arithmetic) -> Arithmetic:
return Arithmetic.sub(other, self)

def sub(self: Arithmetic, other: Arithmetic) -> Arithmetic:
return self - other

subtract = sub

def __mul__(self, other: Arithmetic) -> Arithmetic:
return Arithmetic.mul(self, other)

def __rmul__(self, other: Arithmetic) -> Arithmetic:
return Arithmetic.mul(other, self)

def mul(self: Arithmetic, other: Arithmetic) -> Arithmetic:
return self * other

multiply = mul

def __truediv__(self, other: Arithmetic) -> Arithmetic:
return self.div(other)

def __rtruediv__(self, other: Arithmetic) -> Arithmetic:
return other.div(self)

def __floordiv__(self, other: Arithmetic) -> Arithmetic:
raise NotImplementedError

def __rfloordiv__(self, other: Arithmetic) -> Arithmetic:
raise NotImplementedError

def div(self: Arithmetic, other: Arithmetic) -> Arithmetic:
return self / other

divide = truediv = div

def __pow__(self, other: Arithmetic) -> Arithmetic:
return self.pow(other)

def __rpow__(self, other: Arithmetic) -> Arithmetic:
return Arithmetic.pow(other, self)

def pow(self: Arithmetic, other: Arithmetic) -> Arithmetic:
return self ** other

def __mod__(self, other: Arithmetic) -> Arithmetic:
return self.mod(other)

def __rmod__(self, other: Arithmetic) -> Arithmetic:
return other.mod(self)

def mod(self, other: Arithmetic) -> Arithmetic:
return self % other

fmod = remainder = mod

def __divmod__(self, other: Arithmetic) -> NoReturn:
raise NotImplementedError

def __rdivmod__(self, other: Arithmetic) -> NoReturn:
raise NotImplementedError

def __abs__(self) -> Arithmetic:
return self.abs()

def abs(self) -> Arithmetic:
return abs(self)

def __hash__(self) -> int:
return id(self)

def __matmul__(self, other: Arithmetic) -> Arithmetic:
return self.matmul(other)

def __rmatmul__(self, other: Arithmetic) -> Arithmetic:
return other.matmul(self)

def matmul(self, other: Arithmetic) -> Arithmetic:
return self @ other

def __eq__(self, other: Arithmetic | Numeric) -> Arithmetic | bool:
if not isinstance(other, (Arithmetic, int, float, bool)):
return False
return self.eq(other)

@abstractmethod
def eq(self, other: Arithmetic | Numeric) -> Arithmetic:
...

def __ne__(self, other: Arithmetic | Numeric) -> Arithmetic | bool:
if not isinstance(other, (Arithmetic, int, float, bool)):
return True
return self.ne(other)

@abstractmethod
def ne(self, other: Arithmetic | Numeric) -> Arithmetic:
...

def __gt__(self, other: Arithmetic | Numeric) -> Arithmetic:
return self.gt(other)

def gt(self, other: Arithmetic | Numeric) -> Arithmetic:
return self > other

def __ge__(self, other: Arithmetic | Numeric) -> Arithmetic:
return self.ge(other)

def ge(self, other: Arithmetic | Numeric) -> Arithmetic:
return self >= other

def __lt__(self, other: Arithmetic | Numeric) -> Arithmetic:
return self.lt(other)

def lt(self, other: Arithmetic | Numeric) -> Arithmetic:
return self < other

def __le__(self, other: Arithmetic | Numeric) -> Arithmetic:
return self.le(other)

def le(self, other: Arithmetic | Numeric) -> Arithmetic:
return self <= other
11 changes: 11 additions & 0 deletions koila/interfaces/components/datatypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from __future__ import annotations

from typing import Protocol

from torch import device as Device
from torch import dtype as DType


class DataType(Protocol):
dtype: DType
device: str | Device
8 changes: 8 additions & 0 deletions koila/interfaces/components/meminfo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from typing import Protocol

from .datatypes import DataType
from .multidim import MultiDimensional


class MemoryInfo(MultiDimensional, DataType, Protocol):
pass
41 changes: 41 additions & 0 deletions koila/interfaces/components/multidim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations

import functools
import operator
from abc import abstractmethod
from typing import Protocol, Tuple, overload


class MultiDimensional(Protocol):
def __len__(self) -> int:
return self.size(0)

def dim(self) -> int:
return len(self.size())

@property
def ndim(self) -> int:
return self.dim()

ndimension = ndim

@overload
@abstractmethod
def size(self) -> Tuple[int, ...]:
...

@overload
@abstractmethod
def size(self, dim: int) -> int:
...

@abstractmethod
def size(self, dim: int | None = None) -> int | Tuple[int, ...]:
...

@property
def shape(self) -> Tuple[int, ...]:
return self.size()

def numel(self) -> int:
return functools.reduce(operator.mul, self.shape, 1)
20 changes: 4 additions & 16 deletions koila/tensors/runnables.py → koila/interfaces/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from abc import abstractmethod
from typing import Callable, NamedTuple, Protocol, TypeVar, runtime_checkable

from .tensors import TensorLike
from .components import MemoryInfo
from .tensorlike import TensorLike

T = TypeVar("T", covariant=True)

Expand All @@ -15,21 +16,8 @@ def run(self) -> T:
...


class BatchedPair(NamedTuple):
batch: int
no_batch: int


class BatchInfo(NamedTuple):
index: int
value: int

def map(self, func: Callable[[int], int]) -> BatchInfo:
index = func(self.index)
return BatchInfo(index, self.value)


class RunnableTensor(Runnable[TensorLike], Protocol):
@runtime_checkable
class RunnableTensor(Runnable[TensorLike], MemoryInfo, Protocol):
@abstractmethod
def run(self, partial: range | None = None) -> TensorLike:
...
42 changes: 42 additions & 0 deletions koila/interfaces/tensorlike.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from __future__ import annotations

from abc import abstractmethod
from typing import Protocol, TypeVar

from .components import Arithmetic, MemoryInfo

Number = TypeVar("Number", int, float)
Numeric = TypeVar("Numeric", int, float, bool)


class TensorLike(Arithmetic, MemoryInfo, Protocol):
"""
TensorLike is a protocol that mimics PyTorch's Tensor.
"""

data: TensorLike

@abstractmethod
def __str__(self) -> str:
...

def __bool__(self) -> bool:
return bool(self.item())

def __int__(self) -> int:
return int(self.item())

def __float__(self) -> float:
return float(self.item())

@abstractmethod
def item(self) -> Numeric:
...

@abstractmethod
def transpose(self, dim0: int, dim1: int) -> TensorLike:
...

@property
def T(self) -> TensorLike:
return self.transpose(0, 1)
Loading

0 comments on commit 08b349b

Please sign in to comment.