Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix recursive deserialization of cbor bytes #194

Merged
merged 1 commit into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 42 additions & 38 deletions pycardano/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

import typing

import re
from collections import OrderedDict, UserList, defaultdict
from copy import deepcopy
Expand Down Expand Up @@ -413,65 +415,67 @@ def _restore_dataclass_field(
Returns:
Union[:const:`Primitive`, CBORSerializable]: A CBOR primitive or a CBORSerializable.
"""

if "object_hook" in f.metadata:
return f.metadata["object_hook"](v)
elif isclass(f.type) and issubclass(f.type, CBORSerializable):
return f.type.from_primitive(v)
elif hasattr(f.type, "__origin__") and (f.type.__origin__ is list):
t_args = f.type.__args__
return _restore_typed_primitive(f.type, v)


def _restore_typed_primitive(
t: typing.Type, v: Primitive
) -> Union[Primitive, CBORSerializable]:
"""Try to restore a value back to its original type based on information given in field.

Args:
f (type): A type
v (:const:`Primitive`): A CBOR primitive.

Returns:
Union[:const:`Primitive`, CBORSerializable]: A CBOR primitive or a CBORSerializable.
"""
if t in PRIMITIVE_TYPES and isinstance(v, t):
return v
elif isclass(t) and issubclass(t, CBORSerializable):
return t.from_primitive(v)
elif hasattr(t, "__origin__") and (t.__origin__ is list):
t_args = t.__args__
if len(t_args) != 1:
raise DeserializeException(
f"List types need exactly one type argument, but got {t_args}"
)
t = t_args[0]
if not isinstance(v, list):
raise DeserializeException(f"Expected type list but got {type(v)}")
if isclass(t) and issubclass(t, CBORSerializable):
return IndefiniteList([t.from_primitive(w) for w in v])
else:
return IndefiniteList(v)
elif isclass(f.type) and issubclass(f.type, IndefiniteList):
return IndefiniteList([_restore_typed_primitive(t, w) for w in v])
elif isclass(t) and issubclass(t, IndefiniteList):
return IndefiniteList(v)
elif hasattr(f.type, "__origin__") and (f.type.__origin__ is dict):
t_args = f.type.__args__
elif hasattr(t, "__origin__") and (t.__origin__ is dict):
t_args = t.__args__
if len(t_args) != 2:
raise DeserializeException(
f"Dict types need exactly two type arguments, but got {t_args}"
)
key_t = t_args[0]
val_t = t_args[1]
if isclass(key_t) and issubclass(key_t, CBORSerializable):
key_converter = key_t.from_primitive
else:
key_converter = _identity
if isclass(val_t) and issubclass(val_t, CBORSerializable):
val_converter = val_t.from_primitive
else:
val_converter = _identity
if not isinstance(v, dict):
raise DeserializeException(f"Expected dict type but got {type(v)}")
return {key_converter(key): val_converter(val) for key, val in v.items()}
elif hasattr(f.type, "__origin__") and (
f.type.__origin__ is Union or f.type.__origin__ is Optional
return {
_restore_typed_primitive(key_t, key): _restore_typed_primitive(val_t, val)
for key, val in v.items()
}
elif hasattr(t, "__origin__") and (
t.__origin__ is Union or t.__origin__ is Optional
):
t_args = f.type.__args__
t_args = t.__args__
for t in t_args:
if isclass(t) and issubclass(t, IndefiniteList):
return IndefiniteList(v)
elif isclass(t) and issubclass(t, CBORSerializable):
try:
return t.from_primitive(v)
except DeserializeException:
pass
else:
if not isclass(t) and hasattr(t, "__origin__"):
t = t.__origin__
if t in PRIMITIVE_TYPES and isinstance(v, t):
return v
try:
return _restore_typed_primitive(t, v)
except DeserializeException:
pass
raise DeserializeException(
f"Cannot deserialize object: \n{v}\n in any valid type from {t_args}."
)
return v
raise DeserializeException(f"Cannot deserialize object: \n{v}\n to type {t}.")


ArrayBase = TypeVar("ArrayBase", bound="ArrayCBORSerializable")
Expand Down Expand Up @@ -556,8 +560,8 @@ def to_shallow_primitive(self) -> List[Primitive]:
return primitives

@classmethod
@limit_primitive_type(list)
def from_primitive(cls: Type[ArrayBase], values: list) -> ArrayBase:
@limit_primitive_type(list, tuple)
def from_primitive(cls: Type[ArrayBase], values: Union[list, tuple]) -> ArrayBase:
"""Restore a primitive value to its original class type.

Args:
Expand Down
6 changes: 4 additions & 2 deletions test/pycardano/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Union

from dataclasses import dataclass, field
from test.pycardano.util import check_two_way_cbor

Expand Down Expand Up @@ -40,7 +42,7 @@ def test_array_cbor_serializable():
@dataclass
class Test1(ArrayCBORSerializable):
a: str
b: str = None
b: Union[str, None] = None

@dataclass
class Test2(ArrayCBORSerializable):
Expand Down Expand Up @@ -87,7 +89,7 @@ class Test1(MapCBORSerializable):

@dataclass
class Test2(MapCBORSerializable):
c: str = None
c: Union[str, None] = None
test1: Test1 = field(default_factory=Test1)

t = Test2(test1=Test1(a="a"))
Expand Down