Skip to content

Commit

Permalink
Issue #17 initial steps to decouple code table IO from PrefixCodec
Browse files Browse the repository at this point in the history
- Refactor out pickle save/load from PrefixCode
- Initial JSON based code table storage
  • Loading branch information
soxofaan committed Jul 13, 2024
1 parent 8cd37f8 commit 4b6267d
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 42 deletions.
3 changes: 2 additions & 1 deletion dahuffman/codecs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from functools import partial
from pathlib import Path

import dahuffman.codetableio
from dahuffman.huffmancodec import PrefixCodec


Expand All @@ -22,7 +23,7 @@ def load(name: str) -> PrefixCodec:
if not name.endswith(".pickle"):
name = name + ".pickle"
with importlib.resources.path("dahuffman.codecs", resource=name) as path:
return PrefixCodec.load(path)
return dahuffman.codetableio.pickle_load(path)


load_shakespeare = partial(load, "shakespeare")
Expand Down
132 changes: 132 additions & 0 deletions dahuffman/codetableio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""
Functionality to save/load a code table to/from a file
"""

import json
import logging
import pickle
from pathlib import Path
from typing import Any, Optional, Union

from dahuffman.huffmancodec import _EOF, PrefixCodec

_log = logging.getLogger(__name__)


def ensure_dir(path: Union[str, Path]) -> Path:
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
return path


def pickle_save(
codec: PrefixCodec, path: Union[str, Path], metadata: Any = None
) -> None:
"""
Persist the code table to a file.
:param path: file path to persist to
:param metadata: additional metadata to include
"""
code_table = codec.get_code_table()
data = {
"code_table": code_table,
"type": type(codec),
"concat": codec._concat,
}
if metadata:
data["metadata"] = metadata
path = Path(path)
ensure_dir(path.parent)
with path.open(mode="wb") as f:
pickle.dump(data, file=f)
_log.info(
f"Saved {type(codec).__name__} code table ({len(code_table)} items) to {str(path)!r}"
)


def pickle_load(path: Union[str, Path]) -> PrefixCodec:
"""
Load a persisted PrefixCodec
:param path: path to serialized PrefixCodec code table data.
"""
path = Path(path)
with path.open(mode="rb") as f:
data = pickle.load(f)
cls = data["type"]
assert issubclass(cls, PrefixCodec)
code_table = data["code_table"]
_log.info(
f"Loading {cls.__name__} with {len(code_table)} code table items from {str(path)!r}"
)
return cls(code_table, concat=data["concat"])


def json_save(
codec: PrefixCodec, path: Union[str, Path], metadata: Optional[dict] = None
) -> None:
"""
Persist the code table as a JSON file.
Requires that all structures in the code table are JSON-serializable.
:param path: file path to persist to
:param metadata: additional metadata to include in the file.
"""
code_table = codec.get_code_table()

# Extract internal _EOF symbol from code table
if _EOF in code_table:
eof_code = code_table.pop(_EOF)
else:
eof_code = None

# Transform code table dictionary to a list, to avoid string-coercion of keys in JSON mappings.
code_table = [[k, *v] for (k, v) in code_table.items()]

data = {
"type": "dahuffman code table",
"version": 1,
"code_table": code_table,
}
if eof_code:
data["eof_code"] = eof_code
if metadata:
data["metadata"] = metadata
if codec._concat == list:
data["concat"] = "list"
elif codec._concat == "".join:
data["concat"] = "str_join"
elif codec._concat == bytes:
data["concat"] = "bytes"
else:
_log.warning(f"Unsupported concat callable {codec._concat!r}")

path = Path(path)
ensure_dir(path.parent)
with path.open("w", encoding="utf8") as f:
json.dump(obj=data, fp=f)
_log.info(
f"Saved {type(codec).__name__} code table ({len(code_table)} items) to {str(path)!r}"
)


def json_load(path: Union[str, Path]) -> PrefixCodec:
path = Path(path)
with path.open(mode="r", encoding="utf8") as f:
data = json.load(fp=f)

assert data["type"] == "dahuffman code table"
assert data["version"] == 1

# Reconstruct code table
code_table = {row[0]: row[1:] for row in data["code_table"]}

if "eof_code" in data:
code_table[_EOF] = data["eof_code"]

concat = {"str_join": "".join, "bytes": bytes}.get(data["concat"], list)

_log.info(
f"Loading PrefixCodec with {len(code_table)} code table items from {str(path)!r}"
)
return PrefixCodec(code_table, concat=concat)
52 changes: 15 additions & 37 deletions dahuffman/huffmancodec.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import collections
import itertools
import logging
import pickle
import sys
import warnings
from heapq import heapify, heappop, heappush
from io import IOBase
from pathlib import Path
Expand Down Expand Up @@ -55,14 +55,6 @@ def _guess_concat(data: Any) -> Callable:
}.get(type(data), list)


def ensure_dir(path: Union[str, Path]) -> Path:
path = Path(path)
if not path.exists():
path.mkdir(parents=True)
assert path.is_dir()
return path


class PrefixCodec:
"""
Prefix code codec, using given code table.
Expand Down Expand Up @@ -218,23 +210,14 @@ def save(self, path: Union[str, Path], metadata: Any = None) -> None:
:param metadata: additional metadata
:return:
"""
code_table = self.get_code_table()
data = {
"code_table": code_table,
"type": type(self),
"concat": self._concat,
}
if metadata:
data["metadata"] = metadata
path = Path(path)
ensure_dir(path.parent)
with path.open(mode="wb") as f:
# TODO also provide JSON option? Requires handling of _EOF and possibly other non-string code table keys.
pickle.dump(data, file=f)
_log.info(
"Saved {c} code table ({l} items) to {p!r}".format(
c=type(self).__name__, l=len(code_table), p=str(path)
)
warnings.warn(
"`PrefixCodec.save()` is deprecated, use `dahuffman.codetableio` functionality instead",
DeprecationWarning,
)
import dahuffman.codetableio

return dahuffman.codetableio.pickle_save(
codec=self, path=path, metadata=metadata
)

@staticmethod
Expand All @@ -244,18 +227,13 @@ def load(path: Union[str, Path]) -> "PrefixCodec":
:param path: path to serialized PrefixCodec code table data.
:return:
"""
path = Path(path)
with path.open(mode="rb") as f:
data = pickle.load(f)
cls = data["type"]
assert issubclass(cls, PrefixCodec)
code_table = data["code_table"]
_log.info(
"Loading {c} with {l} code table items from {p!r}".format(
c=cls.__name__, l=len(code_table), p=str(path)
)
warnings.warn(
"`PrefixCodec.load()` is deprecated, use `dahuffman.codetableio` functionality instead",
DeprecationWarning,
)
return cls(code_table, concat=data["concat"])
import dahuffman.codetableio

return dahuffman.codetableio.pickle_load(path=path)


class HuffmanCodec(PrefixCodec):
Expand Down
63 changes: 63 additions & 0 deletions tests/test_codetableio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from pathlib import Path

import pytest

from dahuffman.codetableio import json_load, json_save, pickle_load, pickle_save
from dahuffman.huffmancodec import HuffmanCodec


@pytest.mark.parametrize(
["train_data", "data"],
[
("aabcbcdbabdbcbd", "abcdabcd"),
(
["FR", "UK", "BE", "IT", "FR", "IT", "GR", "FR", "NL", "BE", "DE"],
["FR", "IT", "BE", "FR", "UK"],
),
(b"aabcbcdbabdbcbd", b"abcdabcd"),
(
[(0, 0), (0, 1), (1, 0), (0, 0), (1, 0), (1, 0)],
[(1, 0), (0, 0), (0, 1), (1, 0)],
),
],
)
def test_pickle_save_and_load(tmp_path: Path, train_data, data):
codec1 = HuffmanCodec.from_data(train_data)
encoded1 = codec1.encode(data)

path = tmp_path / "code-table.pickle"
pickle_save(codec=codec1, path=path)
codec2 = pickle_load(path)
encoded2 = codec2.encode(data)

assert encoded1 == encoded2
assert codec1.decode(encoded1) == codec2.decode(encoded2)


@pytest.mark.parametrize(
["train_data", "data"],
[
("aabcbcdbabdbcbd", "abcdabcd"),
(
["FR", "UK", "BE", "IT", "FR", "IT", "GR", "FR", "NL", "BE", "DE"],
["FR", "IT", "BE", "FR", "UK"],
),
(b"aabcbcdbabdbcbd", b"abcdabcd"),
# TODO:
# (
# [(0, 0), (0, 1), (1, 0), (0, 0), (1, 0), (1, 0)],
# [(1, 0), (0, 0), (0, 1), (1, 0)],
# ),
],
)
def test_json_save_and_load(tmp_path: Path, train_data, data):
codec1 = HuffmanCodec.from_data(train_data)
encoded1 = codec1.encode(data)

path = tmp_path / "code-table.json"
json_save(codec=codec1, path=path)
codec2 = json_load(path)
encoded2 = codec2.encode(data)

assert encoded1 == encoded2
assert codec1.decode(encoded1) == codec2.decode(encoded2)
2 changes: 1 addition & 1 deletion tests/test_dahuffman.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def test_eof_cut_off():
assert data == codec.decode(encoded)


def test_save(tmp_path: Path):
def test_save_and_load(tmp_path: Path):
codec1 = HuffmanCodec.from_data("aabcbcdbabdbcbd")
path = str(tmp_path / "foo" / "bar.huff")
codec1.save(path)
Expand Down
4 changes: 1 addition & 3 deletions train/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

import requests

from dahuffman.huffmancodec import ensure_dir

DOWNLOADS = Path(__file__).parent / "data"
CODECS = Path(__file__).parent / "codecs"

Expand All @@ -14,7 +12,7 @@
def download(url: str, path: str) -> Path:
path = DOWNLOADS / path
if not path.exists():
ensure_dir(path.parent)
path.parent.mkdir(parents=True, exists_ok=True)
_log.info(f"Downloading {url}")
with requests.get(url) as r:
r.raise_for_status()
Expand Down

0 comments on commit 4b6267d

Please sign in to comment.