Skip to content

Commit

Permalink
Add support for typing.Required, NotRequired
Browse files Browse the repository at this point in the history
Previously, annotating a `TypedDict` field with one of the `Required`
and `NotRequired` wrappers introduced in Python 3.11, the library would
raise the following error:

> TypeError: issubclass() arg 1 must be a class

Fix that error by adding support for `Required` and `NotRequired`.

Partially addresses issue rnag#121. [1]

[1]: rnag#121
  • Loading branch information
claui committed Aug 18, 2024
1 parent 1e4043d commit 2eb80be
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 1 deletion.
1 change: 1 addition & 0 deletions dataclass_witch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
:copyright: (c) 2021 by Ritvik Nag.
:license: Apache 2.0, see LICENSE for more details.
"""

__all__ = [
# Force the linter to recognize that these are exports
"JSONSerializable",
Expand Down
7 changes: 7 additions & 0 deletions dataclass_witch/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
DefFactory,
NoneType,
JSONObject,
PyRequired,
PyNotRequired,
M,
N,
T,
Expand Down Expand Up @@ -409,6 +411,11 @@ def get_parser_for_annotation(
base_cls, extras, base_types, cls.get_parser_for_annotation
)

elif base_type in (PyRequired, PyNotRequired):
# Given `Required[T]` or `NotRequired[T]`, we only need `T`
ann_type = get_args(ann_type)[0]
return cls.get_parser_for_annotation(ann_type, base_cls, extras)

elif issubclass(base_type, defaultdict):
load_hook = hooks[defaultdict]
return DefaultDictParser(
Expand Down
13 changes: 13 additions & 0 deletions dataclass_witch/type_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
"PyDeque",
"PyTypedDict",
"PyTypedDicts",
"PyRequired",
"PyNotRequired",
"FrozenKeys",
"DefFactory",
"NoneType",
Expand Down Expand Up @@ -139,6 +141,17 @@

PyTypedDicts.append(PyTypedDict)

# Python 3.11 introduced `Required` and `NotRequired` wrappers for
# `TypedDict` fields (PEP 655). Users of earlier Python versions may
# import them from `typing_extensions`. However, they then need to
# use `TypedDict` from `typing_extensions`, not from the standard
# library.
try:
from typing import Required as PyRequired
from typing import NotRequired as PyNotRequired
except ImportError:
from typing_extensions import Required as PyRequired
from typing_extensions import NotRequired as PyNotRequired

# Forward references can be either strings or explicit `ForwardRef` objects.
# noinspection SpellCheckingInspection
Expand Down
1 change: 1 addition & 0 deletions dataclass_witch/utils/typing_compat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Utility module for checking generic types provided by the `typing` library.
"""

import sys
import types
import typing
Expand Down
16 changes: 16 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
"data_file_path",
"PY39_OR_ABOVE",
"PY310_OR_ABOVE",
"PY310_OR_BELOW",
"PY311_OR_ABOVE",
"Literal",
"TypedDict",
"Required",
"NotRequired",
"Annotated",
"Deque",
]
Expand All @@ -23,6 +27,10 @@
# Check if we are running Python 3.10+
PY310_OR_ABOVE = sys.version_info[:2] >= (3, 10)

# Check if we are running Python 3.11+
PY311_OR_ABOVE = sys.version_info[:2] >= (3, 11)
PY310_OR_BELOW = not PY311_OR_ABOVE

try:
from typing import Literal
from typing import TypedDict
Expand All @@ -34,6 +42,14 @@
from typing_extensions import Annotated
from typing_extensions import Deque

# typing.Required and typing.NotRequired were introduced in Python 3.11
if PY311_OR_ABOVE:
from typing import Required
from typing import NotRequired
else:
from typing_extensions import Required
from typing_extensions import NotRequired


def data_file_path(name: str) -> str:
"""Returns the full path to a test file."""
Expand Down
143 changes: 142 additions & 1 deletion tests/unit/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,14 @@
from dataclass_witch.type_def import NoneType, T
from .conftest import MyUUIDSubclass
from ..conftest import (
PY310_OR_BELOW,
does_not_raise,
Literal,
TypedDict,
Annotated,
Deque,
Required,
NotRequired,
)


Expand Down Expand Up @@ -1371,7 +1374,7 @@ class MyClass(JSONSerializable):
),
],
)
def test_typed_dict_with_optional_fields(input, expectation, expected):
def test_typed_dict_with_all_fields_optional(input, expectation, expected):
"""
Test case for loading to a TypedDict which has `total=False`, indicating
that all fields are optional.
Expand All @@ -1396,6 +1399,144 @@ class MyClass(JSONSerializable):
assert result.my_typed_dict == expected


@pytest.mark.skipif(
PY310_OR_BELOW,
reason=""" \
Support for `Required` and `NotRequired` requires Python 3.11 or
higher to work reliably.
Users who still want to use `Required` and `NotRequired` on older
Python versions (using the `typing_extensions` library) in a way
that works with dataclass-witch will have to use the `TypedDict`
version from `typing_extensions`, not from the standard library.
The source code of the `typing_extension` package explains why:
> The standard library TypedDict below Python 3.11 does not
> store runtime information about optional and required keys
> when using Required or NotRequired. [1]
[1]: https://github.com/python/typing_extensions/blob/e1250ff869e7ee5ad05170d8a4b65469f13801c3/src/typing_extensions.py#L879-L880
""",
)
@pytest.mark.parametrize(
"input,expectation,expected",
[
({}, pytest.raises(ParseError), None),
({"key": "value"}, pytest.raises(ParseError), {}),
(
{"my_str": "test", "my_int": 2, "my_bool": True, "other_key": "testing"},
does_not_raise(),
{"my_str": "test", "my_int": 2, "my_bool": True},
),
({"my_str": 3}, pytest.raises(ParseError), None),
(
{"my_str": "test", "my_int": "test", "my_bool": True},
pytest.raises(ValueError),
None,
),
(
{"my_str": "test", "my_int": 2, "my_bool": True},
does_not_raise(),
{"my_str": "test", "my_int": 2, "my_bool": True},
),
(
{"my_str": "test", "my_bool": True},
does_not_raise(),
{"my_str": "test", "my_bool": True},
),
(
# Incorrect type - `list`, but should be a `dict`
[{"my_str": "test", "my_int": 2, "my_bool": True}],
pytest.raises(ParseError),
None,
),
],
)
def test_typed_dict_with_one_field_not_required(input, expectation, expected):
"""
Test case for loading to a TypedDict whose fields are all mandatory
except for one field, whose annotated type is NotRequired.
"""

class MyDict(TypedDict):
my_str: str
my_bool: bool
my_int: NotRequired[int]

@dataclass
class MyClass(JSONSerializable):
my_typed_dict: MyDict

d = {"myTypedDict": input}

with expectation:
result = MyClass.from_dict(d)

log.debug("Parsed object: %r", result)
assert result.my_typed_dict == expected


@pytest.mark.skipif(
PY310_OR_BELOW,
reason=""" \
Support for `Required` and `NotRequired` requires Python 3.11 or
higher to work reliably.
See `test_typed_dict_with_one_field_not_required` for details on
why this is the case.
""",
)
@pytest.mark.parametrize(
"input,expectation,expected",
[
({}, pytest.raises(ParseError), None),
({"my_int": 2}, does_not_raise(), {"my_int": 2}),
({"key": "value"}, pytest.raises(ParseError), None),
({"key": "value", "my_int": 2}, does_not_raise(), {"my_int": 2}),
(
{"my_str": "test", "my_int": 2, "my_bool": True, "other_key": "testing"},
does_not_raise(),
{"my_str": "test", "my_int": 2, "my_bool": True},
),
({"my_str": 3}, pytest.raises(ParseError), None),
(
{"my_str": "test", "my_int": "test", "my_bool": True},
pytest.raises(ValueError),
{"my_str": "test", "my_int": "test", "my_bool": True},
),
(
{"my_str": "test", "my_int": 2, "my_bool": True},
does_not_raise(),
{"my_str": "test", "my_int": 2, "my_bool": True},
),
],
)
def test_typed_dict_with_one_field_required(input, expectation, expected):
"""
Test case for loading to a TypedDict whose fields are all optional
except for one field, whose annotated type is Required.
"""

class MyDict(TypedDict, total=False):
my_str: str
my_bool: bool
my_int: Required[int]

@dataclass
class MyClass(JSONSerializable):
my_typed_dict: MyDict

d = {"myTypedDict": input}

with expectation:
result = MyClass.from_dict(d)

log.debug("Parsed object: %r", result)
assert result.my_typed_dict == expected


@pytest.mark.parametrize(
"input,expectation,expected",
[
Expand Down

0 comments on commit 2eb80be

Please sign in to comment.