Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API: Establish pydata_backend #646

Merged
merged 3 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ tests = [
]
tox = ["sparse[tests]", "tox"]
all = ["sparse[docs,tox]", "matrepr"]
finch = ["finch-tensor"]

[project.urls]
Documentation = "https://sparse.pydata.org/"
Expand All @@ -46,7 +47,7 @@ Repository = "https://github.com/pydata/sparse.git"
Discussions = "https://github.com/pydata/sparse/discussions"

[project.entry-points.numba_extensions]
init = "sparse._numba_extension:_init_extension"
init = "sparse.pydata_backend._numba_extension:_init_extension"

[tool.setuptools.packages.find]
where = ["."]
Expand Down
373 changes: 54 additions & 319 deletions sparse/__init__.py
Original file line number Diff line number Diff line change
@@ -1,323 +1,58 @@
from numpy import (
add,
bitwise_and,
bitwise_not,
bitwise_or,
bitwise_xor,
can_cast,
ceil,
complex64,
complex128,
cos,
cosh,
divide,
e,
exp,
expm1,
finfo,
float16,
float32,
float64,
floor,
floor_divide,
greater,
greater_equal,
iinfo,
imag,
inf,
int8,
int16,
int32,
int64,
less,
less_equal,
log,
log1p,
log2,
log10,
logaddexp,
logical_and,
logical_not,
logical_or,
logical_xor,
multiply,
nan,
negative,
newaxis,
not_equal,
pi,
positive,
real,
remainder,
sign,
sin,
sinh,
sqrt,
square,
subtract,
tan,
tanh,
trunc,
uint8,
uint16,
uint32,
uint64,
)
from numpy import arccos as acos
from numpy import arccosh as acosh
from numpy import arcsin as asin
from numpy import arcsinh as asinh
from numpy import arctan as atan
from numpy import arctan2 as atan2
from numpy import arctanh as atanh
from numpy import bool_ as bool
from numpy import invert as bitwise_invert
from numpy import left_shift as bitwise_left_shift
from numpy import power as pow
from numpy import right_shift as bitwise_right_shift
import os
from contextvars import ContextVar
from enum import Enum

from ._common import (
SparseArray,
abs,
all,
any,
asarray,
asnumpy,
astype,
broadcast_arrays,
broadcast_to,
concat,
concatenate,
dot,
einsum,
empty,
empty_like,
equal,
eye,
full,
full_like,
isfinite,
isinf,
isnan,
matmul,
max,
mean,
min,
moveaxis,
nonzero,
ones,
ones_like,
outer,
pad,
permute_dims,
prod,
reshape,
round,
squeeze,
stack,
std,
sum,
tensordot,
var,
vecdot,
zeros,
zeros_like,
)
from ._compressed import GCXS
from ._coo import COO, as_coo
from ._coo.common import (
argmax,
argmin,
argwhere,
asCOO,
clip,
diagonal,
diagonalize,
expand_dims,
flip,
isneginf,
isposinf,
kron,
matrix_transpose,
nanmax,
nanmean,
nanmin,
nanprod,
nanreduce,
nansum,
result_type,
roll,
sort,
take,
tril,
triu,
unique_counts,
unique_values,
where,
)
from ._dok import DOK
from ._io import load_npz, save_npz
from ._umath import elemwise
from ._utils import random
from ._version import __version__, __version_tuple__ # noqa: F401

__all__ = [
"COO",
"DOK",
"GCXS",
"SparseArray",
"abs",
"acos",
"acosh",
"add",
"all",
"any",
"argmax",
"argmin",
"argwhere",
"asCOO",
"as_coo",
"asarray",
"asin",
"asinh",
"asnumpy",
"astype",
"atan",
"atan2",
"atanh",
"bitwise_and",
"bitwise_invert",
"bitwise_left_shift",
"bitwise_not",
"bitwise_or",
"bitwise_right_shift",
"bitwise_xor",
"bool",
"broadcast_arrays",
"broadcast_to",
"can_cast",
"ceil",
"clip",
"complex128",
"complex64",
"concat",
"concatenate",
"cos",
"cosh",
"diagonal",
"diagonalize",
"divide",
"dot",
"e",
"einsum",
"elemwise",
"empty",
"empty_like",
"equal",
"exp",
"expand_dims",
"expm1",
"eye",
"finfo",
"flip",
"float16",
"float32",
"float64",
"floor",
"floor_divide",
"full",
"full_like",
"greater",
"greater_equal",
"iinfo",
"imag",
"inf",
"int16",
"int32",
"int64",
"int8",
"isfinite",
"isinf",
"isnan",
"isneginf",
"isposinf",
"kron",
"less",
"less_equal",
"load_npz",
"log",
"log10",
"log1p",
"log2",
"logaddexp",
"logical_and",
"logical_not",
"logical_or",
"logical_xor",
"matmul",
"matrix_transpose",
"max",
"mean",
"min",
"moveaxis",
"multiply",
"nan",
"nanmax",
"nanmean",
"nanmin",
"nanprod",
"nanreduce",
"nansum",
"negative",
"newaxis",
"nonzero",
"not_equal",
"ones",
"ones_like",
"outer",
"pad",
"permute_dims",
"pi",
"positive",
"pow",
"prod",
"random",
"real",
"remainder",
"reshape",
"result_type",
"roll",
"round",
"save_npz",
"sign",
"sin",
"sinh",
"sort",
"sqrt",
"square",
"squeeze",
"stack",
"std",
"subtract",
"sum",
"take",
"tan",
"tanh",
"tensordot",
"tril",
"triu",
"trunc",
"uint16",
"uint32",
"uint64",
"uint8",
"unique_counts",
"unique_values",
"var",
"vecdot",
"where",
"zeros",
"zeros_like",
]

__array_api_version__ = "2022.12"


class BackendType(Enum):
PyData = "PyData"
Finch = "Finch"


_ENV_VAR_NAME = "SPARSE_BACKEND"

backend_var = ContextVar("backend", default=BackendType.PyData)

if _ENV_VAR_NAME in os.environ:
backend_var.set(BackendType[os.environ[_ENV_VAR_NAME]])


class Backend:
def __init__(self, backend=BackendType.PyData):
self.backend = backend
self.token = None

def __enter__(self):
token = backend_var.set(self.backend)
self.token = token

def __exit__(self, exc_type, exc_value, traceback):
backend_var.reset(self.token)
self.token = None

@staticmethod
def get_backend_module():
backend = backend_var.get()
if backend == BackendType.PyData:
import sparse.pydata_backend as backend_module
elif backend == BackendType.Finch:
import sparse.finch_backend as backend_module
else:
raise ValueError(f"Invalid backend identifier: {backend}")
return backend_module


def __getattr__(attr):
if attr == "pydata_backend":
import sparse.pydata_backend as backend_module

return backend_module
if attr == "finch_backend":
import sparse.finch_backend as backend_module

hameerabbasi marked this conversation as resolved.
Show resolved Hide resolved
return backend_module

return getattr(Backend.get_backend_module(), attr)
Loading
Loading