diff --git a/dahuffman/codecs/__init__.py b/dahuffman/codecs/__init__.py index d8c53c8..9533519 100644 --- a/dahuffman/codecs/__init__.py +++ b/dahuffman/codecs/__init__.py @@ -9,6 +9,7 @@ from functools import partial from pathlib import Path +import dahuffman.codetableio from dahuffman.huffmancodec import PrefixCodec @@ -22,7 +23,7 @@ def load(name: str) -> PrefixCodec: if not name.endswith(".pickle"): name = name + ".pickle" with importlib.resources.path("dahuffman.codecs", resource=name) as path: - return PrefixCodec.load(path) + return dahuffman.codetableio.pickle_load(path) load_shakespeare = partial(load, "shakespeare") diff --git a/dahuffman/codetableio.py b/dahuffman/codetableio.py new file mode 100644 index 0000000..1d595c4 --- /dev/null +++ b/dahuffman/codetableio.py @@ -0,0 +1,132 @@ +""" +Functionality to save/load a code table to/from a file +""" + +import json +import logging +import pickle +from pathlib import Path +from typing import Any, Optional, Union + +from dahuffman.huffmancodec import _EOF, PrefixCodec + +_log = logging.getLogger(__name__) + + +def ensure_dir(path: Union[str, Path]) -> Path: + path = Path(path) + path.mkdir(parents=True, exist_ok=True) + return path + + +def pickle_save( + codec: PrefixCodec, path: Union[str, Path], metadata: Any = None +) -> None: + """ + Persist the code table to a file. + + :param path: file path to persist to + :param metadata: additional metadata to include + """ + code_table = codec.get_code_table() + data = { + "code_table": code_table, + "type": type(codec), + "concat": codec._concat, + } + if metadata: + data["metadata"] = metadata + path = Path(path) + ensure_dir(path.parent) + with path.open(mode="wb") as f: + pickle.dump(data, file=f) + _log.info( + f"Saved {type(codec).__name__} code table ({len(code_table)} items) to {str(path)!r}" + ) + + +def pickle_load(path: Union[str, Path]) -> PrefixCodec: + """ + Load a persisted PrefixCodec + :param path: path to serialized PrefixCodec code table data. + """ + path = Path(path) + with path.open(mode="rb") as f: + data = pickle.load(f) + cls = data["type"] + assert issubclass(cls, PrefixCodec) + code_table = data["code_table"] + _log.info( + f"Loading {cls.__name__} with {len(code_table)} code table items from {str(path)!r}" + ) + return cls(code_table, concat=data["concat"]) + + +def json_save( + codec: PrefixCodec, path: Union[str, Path], metadata: Optional[dict] = None +) -> None: + """ + Persist the code table as a JSON file. + Requires that all structures in the code table are JSON-serializable. + + :param path: file path to persist to + :param metadata: additional metadata to include in the file. + """ + code_table = codec.get_code_table() + + # Extract internal _EOF symbol from code table + if _EOF in code_table: + eof_code = code_table.pop(_EOF) + else: + eof_code = None + + # Transform code table dictionary to a list, to avoid string-coercion of keys in JSON mappings. + code_table = [[k, *v] for (k, v) in code_table.items()] + + data = { + "type": "dahuffman code table", + "version": 1, + "code_table": code_table, + } + if eof_code: + data["eof_code"] = eof_code + if metadata: + data["metadata"] = metadata + if codec._concat == list: + data["concat"] = "list" + elif codec._concat == "".join: + data["concat"] = "str_join" + elif codec._concat == bytes: + data["concat"] = "bytes" + else: + _log.warning(f"Unsupported concat callable {codec._concat!r}") + + path = Path(path) + ensure_dir(path.parent) + with path.open("w", encoding="utf8") as f: + json.dump(obj=data, fp=f) + _log.info( + f"Saved {type(codec).__name__} code table ({len(code_table)} items) to {str(path)!r}" + ) + + +def json_load(path: Union[str, Path]) -> PrefixCodec: + path = Path(path) + with path.open(mode="r", encoding="utf8") as f: + data = json.load(fp=f) + + assert data["type"] == "dahuffman code table" + assert data["version"] == 1 + + # Reconstruct code table + code_table = {row[0]: row[1:] for row in data["code_table"]} + + if "eof_code" in data: + code_table[_EOF] = data["eof_code"] + + concat = {"str_join": "".join, "bytes": bytes}.get(data["concat"], list) + + _log.info( + f"Loading PrefixCodec with {len(code_table)} code table items from {str(path)!r}" + ) + return PrefixCodec(code_table, concat=concat) diff --git a/dahuffman/huffmancodec.py b/dahuffman/huffmancodec.py index fb28cec..461ed7f 100644 --- a/dahuffman/huffmancodec.py +++ b/dahuffman/huffmancodec.py @@ -1,8 +1,8 @@ import collections import itertools import logging -import pickle import sys +import warnings from heapq import heapify, heappop, heappush from io import IOBase from pathlib import Path @@ -55,14 +55,6 @@ def _guess_concat(data: Any) -> Callable: }.get(type(data), list) -def ensure_dir(path: Union[str, Path]) -> Path: - path = Path(path) - if not path.exists(): - path.mkdir(parents=True) - assert path.is_dir() - return path - - class PrefixCodec: """ Prefix code codec, using given code table. @@ -218,23 +210,14 @@ def save(self, path: Union[str, Path], metadata: Any = None) -> None: :param metadata: additional metadata :return: """ - code_table = self.get_code_table() - data = { - "code_table": code_table, - "type": type(self), - "concat": self._concat, - } - if metadata: - data["metadata"] = metadata - path = Path(path) - ensure_dir(path.parent) - with path.open(mode="wb") as f: - # TODO also provide JSON option? Requires handling of _EOF and possibly other non-string code table keys. - pickle.dump(data, file=f) - _log.info( - "Saved {c} code table ({l} items) to {p!r}".format( - c=type(self).__name__, l=len(code_table), p=str(path) - ) + warnings.warn( + "`PrefixCodec.save()` is deprecated, use `dahuffman.codetableio` functionality instead", + DeprecationWarning, + ) + import dahuffman.codetableio + + return dahuffman.codetableio.pickle_save( + codec=self, path=path, metadata=metadata ) @staticmethod @@ -244,18 +227,13 @@ def load(path: Union[str, Path]) -> "PrefixCodec": :param path: path to serialized PrefixCodec code table data. :return: """ - path = Path(path) - with path.open(mode="rb") as f: - data = pickle.load(f) - cls = data["type"] - assert issubclass(cls, PrefixCodec) - code_table = data["code_table"] - _log.info( - "Loading {c} with {l} code table items from {p!r}".format( - c=cls.__name__, l=len(code_table), p=str(path) - ) + warnings.warn( + "`PrefixCodec.load()` is deprecated, use `dahuffman.codetableio` functionality instead", + DeprecationWarning, ) - return cls(code_table, concat=data["concat"]) + import dahuffman.codetableio + + return dahuffman.codetableio.pickle_load(path=path) class HuffmanCodec(PrefixCodec): diff --git a/tests/test_codetableio.py b/tests/test_codetableio.py new file mode 100644 index 0000000..82b7b55 --- /dev/null +++ b/tests/test_codetableio.py @@ -0,0 +1,63 @@ +from pathlib import Path + +import pytest + +from dahuffman.codetableio import json_load, json_save, pickle_load, pickle_save +from dahuffman.huffmancodec import HuffmanCodec + + +@pytest.mark.parametrize( + ["train_data", "data"], + [ + ("aabcbcdbabdbcbd", "abcdabcd"), + ( + ["FR", "UK", "BE", "IT", "FR", "IT", "GR", "FR", "NL", "BE", "DE"], + ["FR", "IT", "BE", "FR", "UK"], + ), + (b"aabcbcdbabdbcbd", b"abcdabcd"), + ( + [(0, 0), (0, 1), (1, 0), (0, 0), (1, 0), (1, 0)], + [(1, 0), (0, 0), (0, 1), (1, 0)], + ), + ], +) +def test_pickle_save_and_load(tmp_path: Path, train_data, data): + codec1 = HuffmanCodec.from_data(train_data) + encoded1 = codec1.encode(data) + + path = tmp_path / "code-table.pickle" + pickle_save(codec=codec1, path=path) + codec2 = pickle_load(path) + encoded2 = codec2.encode(data) + + assert encoded1 == encoded2 + assert codec1.decode(encoded1) == codec2.decode(encoded2) + + +@pytest.mark.parametrize( + ["train_data", "data"], + [ + ("aabcbcdbabdbcbd", "abcdabcd"), + ( + ["FR", "UK", "BE", "IT", "FR", "IT", "GR", "FR", "NL", "BE", "DE"], + ["FR", "IT", "BE", "FR", "UK"], + ), + (b"aabcbcdbabdbcbd", b"abcdabcd"), + # TODO: + # ( + # [(0, 0), (0, 1), (1, 0), (0, 0), (1, 0), (1, 0)], + # [(1, 0), (0, 0), (0, 1), (1, 0)], + # ), + ], +) +def test_json_save_and_load(tmp_path: Path, train_data, data): + codec1 = HuffmanCodec.from_data(train_data) + encoded1 = codec1.encode(data) + + path = tmp_path / "code-table.json" + json_save(codec=codec1, path=path) + codec2 = json_load(path) + encoded2 = codec2.encode(data) + + assert encoded1 == encoded2 + assert codec1.decode(encoded1) == codec2.decode(encoded2) diff --git a/tests/test_dahuffman.py b/tests/test_dahuffman.py index a1a1405..6c50042 100644 --- a/tests/test_dahuffman.py +++ b/tests/test_dahuffman.py @@ -136,7 +136,7 @@ def test_eof_cut_off(): assert data == codec.decode(encoded) -def test_save(tmp_path: Path): +def test_save_and_load(tmp_path: Path): codec1 = HuffmanCodec.from_data("aabcbcdbabdbcbd") path = str(tmp_path / "foo" / "bar.huff") codec1.save(path) diff --git a/train/train_utils.py b/train/train_utils.py index 98a82d6..1ce0ab6 100644 --- a/train/train_utils.py +++ b/train/train_utils.py @@ -3,8 +3,6 @@ import requests -from dahuffman.huffmancodec import ensure_dir - DOWNLOADS = Path(__file__).parent / "data" CODECS = Path(__file__).parent / "codecs" @@ -14,7 +12,7 @@ def download(url: str, path: str) -> Path: path = DOWNLOADS / path if not path.exists(): - ensure_dir(path.parent) + path.parent.mkdir(parents=True, exists_ok=True) _log.info(f"Downloading {url}") with requests.get(url) as r: r.raise_for_status()