This repository has been archived by the owner on Aug 26, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 63
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This makes type checking much faster.
- Loading branch information
RenChu Wang
committed
Jan 2, 2022
1 parent
a4d6657
commit 08b349b
Showing
10 changed files
with
285 additions
and
302 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +0,0 @@ | ||
from . import constants, gpus | ||
from .errors import UnsupportedError | ||
from .immediate import ImmediateNumber, ImmediateTensor, immediate | ||
from .lazy import DelayedTensor, LazyFunction, LazyTensor, lazy | ||
from .prepasses import CallBack, MetaData, PrePass, PrePassFunc | ||
from .runnables import ( | ||
BatchedPair, | ||
BatchInfo, | ||
Runnable, | ||
RunnableTensor, | ||
TensorMixin, | ||
run, | ||
) | ||
from .tensors import TensorLike | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .runnable import Runnable, RunnableTensor | ||
from .tensorlike import TensorLike |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .arithmetic import Arithmetic | ||
from .datatypes import DataType | ||
from .meminfo import MemoryInfo | ||
from .multidim import MultiDimensional |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
from __future__ import annotations | ||
|
||
from abc import abstractmethod | ||
from typing import NoReturn, Protocol, Union, runtime_checkable | ||
|
||
|
||
Numeric = Union[int, float, bool] | ||
|
||
|
||
@runtime_checkable | ||
class Arithmetic(Protocol): | ||
""" | ||
Arithmetic is a type that supports arithmetic operations. | ||
Operations such as +-*/ etc are considered arithmetic, basically everything that can be used on a scalar. | ||
Inheriting this class, requires half of the methods to be overwritten. | ||
For example, either overload `add` or `__add__`. | ||
If `__add__` is overwritten, `add` is implemented automatically using `__add__`, and vice versa. | ||
The only exception is `eq` and `ne`. They must be manually implemented. | ||
""" | ||
|
||
def __invert__(self) -> Arithmetic: | ||
return self.logical_not() | ||
|
||
@abstractmethod | ||
def logical_not(self) -> Arithmetic: | ||
... | ||
|
||
def __pos__(self) -> Arithmetic: | ||
return self.pos() | ||
|
||
def pos(self) -> Arithmetic: | ||
return +self | ||
|
||
def __neg__(self) -> Arithmetic: | ||
return self.neg() | ||
|
||
def neg(self) -> Arithmetic: | ||
return -self | ||
|
||
def __add__(self, other: Arithmetic) -> Arithmetic: | ||
return Arithmetic.add(self, other) | ||
|
||
def __radd__(self, other: Arithmetic) -> Arithmetic: | ||
return Arithmetic.add(other, self) | ||
|
||
def add(self: Arithmetic, other: Arithmetic) -> Arithmetic: | ||
return self + other | ||
|
||
def __sub__(self, other: Arithmetic) -> Arithmetic: | ||
return Arithmetic.sub(self, other) | ||
|
||
def __rsub__(self, other: Arithmetic) -> Arithmetic: | ||
return Arithmetic.sub(other, self) | ||
|
||
def sub(self: Arithmetic, other: Arithmetic) -> Arithmetic: | ||
return self - other | ||
|
||
subtract = sub | ||
|
||
def __mul__(self, other: Arithmetic) -> Arithmetic: | ||
return Arithmetic.mul(self, other) | ||
|
||
def __rmul__(self, other: Arithmetic) -> Arithmetic: | ||
return Arithmetic.mul(other, self) | ||
|
||
def mul(self: Arithmetic, other: Arithmetic) -> Arithmetic: | ||
return self * other | ||
|
||
multiply = mul | ||
|
||
def __truediv__(self, other: Arithmetic) -> Arithmetic: | ||
return self.div(other) | ||
|
||
def __rtruediv__(self, other: Arithmetic) -> Arithmetic: | ||
return other.div(self) | ||
|
||
def __floordiv__(self, other: Arithmetic) -> Arithmetic: | ||
raise NotImplementedError | ||
|
||
def __rfloordiv__(self, other: Arithmetic) -> Arithmetic: | ||
raise NotImplementedError | ||
|
||
def div(self: Arithmetic, other: Arithmetic) -> Arithmetic: | ||
return self / other | ||
|
||
divide = truediv = div | ||
|
||
def __pow__(self, other: Arithmetic) -> Arithmetic: | ||
return self.pow(other) | ||
|
||
def __rpow__(self, other: Arithmetic) -> Arithmetic: | ||
return Arithmetic.pow(other, self) | ||
|
||
def pow(self: Arithmetic, other: Arithmetic) -> Arithmetic: | ||
return self ** other | ||
|
||
def __mod__(self, other: Arithmetic) -> Arithmetic: | ||
return self.mod(other) | ||
|
||
def __rmod__(self, other: Arithmetic) -> Arithmetic: | ||
return other.mod(self) | ||
|
||
def mod(self, other: Arithmetic) -> Arithmetic: | ||
return self % other | ||
|
||
fmod = remainder = mod | ||
|
||
def __divmod__(self, other: Arithmetic) -> NoReturn: | ||
raise NotImplementedError | ||
|
||
def __rdivmod__(self, other: Arithmetic) -> NoReturn: | ||
raise NotImplementedError | ||
|
||
def __abs__(self) -> Arithmetic: | ||
return self.abs() | ||
|
||
def abs(self) -> Arithmetic: | ||
return abs(self) | ||
|
||
def __hash__(self) -> int: | ||
return id(self) | ||
|
||
def __matmul__(self, other: Arithmetic) -> Arithmetic: | ||
return self.matmul(other) | ||
|
||
def __rmatmul__(self, other: Arithmetic) -> Arithmetic: | ||
return other.matmul(self) | ||
|
||
def matmul(self, other: Arithmetic) -> Arithmetic: | ||
return self @ other | ||
|
||
def __eq__(self, other: Arithmetic | Numeric) -> Arithmetic | bool: | ||
if not isinstance(other, (Arithmetic, int, float, bool)): | ||
return False | ||
return self.eq(other) | ||
|
||
@abstractmethod | ||
def eq(self, other: Arithmetic | Numeric) -> Arithmetic: | ||
... | ||
|
||
def __ne__(self, other: Arithmetic | Numeric) -> Arithmetic | bool: | ||
if not isinstance(other, (Arithmetic, int, float, bool)): | ||
return True | ||
return self.ne(other) | ||
|
||
@abstractmethod | ||
def ne(self, other: Arithmetic | Numeric) -> Arithmetic: | ||
... | ||
|
||
def __gt__(self, other: Arithmetic | Numeric) -> Arithmetic: | ||
return self.gt(other) | ||
|
||
def gt(self, other: Arithmetic | Numeric) -> Arithmetic: | ||
return self > other | ||
|
||
def __ge__(self, other: Arithmetic | Numeric) -> Arithmetic: | ||
return self.ge(other) | ||
|
||
def ge(self, other: Arithmetic | Numeric) -> Arithmetic: | ||
return self >= other | ||
|
||
def __lt__(self, other: Arithmetic | Numeric) -> Arithmetic: | ||
return self.lt(other) | ||
|
||
def lt(self, other: Arithmetic | Numeric) -> Arithmetic: | ||
return self < other | ||
|
||
def __le__(self, other: Arithmetic | Numeric) -> Arithmetic: | ||
return self.le(other) | ||
|
||
def le(self, other: Arithmetic | Numeric) -> Arithmetic: | ||
return self <= other |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Protocol | ||
|
||
from torch import device as Device | ||
from torch import dtype as DType | ||
|
||
|
||
class DataType(Protocol): | ||
dtype: DType | ||
device: str | Device |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from typing import Protocol | ||
|
||
from .datatypes import DataType | ||
from .multidim import MultiDimensional | ||
|
||
|
||
class MemoryInfo(MultiDimensional, DataType, Protocol): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from __future__ import annotations | ||
|
||
import functools | ||
import operator | ||
from abc import abstractmethod | ||
from typing import Protocol, Tuple, overload | ||
|
||
|
||
class MultiDimensional(Protocol): | ||
def __len__(self) -> int: | ||
return self.size(0) | ||
|
||
def dim(self) -> int: | ||
return len(self.size()) | ||
|
||
@property | ||
def ndim(self) -> int: | ||
return self.dim() | ||
|
||
ndimension = ndim | ||
|
||
@overload | ||
@abstractmethod | ||
def size(self) -> Tuple[int, ...]: | ||
... | ||
|
||
@overload | ||
@abstractmethod | ||
def size(self, dim: int) -> int: | ||
... | ||
|
||
@abstractmethod | ||
def size(self, dim: int | None = None) -> int | Tuple[int, ...]: | ||
... | ||
|
||
@property | ||
def shape(self) -> Tuple[int, ...]: | ||
return self.size() | ||
|
||
def numel(self) -> int: | ||
return functools.reduce(operator.mul, self.shape, 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from __future__ import annotations | ||
|
||
from abc import abstractmethod | ||
from typing import Protocol, TypeVar | ||
|
||
from .components import Arithmetic, MemoryInfo | ||
|
||
Number = TypeVar("Number", int, float) | ||
Numeric = TypeVar("Numeric", int, float, bool) | ||
|
||
|
||
class TensorLike(Arithmetic, MemoryInfo, Protocol): | ||
""" | ||
TensorLike is a protocol that mimics PyTorch's Tensor. | ||
""" | ||
|
||
data: TensorLike | ||
|
||
@abstractmethod | ||
def __str__(self) -> str: | ||
... | ||
|
||
def __bool__(self) -> bool: | ||
return bool(self.item()) | ||
|
||
def __int__(self) -> int: | ||
return int(self.item()) | ||
|
||
def __float__(self) -> float: | ||
return float(self.item()) | ||
|
||
@abstractmethod | ||
def item(self) -> Numeric: | ||
... | ||
|
||
@abstractmethod | ||
def transpose(self, dim0: int, dim1: int) -> TensorLike: | ||
... | ||
|
||
@property | ||
def T(self) -> TensorLike: | ||
return self.transpose(0, 1) |
Oops, something went wrong.