diff --git a/src/nb_enum.cpp b/src/nb_enum.cpp index 1cf96f77..4a373097 100644 --- a/src/nb_enum.cpp +++ b/src/nb_enum.cpp @@ -119,7 +119,8 @@ void enum_append(PyObject *tp_, const char *name_, int64_t value_, fail("refusing to add duplicate key \"%s\" to enumeration \"%s\"!", name_, type_name(tp).c_str()); - // handle updates of the flag and bit masks by hand, + # if PY_VERSION_HEX >= 0x030B0000 + // In Python 3.11+, update the flag and bit masks by hand, // since enum._proto_member.__set_name__ is not called in this code path. if (t->flags & (uint32_t) enum_flags::flag_enum) { int64_t flag_mask = cast(tp.attr("_flag_mask_")); @@ -133,6 +134,7 @@ void enum_append(PyObject *tp_, const char *name_, int64_t value_, int64_t bit_length = cast(tp.attr("_flag_mask_").attr("bit_length")()); tp.attr("_all_bits_") = int_((2 << bit_length) - 1); } + #endif object el; if (issubclass(tp, val_tp)) diff --git a/tests/test_enum.py b/tests/test_enum.py index 81bdccf0..19d760fe 100644 --- a/tests/test_enum.py +++ b/tests/test_enum.py @@ -1,5 +1,4 @@ import test_enum_ext as t -import enum import pytest def test01_unsigned_enum(): @@ -144,7 +143,9 @@ def test06_enum_flag(): assert (t.Flag(3) == (t.Flag.A | t.Flag.B)) # ensure the flag mask is set correctly by enum_append - assert t.Flag._flag_mask_ == 7 + # in Python 3.11+ + if hasattr(t.Flag, "_flag_mask_"): + assert t.Flag._flag_mask_ == 7 assert (t.from_enum(t.Flag.A | t.Flag.C) == 5) assert (t.from_enum_implicit(t.Flag(1) | t.Flag(4)) == 5) diff --git a/tests/test_enum_ext.pyi.ref b/tests/test_enum_ext.pyi.ref index 56eed6d1..0c7406ad 100644 --- a/tests/test_enum_ext.pyi.ref +++ b/tests/test_enum_ext.pyi.ref @@ -2,11 +2,11 @@ import enum from typing import overload -A: Flag = 1 +A: Flag = Flag.A -B: Flag = 2 +B: Flag = Flag.B -C: Flag = 4 +C: Flag = Flag.C class ClassicEnum(enum.Enum): Item1 = 0 @@ -41,14 +41,9 @@ class EnumProperty: @property def read_enum(self) -> Enum: ... -class Flag(enum.IntFlag): +class Flag(enum.Flag): """enum-level docstring""" - __str__ = __repr__ - - def __repr__(self, /): - """Return repr(self).""" - A = 1 """Value A"""