Skip to content

Commit

Permalink
Add flag and bit mask handling, encode assumptions in existing test
Browse files Browse the repository at this point in the history
The enum_append() method needed an extra branch to update the
`_flag_mask_`, `_singles_mask_`, and `_all_bits_` members in the Flag
case correctly.
  • Loading branch information
nicholasjng committed Aug 27, 2024
1 parent 6cb336f commit fda2f2f
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
2 changes: 1 addition & 1 deletion include/nanobind/nb_class.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ NB_INLINE void enum_extra_apply(enum_init_data &e, is_arithmetic) {
}

NB_INLINE void enum_extra_apply(enum_init_data &e, flag_enum) {
e.flags |= (uint32_t) type_flags::flag_enum;
e.flags |= (uint32_t) enum_flags::flag_enum;
}

NB_INLINE void enum_extra_apply(enum_init_data &e, const char *doc) {
Expand Down
19 changes: 17 additions & 2 deletions src/nb_enum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ PyObject *enum_create(enum_init_data *ed) noexcept {
handle scope(ed->scope);

bool is_arithmetic = ed->flags & (uint32_t) enum_flags::is_arithmetic;
bool is_flag_enum = ed->flags & (uint32_t) enum_flags::flag_enum;
bool is_flag = ed->flags & (uint32_t) enum_flags::flag_enum;

str name(ed->name), qualname = name;
object modname;
Expand Down Expand Up @@ -119,6 +119,21 @@ void enum_append(PyObject *tp_, const char *name_, int64_t value_,
fail("refusing to add duplicate key \"%s\" to enumeration \"%s\"!",
name_, type_name(tp).c_str());

// handle updates of the flag and bit masks by hand,
// since enum._proto_member.__set_name__ is not called in this code path.
if (t->flags & (uint32_t) enum_flags::flag_enum) {
int64_t flag_mask = cast<int64_t>(tp.attr("_flag_mask_"));
tp.attr("_flag_mask_") = int_(flag_mask | value_);

bool is_single_bit = (value_ != 0) && (value_ & (value_ - 1)) == 0;
if (is_single_bit) {
int64_t singles_mask = cast<int64_t>(tp.attr("_singles_mask_"));
tp.attr("_singles_mask_") = int_(singles_mask | value_);
}
int64_t bit_length = cast<int64_t>(tp.attr("_flag_mask_").attr("bit_length")());
tp.attr("_all_bits_") = int_((2 << bit_length) - 1);
}

object el;
if (issubclass(tp, val_tp))
el = val_tp.attr("__new__")(tp, val);
Expand Down Expand Up @@ -156,7 +171,7 @@ bool enum_from_python(const std::type_info *tp, PyObject *o, int64_t *out, uint8
if (!t)
return false;

if ((t->flags & (uint32_t) type_flags::flag_enum) != 0 && Py_TYPE(o) == t->type_py) {
if ((t->flags & (uint32_t) enum_flags::flag_enum) != 0 && Py_TYPE(o) == t->type_py) {
auto pValue = PyObject_GetAttrString(o, "value");
if (pValue == nullptr) {
PyErr_Clear();
Expand Down
3 changes: 3 additions & 0 deletions tests/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ def test06_enum_flag():
assert (t.Flag(3) & t.Flag(1)).value == 1
assert (t.Flag(3) ^ t.Flag(1)).value == 2
assert (t.Flag(3) == (t.Flag.A | t.Flag.B))

# ensure the flag mask is set correctly by enum_append
assert t.Flag._flag_mask_ == 7
assert (t.from_enum(t.Flag.A | t.Flag.C) == 5)
assert (t.from_enum_implicit(t.Flag(1) | t.Flag(4)) == 5)

Expand Down

0 comments on commit fda2f2f

Please sign in to comment.