Skip to content

Commit

Permalink
feat: add support for custom picklers (#2682)
Browse files Browse the repository at this point in the history
* chore: move metadata to proper destination

* feat: add support for custom pickling

* feat: add support for custom pickling

* test: ensure custom pickler works
  • Loading branch information
agoose77 committed Sep 1, 2023
1 parent 81fc063 commit 519bba6
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 2 deletions.
3 changes: 2 additions & 1 deletion awkward-cpp/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ build-backend = "scikit_build_core.build"
name = "awkward_cpp"
version = "22"
dependencies = [
"numpy>=1.18.0"
"numpy>=1.18.0",
"importlib_resources;python_version < \"3.9\""
]
readme = "README.md"
description = "CPU kernels and compiled extensions for Awkward Array"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ classifiers = [
]
dependencies = [
"awkward_cpp==22",
"importlib_resources;python_version < \"3.9\"",
"importlib_metadata>=4.13.0;python_version < \"3.12\"",
"numpy>=1.18.0",
"packaging",
"typing_extensions>=4.1.0; python_version < \"3.11\""
Expand Down
73 changes: 73 additions & 0 deletions src/awkward/_pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Interface for plugin-configurable pickle __reduce_ex__ implementation"""
from __future__ import annotations

import sys
import threading
import warnings

from awkward._typing import Any, Protocol, runtime_checkable

if sys.version_info < (3, 12):
import importlib_metadata
else:
import importlib.metadata as importlib_metadata


@runtime_checkable
class PickleReducer(Protocol):
def __call__(self, obj: Any, protocol: int) -> tuple | NotImplemented:
...


_register_lock = threading.Lock()
_plugin: PickleReducer | None = None
_is_registered = False


def _load_reduce_plugin():
best_plugin = None

for entry_point in importlib_metadata.entry_points(group="awkward.pickle.reduce"):
plugin = entry_point.load()

try:
assert isinstance(plugin, PickleReducer)
except AssertionError:
warnings.warn(
f"Couldn't load `awkward.pickle.reduce` plugin: {entry_point}",
stacklevel=2,
)
continue

if best_plugin is not None:
raise RuntimeError(
"Encountered multiple Awkward pickle reducers under the `awkward.pickle.reduce` entrypoint"
)
best_plugin = plugin

return best_plugin


def get_custom_reducer() -> PickleReducer | None:
"""
Returns the implementation of a custom __reduce_ex__ function for Awkward
highlevel objects, or None if none provided
"""
global _is_registered, _plugin

with _register_lock:
if not _is_registered:
_plugin = _load_reduce_plugin()
_is_registered = True

if _plugin is None:
return None
else:
return _plugin


def custom_reduce(obj, protocol) -> tuple | NotImplemented:
plugin = get_custom_reducer()
if plugin is None:
return NotImplemented
return plugin(obj, protocol)
11 changes: 11 additions & 0 deletions src/awkward/highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._operators import NDArrayOperatorsMixin
from awkward._pickle import custom_reduce
from awkward._regularize import is_non_string_like_iterable
from awkward._typing import TypeVar

Expand Down Expand Up @@ -1458,6 +1459,10 @@ def numba_type(self):
return numba.typeof(self._numbaview)

def __reduce_ex__(self, protocol: int) -> tuple:
result = custom_reduce(self, protocol)
if result is not NotImplemented:
return result

packed_layout = ak.operations.to_packed(self._layout, highlevel=False)
form, length, container = ak.operations.to_buffers(
packed_layout,
Expand Down Expand Up @@ -2125,6 +2130,11 @@ def numba_type(self):
return numba.typeof(self._numbaview)

def __reduce_ex__(self, protocol: int) -> tuple:
# Allow third-party libraries to customise pickling
result = custom_reduce(self, protocol)
if result is not NotImplemented:
return result

packed_layout = ak.operations.to_packed(self._layout, highlevel=False)
form, length, container = ak.operations.to_buffers(
packed_layout.array,
Expand All @@ -2136,6 +2146,7 @@ def __reduce_ex__(self, protocol: int) -> tuple:
# For pickle >= 5, we can avoid copying the buffers
if protocol >= 5:
container = {k: pickle.PickleBuffer(v) for k, v in container.items()}

if self._behavior is ak.behavior:
behavior = None
else:
Expand Down
93 changes: 93 additions & 0 deletions tests/test_2682_custom_pickler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
import multiprocessing
import os
import pickle
import sys
from concurrent.futures import ProcessPoolExecutor

import pytest

import awkward as ak


def _init_process_with_pickler(pickler_source: str, tmp_path):
# Create custom plugin
(tmp_path / "impl_pickler.py").write_bytes(pickler_source.encode("UTF-8"))
dist_info = tmp_path / "impl_pickler-0.0.0.dist-info"
dist_info.mkdir()
(dist_info / "entry_points.txt").write_bytes(
b"[awkward.pickle.reduce]\nimpl = impl_pickler:plugin\n"
)
sys.path.insert(0, os.fsdecode(tmp_path))


def _pickle_complex_array_and_return_form_impl():
array = ak.Array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])[[0, 2]]
return pickle.loads(pickle.dumps(array)).layout.form


def pickle_complex_array_and_return_form(pickler_source, tmp_path):
"""Create a new (spawned) process, and register the given pickler source
via entrypoints"""
with ProcessPoolExecutor(
1,
initializer=_init_process_with_pickler,
initargs=(pickler_source, tmp_path),
# Don't fork the current process with all of its state
mp_context=multiprocessing.get_context("spawn"),
) as executor:
pickle_future = executor.submit(_pickle_complex_array_and_return_form_impl)
return pickle_future.result()


def test_default_pickler():
assert _pickle_complex_array_and_return_form_impl() == ak.forms.from_dict(
{"class": "ListOffsetArray", "offsets": "i64", "content": "int64"}
)


def test_noop_pickler(tmp_path):
assert (
pickle_complex_array_and_return_form(
"""
def plugin(obj, protocol: int):
return NotImplemented""",
tmp_path,
)
== ak.forms.from_dict(
{"class": "ListOffsetArray", "offsets": "i64", "content": "int64"}
)
)


def test_non_packing_pickler(tmp_path):
assert (
pickle_complex_array_and_return_form(
"""
def plugin(obj, protocol):
import awkward as ak
if isinstance(obj, ak.Array):
form, length, container = ak.to_buffers(obj)
return (
object.__new__,
(ak.Array,),
(form.to_dict(), length, container, obj.behavior),
)
else:
return NotImplemented""",
tmp_path,
)
== ak.forms.from_dict(
{"class": "ListArray", "starts": "i64", "stops": "i64", "content": "int64"}
)
)


def test_malformed_pickler(tmp_path):
with pytest.raises(RuntimeError, match=r"malformed pickler!"):
pickle_complex_array_and_return_form(
"""
def plugin(obj, protocol: int):
raise RuntimeError('malformed pickler!')""",
tmp_path,
)

0 comments on commit 519bba6

Please sign in to comment.