Skip to content

Commit

Permalink
feat: add support for typing.Required, NotRequired
Browse files Browse the repository at this point in the history
Previously, annotating a `TypedDict` field with one of the `Required`
and `NotRequired` wrappers introduced in Python 3.11, dataclass-wizard
would raise the following error:

> TypeError: issubclass() arg 1 must be a class

Fix that error by adding support for `Required` and `NotRequired`.

Partially addresses issue #121. [1]

[1]: #121
  • Loading branch information
claui committed Aug 18, 2024
1 parent 40ea64d commit 068a145
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 3 deletions.
3 changes: 3 additions & 0 deletions dataclass_wizard/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
# Check if currently running Python 3.10 or higher
PY310_OR_ABOVE = _PY_VERSION >= (3, 10)

# Check if currently running Python 3.11 or higher
PY311_OR_ABOVE = _PY_VERSION >= (3, 11)

# The name of the dictionary object that contains `load` hooks for each
# object type. Also used to check if a class is a :class:`BaseLoadHook`
_LOAD_HOOKS = '__LOAD_HOOKS__'
Expand Down
7 changes: 7 additions & 0 deletions dataclass_wizard/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .parsers import *
from .type_def import (
ExplicitNull, FrozenKeys, DefFactory, NoneType, JSONObject,
PyRequired, PyNotRequired,
M, N, T, E, U, DD, LSQ, NT
)
from .utils.string_conv import to_snake_case
Expand Down Expand Up @@ -360,6 +361,12 @@ def get_parser_for_annotation(cls, ann_type: Type[T],
cls.get_parser_for_annotation
)

elif base_type in (PyRequired, PyNotRequired):
# Given `Required[T]` or `NotRequired[T]`, we only need `T`
ann_type = get_args(ann_type)[0]
return cls.get_parser_for_annotation(
ann_type, base_cls, extras)

elif issubclass(base_type, defaultdict):
load_hook = hooks[defaultdict]
return DefaultDictParser(
Expand Down
17 changes: 16 additions & 1 deletion dataclass_wizard/type_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
'PyDeque',
'PyTypedDict',
'PyTypedDicts',
'PyRequired',
'PyNotRequired',
'FrozenKeys',
'DefFactory',
'NoneType',
Expand Down Expand Up @@ -42,7 +44,7 @@
)
from uuid import UUID

from .constants import PY36, PY38_OR_ABOVE
from .constants import PY36, PY38_OR_ABOVE, PY311_OR_ABOVE
from .decorators import discard_kwargs


Expand Down Expand Up @@ -128,10 +130,23 @@
PyTypedDicts.append(PyTypedDict)
except ImportError:
pass

# Python 3.11 introduced `Required` and `NotRequired` wrappers for
# `TypedDict` fields (PEP 655). Python 3.8+ users can import the
# wrappers from `typing_extensions`.
if PY311_OR_ABOVE:
from typing import Required as PyRequired
from typing import NotRequired as PyNotRequired
else:
from typing_extensions import Required as PyRequired
from typing_extensions import NotRequired as PyNotRequired

else: # pragma: no cover
from typing_extensions import Literal as PyLiteral
from typing_extensions import Protocol as PyProtocol
from typing_extensions import TypedDict as PyTypedDict
from typing_extensions import Required as PyRequired
from typing_extensions import NotRequired as PyNotRequired
# Seems like `Deque` was only introduced to `typing` in 3.6.1, so Python
# 3.6.0 won't have it; to be safe, we'll instead import from the
# `typing_extensions` module here.
Expand Down
17 changes: 16 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
'PY36',
'PY39_OR_ABOVE',
'PY310_OR_ABOVE',
'PY311_OR_ABOVE',
# For compatibility with Python 3.6 and 3.7
'Literal',
'TypedDict',
'Annotated',
'Deque'
'Deque',
# For compatibility with Python 3.6 through 3.10
'Required',
'NotRequired'
]

import sys
Expand All @@ -27,6 +31,9 @@
# Check if we are running Python 3.10+
PY310_OR_ABOVE = sys.version_info[:2] >= (3, 10)

# Check if we are running Python 3.11+
PY311_OR_ABOVE = sys.version_info[:2] >= (3, 10)

# Ref: https://docs.pytest.org/en/6.2.x/example/parametrize.html#parametrizing-conditional-raising
if sys.version_info[:2] >= (3, 7):
from contextlib import nullcontext as does_not_raise
Expand Down Expand Up @@ -54,6 +61,14 @@
else:
from typing_extensions import Annotated

# typing.Required and typing.NotRequired: Introduced in Python 3.11
if PY311_OR_ABOVE:
from typing import Required
from typing import NotRequired
else:
from typing_extensions import Required
from typing_extensions import NotRequired


def data_file_path(name: str) -> str:
"""Returns the full path to a test file."""
Expand Down
126 changes: 125 additions & 1 deletion tests/unit/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,7 +1403,7 @@ class MyClass(JSONSerializable):
)
]
)
def test_typed_dict_with_optional_fields(input, expectation, expected):
def test_typed_dict_with_all_fields_optional(input, expectation, expected):
"""
Test case for loading to a TypedDict which has `total=False`, indicating
that all fields are optional.
Expand All @@ -1427,6 +1427,130 @@ class MyClass(JSONSerializable):
assert result.my_typed_dict == expected


@pytest.mark.skipif(PY36, reason='requires Python 3.7 or higher')
@pytest.mark.parametrize(
'input,expectation,expected',
[
(
{}, pytest.raises(ParseError), None
),
(
{'key': 'value'}, pytest.raises(ParseError), {}
),
(
{'my_str': 'test', 'my_int': 2,
'my_bool': True, 'other_key': 'testing'}, does_not_raise(),
{'my_str': 'test', 'my_int': 2, 'my_bool': True}
),
(
{'my_str': 3}, pytest.raises(ParseError), None
),
(
{'my_str': 'test', 'my_int': 'test', 'my_bool': True},
pytest.raises(ValueError), None,
),
(
{'my_str': 'test', 'my_int': 2, 'my_bool': True},
does_not_raise(),
{'my_str': 'test', 'my_int': 2, 'my_bool': True}
),
(
{'my_str': 'test', 'my_bool': True},
does_not_raise(),
{'my_str': 'test', 'my_bool': True}
),
(
# Incorrect type - `list`, but should be a `dict`
[{'my_str': 'test', 'my_int': 2, 'my_bool': True}],
pytest.raises(ParseError), None
)
]
)
def test_typed_dict_with_one_field_not_required(input, expectation, expected):
"""
Test case for loading to a TypedDict whose fields are all mandatory
except for one field, whose annotated type is NotRequired.
"""
class MyDict(TypedDict):
my_str: str
my_bool: bool
my_int: NotRequired[int]

@dataclass
class MyClass(JSONSerializable):
my_typed_dict: MyDict

d = {'myTypedDict': input}

with expectation:
result = MyClass.from_dict(d)

log.debug('Parsed object: %r', result)
assert result.my_typed_dict == expected


@pytest.mark.skipif(PY36, reason='requires Python 3.7 or higher')
@pytest.mark.parametrize(
'input,expectation,expected',
[
(
{}, pytest.raises(ParseError), None
),
(
{'my_int': 2}, does_not_raise(), {'my_int': 2}
),
(
{'key': 'value'}, pytest.raises(ParseError), None
),
(
{'key': 'value', 'my_int': 2}, does_not_raise(),
{'my_int': 2}
),
(
{'my_str': 'test', 'my_int': 2,
'my_bool': True, 'other_key': 'testing'}, does_not_raise(),
{'my_str': 'test', 'my_int': 2, 'my_bool': True}
),
(
{'my_str': 3}, pytest.raises(ParseError), None
),
(
{'my_str': 'test', 'my_int': 'test', 'my_bool': True},
pytest.raises(ValueError),
{'my_str': 'test', 'my_int': 'test', 'my_bool': True}
),
(
{'my_str': 'test', 'my_int': 2, 'my_bool': True},
does_not_raise(),
{'my_str': 'test', 'my_int': 2, 'my_bool': True}
)
]
)
def test_typed_dict_with_one_field_required(input, expectation, expected):
"""
Test case for loading to a TypedDict whose fields are all optional
except for one field, whose annotated type is Required.
"""
class MyDict(TypedDict, total=False):
my_str: str
my_bool: bool
my_int: Required[int]

@dataclass
class MyClass(JSONSerializable):
my_typed_dict: MyDict

d = {'myTypedDict': input}

with expectation:
result = MyClass.from_dict(d)

log.debug('Parsed object: %r', result)
assert result.my_typed_dict == expected


@pytest.mark.parametrize(
'input,expectation,expected',
[
Expand Down

0 comments on commit 068a145

Please sign in to comment.