-
Notifications
You must be signed in to change notification settings - Fork 85
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add support for custom picklers (#2682)
* chore: move metadata to proper destination * feat: add support for custom pickling * feat: add support for custom pickling * test: ensure custom pickler works
- Loading branch information
Showing
5 changed files
with
180 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
"""Interface for plugin-configurable pickle __reduce_ex__ implementation""" | ||
from __future__ import annotations | ||
|
||
import sys | ||
import threading | ||
import warnings | ||
|
||
from awkward._typing import Any, Protocol, runtime_checkable | ||
|
||
if sys.version_info < (3, 12): | ||
import importlib_metadata | ||
else: | ||
import importlib.metadata as importlib_metadata | ||
|
||
|
||
@runtime_checkable | ||
class PickleReducer(Protocol): | ||
def __call__(self, obj: Any, protocol: int) -> tuple | NotImplemented: | ||
... | ||
|
||
|
||
_register_lock = threading.Lock() | ||
_plugin: PickleReducer | None = None | ||
_is_registered = False | ||
|
||
|
||
def _load_reduce_plugin(): | ||
best_plugin = None | ||
|
||
for entry_point in importlib_metadata.entry_points(group="awkward.pickle.reduce"): | ||
plugin = entry_point.load() | ||
|
||
try: | ||
assert isinstance(plugin, PickleReducer) | ||
except AssertionError: | ||
warnings.warn( | ||
f"Couldn't load `awkward.pickle.reduce` plugin: {entry_point}", | ||
stacklevel=2, | ||
) | ||
continue | ||
|
||
if best_plugin is not None: | ||
raise RuntimeError( | ||
"Encountered multiple Awkward pickle reducers under the `awkward.pickle.reduce` entrypoint" | ||
) | ||
best_plugin = plugin | ||
|
||
return best_plugin | ||
|
||
|
||
def get_custom_reducer() -> PickleReducer | None: | ||
""" | ||
Returns the implementation of a custom __reduce_ex__ function for Awkward | ||
highlevel objects, or None if none provided | ||
""" | ||
global _is_registered, _plugin | ||
|
||
with _register_lock: | ||
if not _is_registered: | ||
_plugin = _load_reduce_plugin() | ||
_is_registered = True | ||
|
||
if _plugin is None: | ||
return None | ||
else: | ||
return _plugin | ||
|
||
|
||
def custom_reduce(obj, protocol) -> tuple | NotImplemented: | ||
plugin = get_custom_reducer() | ||
if plugin is None: | ||
return NotImplemented | ||
return plugin(obj, protocol) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE | ||
import multiprocessing | ||
import os | ||
import pickle | ||
import sys | ||
from concurrent.futures import ProcessPoolExecutor | ||
|
||
import pytest | ||
|
||
import awkward as ak | ||
|
||
|
||
def _init_process_with_pickler(pickler_source: str, tmp_path): | ||
# Create custom plugin | ||
(tmp_path / "impl_pickler.py").write_bytes(pickler_source.encode("UTF-8")) | ||
dist_info = tmp_path / "impl_pickler-0.0.0.dist-info" | ||
dist_info.mkdir() | ||
(dist_info / "entry_points.txt").write_bytes( | ||
b"[awkward.pickle.reduce]\nimpl = impl_pickler:plugin\n" | ||
) | ||
sys.path.insert(0, os.fsdecode(tmp_path)) | ||
|
||
|
||
def _pickle_complex_array_and_return_form_impl(): | ||
array = ak.Array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])[[0, 2]] | ||
return pickle.loads(pickle.dumps(array)).layout.form | ||
|
||
|
||
def pickle_complex_array_and_return_form(pickler_source, tmp_path): | ||
"""Create a new (spawned) process, and register the given pickler source | ||
via entrypoints""" | ||
with ProcessPoolExecutor( | ||
1, | ||
initializer=_init_process_with_pickler, | ||
initargs=(pickler_source, tmp_path), | ||
# Don't fork the current process with all of its state | ||
mp_context=multiprocessing.get_context("spawn"), | ||
) as executor: | ||
pickle_future = executor.submit(_pickle_complex_array_and_return_form_impl) | ||
return pickle_future.result() | ||
|
||
|
||
def test_default_pickler(): | ||
assert _pickle_complex_array_and_return_form_impl() == ak.forms.from_dict( | ||
{"class": "ListOffsetArray", "offsets": "i64", "content": "int64"} | ||
) | ||
|
||
|
||
def test_noop_pickler(tmp_path): | ||
assert ( | ||
pickle_complex_array_and_return_form( | ||
""" | ||
def plugin(obj, protocol: int): | ||
return NotImplemented""", | ||
tmp_path, | ||
) | ||
== ak.forms.from_dict( | ||
{"class": "ListOffsetArray", "offsets": "i64", "content": "int64"} | ||
) | ||
) | ||
|
||
|
||
def test_non_packing_pickler(tmp_path): | ||
assert ( | ||
pickle_complex_array_and_return_form( | ||
""" | ||
def plugin(obj, protocol): | ||
import awkward as ak | ||
if isinstance(obj, ak.Array): | ||
form, length, container = ak.to_buffers(obj) | ||
return ( | ||
object.__new__, | ||
(ak.Array,), | ||
(form.to_dict(), length, container, obj.behavior), | ||
) | ||
else: | ||
return NotImplemented""", | ||
tmp_path, | ||
) | ||
== ak.forms.from_dict( | ||
{"class": "ListArray", "starts": "i64", "stops": "i64", "content": "int64"} | ||
) | ||
) | ||
|
||
|
||
def test_malformed_pickler(tmp_path): | ||
with pytest.raises(RuntimeError, match=r"malformed pickler!"): | ||
pickle_complex_array_and_return_form( | ||
""" | ||
def plugin(obj, protocol: int): | ||
raise RuntimeError('malformed pickler!')""", | ||
tmp_path, | ||
) |