Skip to content

Commit

Permalink
Add remaining IO dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol committed Jun 14, 2024
1 parent 8724f7d commit bbcb63c
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 21 deletions.
28 changes: 25 additions & 3 deletions src/finch/io.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,35 @@
from pathlib import Path

from .julia import jl
from .julia import add_package, jl
from .tensor import Tensor


def _import_deps(filename: str) -> None:
fn = filename
if fn.endswith(".mtx") or fn.endswith(".ttx") or fn.endswith(".tns"):
add_package("TensorMarket", hash="8b7d4fe7-0b45-4d0d-9dd8-5cc9b23b4b77", version="0.2.0")
jl.seval("using TensorMarket")
elif fn.endswith(".bspnpy"):
add_package("NPZ", hash="15e1cf62-19b3-5cfa-8e77-841668bca605", version="0.4.3")
jl.seval("using NPZ")
elif fn.endswith(".bsp.h5"):
add_package("HDF5", hash="f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f", version="0.17.2")
jl.seval("using HDF5")
else:
raise ValueError(
f"Unsupported file extension. Supported extensions are "
"`.mtx`, `.ttx`, `.tns`, `.bspnpy`, and `.bsp.h5`."
)


def read(filename: Path | str) -> Tensor:
julia_obj = jl.fread(str(filename))
fn = str(filename)
_import_deps(fn)
julia_obj = jl.fread(fn)
return Tensor(julia_obj)


def write(filename: Path | str, tns: Tensor) -> None:
jl.fwrite(str(filename), tns._obj)
fn = str(filename)
_import_deps(fn)
jl.fwrite(fn, tns._obj)
27 changes: 9 additions & 18 deletions src/finch/julia.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,28 @@

import juliapkg


def add_package(name: str, hash: str, version: str) -> None:
deps = juliapkg.deps.load_cur_deps()
if (deps.get("packages", {}).get(name, {}).get("version", None) != version):
juliapkg.add(name, hash, version=version)


_FINCH_NAME = "Finch"
_FINCH_VERSION = "0.6.31"
_FINCH_HASH = "9177782c-1635-4eb9-9bfb-d9dfa25e6bce"
_FINCH_REPO_PATH = os.environ.get("FINCH_REPO_PATH", default=None)
_FINCH_REPO_URL = os.environ.get("FINCH_URL_PATH", default=None)

_TENSOR_MARKET_NAME = "TensorMarket"
_TENSOR_MARKET_HASH = "8b7d4fe7-0b45-4d0d-9dd8-5cc9b23b4b77"
_TENSOR_MARKET_VERSION = "0.2.0"

if _FINCH_REPO_PATH and _FINCH_REPO_URL:
raise ValueError("FINCH_REPO_PATH and FINCH_URL_PATH can't be set at the same time.")

deps = juliapkg.deps.load_cur_deps()

if _FINCH_REPO_PATH: # Also account for empty string
juliapkg.add(_FINCH_NAME, _FINCH_HASH, path=_FINCH_REPO_PATH, dev=True)
elif _FINCH_REPO_URL:
juliapkg.add(_FINCH_NAME, _FINCH_HASH, url=_FINCH_REPO_URL, dev=True)
elif (
deps.get("packages", {}).get(_FINCH_NAME, {}).get("version", None)
!= _FINCH_VERSION
):
juliapkg.add(_FINCH_NAME, _FINCH_HASH, version=_FINCH_VERSION)

if (
deps.get("packages", {}).get(_TENSOR_MARKET_NAME, {}).get("version", None)
!= _FINCH_VERSION
):
juliapkg.add(_TENSOR_MARKET_NAME, _TENSOR_MARKET_HASH, version=_TENSOR_MARKET_VERSION)
else:
add_package(_FINCH_NAME, _FINCH_HASH, _FINCH_VERSION)

import juliacall as jc # noqa

Expand All @@ -40,4 +32,3 @@

jl.seval("using Finch")
jl.seval("using Random")
jl.seval("using TensorMarket")

0 comments on commit bbcb63c

Please sign in to comment.