From 2eb80be987c4483c1b623645c0e7ffdb590710aa Mon Sep 17 00:00:00 2001 From: Claudia Pellegrino Date: Sun, 18 Aug 2024 22:57:13 +0200 Subject: [PATCH] Add support for typing.Required, NotRequired Previously, annotating a `TypedDict` field with one of the `Required` and `NotRequired` wrappers introduced in Python 3.11, the library would raise the following error: > TypeError: issubclass() arg 1 must be a class Fix that error by adding support for `Required` and `NotRequired`. Partially addresses issue rnag/dataclass-wizard#121. [1] [1]: https://github.com/rnag/dataclass-wizard/issues/121 --- dataclass_witch/__init__.py | 1 + dataclass_witch/loaders.py | 7 ++ dataclass_witch/type_def.py | 13 +++ dataclass_witch/utils/typing_compat.py | 1 + tests/conftest.py | 16 +++ tests/unit/test_load.py | 143 ++++++++++++++++++++++++- 6 files changed, 180 insertions(+), 1 deletion(-) diff --git a/dataclass_witch/__init__.py b/dataclass_witch/__init__.py index 8bde37d..085aa18 100644 --- a/dataclass_witch/__init__.py +++ b/dataclass_witch/__init__.py @@ -67,6 +67,7 @@ :copyright: (c) 2021 by Ritvik Nag. :license: Apache 2.0, see LICENSE for more details. """ + __all__ = [ # Force the linter to recognize that these are exports "JSONSerializable", diff --git a/dataclass_witch/loaders.py b/dataclass_witch/loaders.py index c7b70c1..9c61974 100644 --- a/dataclass_witch/loaders.py +++ b/dataclass_witch/loaders.py @@ -63,6 +63,8 @@ DefFactory, NoneType, JSONObject, + PyRequired, + PyNotRequired, M, N, T, @@ -409,6 +411,11 @@ def get_parser_for_annotation( base_cls, extras, base_types, cls.get_parser_for_annotation ) + elif base_type in (PyRequired, PyNotRequired): + # Given `Required[T]` or `NotRequired[T]`, we only need `T` + ann_type = get_args(ann_type)[0] + return cls.get_parser_for_annotation(ann_type, base_cls, extras) + elif issubclass(base_type, defaultdict): load_hook = hooks[defaultdict] return DefaultDictParser( diff --git a/dataclass_witch/type_def.py b/dataclass_witch/type_def.py index 9869d18..757b30e 100644 --- a/dataclass_witch/type_def.py +++ b/dataclass_witch/type_def.py @@ -5,6 +5,8 @@ "PyDeque", "PyTypedDict", "PyTypedDicts", + "PyRequired", + "PyNotRequired", "FrozenKeys", "DefFactory", "NoneType", @@ -139,6 +141,17 @@ PyTypedDicts.append(PyTypedDict) +# Python 3.11 introduced `Required` and `NotRequired` wrappers for +# `TypedDict` fields (PEP 655). Users of earlier Python versions may +# import them from `typing_extensions`. However, they then need to +# use `TypedDict` from `typing_extensions`, not from the standard +# library. +try: + from typing import Required as PyRequired + from typing import NotRequired as PyNotRequired +except ImportError: + from typing_extensions import Required as PyRequired + from typing_extensions import NotRequired as PyNotRequired # Forward references can be either strings or explicit `ForwardRef` objects. # noinspection SpellCheckingInspection diff --git a/dataclass_witch/utils/typing_compat.py b/dataclass_witch/utils/typing_compat.py index 35bb657..f0c2cf8 100644 --- a/dataclass_witch/utils/typing_compat.py +++ b/dataclass_witch/utils/typing_compat.py @@ -1,6 +1,7 @@ """ Utility module for checking generic types provided by the `typing` library. """ + import sys import types import typing diff --git a/tests/conftest.py b/tests/conftest.py index 6e874fc..3f62334 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,8 +3,12 @@ "data_file_path", "PY39_OR_ABOVE", "PY310_OR_ABOVE", + "PY310_OR_BELOW", + "PY311_OR_ABOVE", "Literal", "TypedDict", + "Required", + "NotRequired", "Annotated", "Deque", ] @@ -23,6 +27,10 @@ # Check if we are running Python 3.10+ PY310_OR_ABOVE = sys.version_info[:2] >= (3, 10) +# Check if we are running Python 3.11+ +PY311_OR_ABOVE = sys.version_info[:2] >= (3, 11) +PY310_OR_BELOW = not PY311_OR_ABOVE + try: from typing import Literal from typing import TypedDict @@ -34,6 +42,14 @@ from typing_extensions import Annotated from typing_extensions import Deque +# typing.Required and typing.NotRequired were introduced in Python 3.11 +if PY311_OR_ABOVE: + from typing import Required + from typing import NotRequired +else: + from typing_extensions import Required + from typing_extensions import NotRequired + def data_file_path(name: str) -> str: """Returns the full path to a test file.""" diff --git a/tests/unit/test_load.py b/tests/unit/test_load.py index 9c54230..6faf6c2 100644 --- a/tests/unit/test_load.py +++ b/tests/unit/test_load.py @@ -56,11 +56,14 @@ from dataclass_witch.type_def import NoneType, T from .conftest import MyUUIDSubclass from ..conftest import ( + PY310_OR_BELOW, does_not_raise, Literal, TypedDict, Annotated, Deque, + Required, + NotRequired, ) @@ -1371,7 +1374,7 @@ class MyClass(JSONSerializable): ), ], ) -def test_typed_dict_with_optional_fields(input, expectation, expected): +def test_typed_dict_with_all_fields_optional(input, expectation, expected): """ Test case for loading to a TypedDict which has `total=False`, indicating that all fields are optional. @@ -1396,6 +1399,144 @@ class MyClass(JSONSerializable): assert result.my_typed_dict == expected +@pytest.mark.skipif( + PY310_OR_BELOW, + reason=""" \ + Support for `Required` and `NotRequired` requires Python 3.11 or + higher to work reliably. + + Users who still want to use `Required` and `NotRequired` on older + Python versions (using the `typing_extensions` library) in a way + that works with dataclass-witch will have to use the `TypedDict` + version from `typing_extensions`, not from the standard library. + + The source code of the `typing_extension` package explains why: + + > The standard library TypedDict below Python 3.11 does not + > store runtime information about optional and required keys + > when using Required or NotRequired. [1] + + [1]: https://github.com/python/typing_extensions/blob/e1250ff869e7ee5ad05170d8a4b65469f13801c3/src/typing_extensions.py#L879-L880 + """, +) +@pytest.mark.parametrize( + "input,expectation,expected", + [ + ({}, pytest.raises(ParseError), None), + ({"key": "value"}, pytest.raises(ParseError), {}), + ( + {"my_str": "test", "my_int": 2, "my_bool": True, "other_key": "testing"}, + does_not_raise(), + {"my_str": "test", "my_int": 2, "my_bool": True}, + ), + ({"my_str": 3}, pytest.raises(ParseError), None), + ( + {"my_str": "test", "my_int": "test", "my_bool": True}, + pytest.raises(ValueError), + None, + ), + ( + {"my_str": "test", "my_int": 2, "my_bool": True}, + does_not_raise(), + {"my_str": "test", "my_int": 2, "my_bool": True}, + ), + ( + {"my_str": "test", "my_bool": True}, + does_not_raise(), + {"my_str": "test", "my_bool": True}, + ), + ( + # Incorrect type - `list`, but should be a `dict` + [{"my_str": "test", "my_int": 2, "my_bool": True}], + pytest.raises(ParseError), + None, + ), + ], +) +def test_typed_dict_with_one_field_not_required(input, expectation, expected): + """ + Test case for loading to a TypedDict whose fields are all mandatory + except for one field, whose annotated type is NotRequired. + + """ + + class MyDict(TypedDict): + my_str: str + my_bool: bool + my_int: NotRequired[int] + + @dataclass + class MyClass(JSONSerializable): + my_typed_dict: MyDict + + d = {"myTypedDict": input} + + with expectation: + result = MyClass.from_dict(d) + + log.debug("Parsed object: %r", result) + assert result.my_typed_dict == expected + + +@pytest.mark.skipif( + PY310_OR_BELOW, + reason=""" \ + Support for `Required` and `NotRequired` requires Python 3.11 or + higher to work reliably. + See `test_typed_dict_with_one_field_not_required` for details on + why this is the case. + """, +) +@pytest.mark.parametrize( + "input,expectation,expected", + [ + ({}, pytest.raises(ParseError), None), + ({"my_int": 2}, does_not_raise(), {"my_int": 2}), + ({"key": "value"}, pytest.raises(ParseError), None), + ({"key": "value", "my_int": 2}, does_not_raise(), {"my_int": 2}), + ( + {"my_str": "test", "my_int": 2, "my_bool": True, "other_key": "testing"}, + does_not_raise(), + {"my_str": "test", "my_int": 2, "my_bool": True}, + ), + ({"my_str": 3}, pytest.raises(ParseError), None), + ( + {"my_str": "test", "my_int": "test", "my_bool": True}, + pytest.raises(ValueError), + {"my_str": "test", "my_int": "test", "my_bool": True}, + ), + ( + {"my_str": "test", "my_int": 2, "my_bool": True}, + does_not_raise(), + {"my_str": "test", "my_int": 2, "my_bool": True}, + ), + ], +) +def test_typed_dict_with_one_field_required(input, expectation, expected): + """ + Test case for loading to a TypedDict whose fields are all optional + except for one field, whose annotated type is Required. + + """ + + class MyDict(TypedDict, total=False): + my_str: str + my_bool: bool + my_int: Required[int] + + @dataclass + class MyClass(JSONSerializable): + my_typed_dict: MyDict + + d = {"myTypedDict": input} + + with expectation: + result = MyClass.from_dict(d) + + log.debug("Parsed object: %r", result) + assert result.my_typed_dict == expected + + @pytest.mark.parametrize( "input,expectation,expected", [