Skip to content

Commit

Permalink
Update Flag masks in Python 3.11+ only, change stub reference
Browse files Browse the repository at this point in the history
Also incorporates the change to `enum.Flag` and verbose names for enum
members in the stub.
  • Loading branch information
nicholasjng committed Aug 27, 2024
1 parent fda2f2f commit 77f7dfc
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 12 deletions.
4 changes: 3 additions & 1 deletion src/nb_enum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ void enum_append(PyObject *tp_, const char *name_, int64_t value_,
fail("refusing to add duplicate key \"%s\" to enumeration \"%s\"!",
name_, type_name(tp).c_str());

// handle updates of the flag and bit masks by hand,
# if PY_VERSION_HEX >= 0x030B0000
// In Python 3.11+, update the flag and bit masks by hand,
// since enum._proto_member.__set_name__ is not called in this code path.
if (t->flags & (uint32_t) enum_flags::flag_enum) {
int64_t flag_mask = cast<int64_t>(tp.attr("_flag_mask_"));
Expand All @@ -133,6 +134,7 @@ void enum_append(PyObject *tp_, const char *name_, int64_t value_,
int64_t bit_length = cast<int64_t>(tp.attr("_flag_mask_").attr("bit_length")());
tp.attr("_all_bits_") = int_((2 << bit_length) - 1);
}
#endif

object el;
if (issubclass(tp, val_tp))
Expand Down
5 changes: 3 additions & 2 deletions tests/test_enum.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import test_enum_ext as t
import enum
import pytest

def test01_unsigned_enum():
Expand Down Expand Up @@ -144,7 +143,9 @@ def test06_enum_flag():
assert (t.Flag(3) == (t.Flag.A | t.Flag.B))

# ensure the flag mask is set correctly by enum_append
assert t.Flag._flag_mask_ == 7
# in Python 3.11+
if hasattr(t.Flag, "_flag_mask_"):
assert t.Flag._flag_mask_ == 7
assert (t.from_enum(t.Flag.A | t.Flag.C) == 5)
assert (t.from_enum_implicit(t.Flag(1) | t.Flag(4)) == 5)

Expand Down
13 changes: 4 additions & 9 deletions tests/test_enum_ext.pyi.ref
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ import enum
from typing import overload


A: Flag = 1
A: Flag = Flag.A

B: Flag = 2
B: Flag = Flag.B

C: Flag = 4
C: Flag = Flag.C

class ClassicEnum(enum.Enum):
Item1 = 0
Expand Down Expand Up @@ -41,14 +41,9 @@ class EnumProperty:
@property
def read_enum(self) -> Enum: ...

class Flag(enum.IntFlag):
class Flag(enum.Flag):
"""enum-level docstring"""

__str__ = __repr__

def __repr__(self, /):
"""Return repr(self)."""

A = 1
"""Value A"""

Expand Down

0 comments on commit 77f7dfc

Please sign in to comment.