diff --git a/sparse/__init__.py b/sparse/__init__.py index 3c1831ac..06400ca0 100644 --- a/sparse/__init__.py +++ b/sparse/__init__.py @@ -4,24 +4,24 @@ from ._version import __version__, __version_tuple__ # noqa: F401 - __array_api_version__ = "2022.12" + class BackendType(Enum): - pydata = "PyData" - finch = "Finch" + PyData = "PyData" + Finch = "Finch" _ENV_VAR_NAME = "SPARSE_BACKEND" -backend_var = ContextVar("backend", default=BackendType.pydata) +backend_var = ContextVar("backend", default=BackendType.PyData) if _ENV_VAR_NAME in os.environ: backend_var.set(BackendType[os.environ[_ENV_VAR_NAME]]) class Backend: - def __init__(self, backend=BackendType.pydata): + def __init__(self, backend=BackendType.PyData): self.backend = backend self.token = None @@ -36,9 +36,9 @@ def __exit__(self, exc_type, exc_value, traceback): @staticmethod def get_backend_module(): backend = backend_var.get() - if backend == BackendType.pydata: + if backend == BackendType.PyData: import sparse.pydata_backend as backend_module - elif backend == BackendType.finch: + elif backend == BackendType.Finch: import sparse.finch_backend as backend_module else: raise ValueError(f"Invalid backend identifier: {backend}") diff --git a/sparse/tests/conftest.py b/sparse/tests/conftest.py index dfc5102d..8dd080ce 100644 --- a/sparse/tests/conftest.py +++ b/sparse/tests/conftest.py @@ -3,7 +3,7 @@ import pytest -@pytest.fixture(scope="session", params=[sparse.BackendType.pydata, sparse.BackendType.finch]) +@pytest.fixture(scope="session", params=[sparse.BackendType.PyData, sparse.BackendType.Finch]) def backend(request): with sparse.Backend(backend=request.param): yield request.param diff --git a/sparse/tests/test_backends.py b/sparse/tests/test_backends.py index 9b8ff097..9cc8b5df 100644 --- a/sparse/tests/test_backends.py +++ b/sparse/tests/test_backends.py @@ -6,7 +6,7 @@ def test_backend_contex_manager(backend): - if backend == sparse.BackendType.finch: + if backend == sparse.BackendType.Finch: with pytest.raises(NotImplementedError): sparse.COO.from_numpy(np.eye(5)) else: