Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add escape hatch for custom JSON serialization #1955

Merged
merged 6 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 42 additions & 24 deletions kombu/utils/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def default(self, o):

for t, (marker, encoder) in _encoders.items():
if isinstance(o, t):
return _as(marker, encoder(o))
return (
encoder(o) if marker is None else _as(marker, encoder(o))
)

# Bytes is slightly trickier, so we cannot put them directly
# into _encoders, because we use two formats: bytes, and base64.
Expand All @@ -50,7 +52,11 @@ def _as(t: str, v: Any):


def dumps(
s, _dumps=json.dumps, cls=JSONEncoder, default_kwargs=None, **kwargs
s,
_dumps=json.dumps,
cls=JSONEncoder,
default_kwargs=None,
**kwargs
):
"""Serialize object to json string."""
default_kwargs = default_kwargs or {}
Expand Down Expand Up @@ -94,35 +100,47 @@ def loads(s, _loads=json.loads, decode_bytes=True, object_hook=object_hook):

def register_type(
t: type[T],
marker: str,
marker: str | None,
encoder: Callable[[T], EncodedT],
decoder: Callable[[EncodedT], T],
decoder: Callable[[EncodedT], T] = lambda d: d,
):
"""Add support for serializing/deserializing native python type."""
"""Add support for serializing/deserializing native python type.

If marker is `None`, the encoding is a pure transformation and the result
is not placed in an envelope, so `decoder` is unnecessary. Decoding must
instead be handled outside this library.
"""
_encoders[t] = (marker, encoder)
_decoders[marker] = decoder
if marker is not None:
_decoders[marker] = decoder


_encoders: dict[type, tuple[str, EncoderT]] = {}
_encoders: dict[type, tuple[str | None, EncoderT]] = {}
_decoders: dict[str, DecoderT] = {
"bytes": lambda o: o.encode("utf-8"),
"base64": lambda o: base64.b64decode(o.encode("utf-8")),
}

# NOTE: datetime should be registered before date,
# because datetime is also instance of date.
register_type(datetime, "datetime", datetime.isoformat, datetime.fromisoformat)
register_type(
date,
"date",
lambda o: o.isoformat(),
lambda o: datetime.fromisoformat(o).date(),
)
register_type(time, "time", lambda o: o.isoformat(), time.fromisoformat)
register_type(Decimal, "decimal", str, Decimal)
register_type(
uuid.UUID,
"uuid",
lambda o: {"hex": o.hex},
lambda o: uuid.UUID(**o),
)

def _register_default_types():
# NOTE: datetime should be registered before date,
# because datetime is also instance of date.
register_type(datetime, "datetime", datetime.isoformat,
datetime.fromisoformat)
register_type(
date,
"date",
lambda o: o.isoformat(),
lambda o: datetime.fromisoformat(o).date(),
)
register_type(time, "time", lambda o: o.isoformat(), time.fromisoformat)
register_type(Decimal, "decimal", str, Decimal)
register_type(
uuid.UUID,
"uuid",
lambda o: {"hex": o.hex},
lambda o: uuid.UUID(**o),
)


_register_default_types()
43 changes: 42 additions & 1 deletion t/unit/utils/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import uuid
from collections import namedtuple
from dataclasses import dataclass
from datetime import datetime
from decimal import Decimal

Expand All @@ -11,7 +12,8 @@
from hypothesis import strategies as st

from kombu.utils.encoding import str_to_bytes
from kombu.utils.json import dumps, loads
from kombu.utils.json import (_register_default_types, dumps, loads,
register_type)

if sys.version_info >= (3, 9):
from zoneinfo import ZoneInfo
Expand All @@ -28,6 +30,10 @@ def __json__(self):


class test_JSONEncoder:
@pytest.fixture(autouse=True)
def reset_registered_types(self):
_register_default_types()

@pytest.mark.freeze_time("2015-10-21")
def test_datetime(self):
now = datetime.utcnow()
Expand Down Expand Up @@ -82,6 +88,41 @@ def test_UUID(self):
assert loaded_value == {'u': id}
assert loaded_value["u"].version == id.version

def test_register_type_overrides_defaults(self):
# This type is already registered by default, let's override it
register_type(uuid.UUID, "uuid", lambda o: "custom", lambda o: o)
value = uuid.uuid4()
loaded_value = loads(dumps({'u': value}))
assert loaded_value == {'u': "custom"}

def test_register_type_with_new_type(self):
# Guaranteed never before seen type
@dataclass()
class SomeType:
a: int

register_type(SomeType, "some_type", lambda o: "custom", lambda o: o)
value = SomeType(42)
loaded_value = loads(dumps({'u': value}))
assert loaded_value == {'u': "custom"}

def test_register_type_with_empty_marker(self):
register_type(
datetime,
None,
lambda o: o.isoformat(),
lambda o: "should never be used"
)
now = datetime.utcnow()
serialized_str = dumps({'now': now})
deserialized_value = loads(serialized_str)

assert "__type__" not in serialized_str
assert "__value__" not in serialized_str

# Check that there is no extra deserialization happening
assert deserialized_value == {'now': now.isoformat()}

def test_default(self):
with pytest.raises(TypeError):
dumps({'o': object()})
Expand Down