Skip to content

Commit

Permalink
Fix none check in variant caster
Browse files Browse the repository at this point in the history
  • Loading branch information
yosh-matsuda committed Oct 9, 2023
1 parent 30a6bac commit 074375c
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
9 changes: 6 additions & 3 deletions include/nanobind/stl/variant.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ template <> struct type_caster<std::monostate> {
NB_TYPE_CASTER(std::monostate, const_name("None"));

bool from_python(handle src, uint8_t, cleanup_list *) noexcept {
if (src.is_none())
return true;
return false;
return src.is_none();
}

static handle from_cpp(const std::monostate &, rv_policy,
Expand All @@ -54,6 +52,11 @@ template <typename... Ts> struct type_caster<std::variant<Ts...>> {
"type caster was registered to intercept this particular "
"type, which is not allowed.");

if constexpr (!std::is_pointer_v<T> && is_base_caster_v<CasterT>) {
if (src.is_none())
return false;
}

CasterT caster;

if (!caster.from_python(src, flags, cleanup))
Expand Down
2 changes: 1 addition & 1 deletion tests/test_stl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ NB_MODULE(test_stl_ext, m) {

// ----- test43-test50 ------
m.def("variant_copyable", [](std::variant<Copyable, int> &) {});
m.def("variant_copyable_none", [](std::variant<std::monostate, Copyable, int> &) {}, nb::arg("x").none());
m.def("variant_copyable_none", [](std::variant<int, Copyable, std::monostate> &) {}, nb::arg("x").none());
m.def("variant_copyable_ptr", [](std::variant<Copyable *, int> &) {});
m.def("variant_copyable_ptr_none", [](std::variant<Copyable *, int> &) {}, nb::arg("x").none());
m.def("variant_ret_var_copyable", []() { return std::variant<Copyable, int>(); });
Expand Down
2 changes: 1 addition & 1 deletion tests/test_stl.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def test44_std_variant_copyable_none(clean):
t.variant_copyable_none(5)
t.variant_copyable_none(None)
assert t.variant_copyable_none.__doc__ == (
"variant_copyable_none(x: Optional[Union[test_stl_ext.Copyable, int]]) -> None"
"variant_copyable_none(x: Optional[Union[int, test_stl_ext.Copyable]]) -> None"
)
assert_stats(
default_constructed=1,
Expand Down

0 comments on commit 074375c

Please sign in to comment.