From c8c0b3120306f76b17b285f69b7aa72386d01ee1 Mon Sep 17 00:00:00 2001 From: Qingran Zheng Date: Wed, 1 Nov 2023 13:03:35 -0700 Subject: [PATCH] Add type_caster for std::nullopt --- include/nanobind/stl/optional.h | 15 +++++++++++++++ tests/test_stl.cpp | 1 + tests/test_stl.py | 14 ++++++++------ 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/include/nanobind/stl/optional.h b/include/nanobind/stl/optional.h index 7ba128f8..90eb5a49 100644 --- a/include/nanobind/stl/optional.h +++ b/include/nanobind/stl/optional.h @@ -59,5 +59,20 @@ struct type_caster> { } }; + +template <> struct type_caster { + bool from_python(handle src, uint8_t, cleanup_list *) noexcept { + if (src.is_none()) + return true; + return false; + } + + static handle from_cpp(std::nullopt_t, rv_policy, cleanup_list *) noexcept { + return none().release(); + } + + NB_TYPE_CASTER(std::nullopt_t, const_name("None")); +}; + NAMESPACE_END(detail) NAMESPACE_END(NB_NAMESPACE) diff --git a/tests/test_stl.cpp b/tests/test_stl.cpp index 4e1efc2d..1ede2e56 100644 --- a/tests/test_stl.cpp +++ b/tests/test_stl.cpp @@ -253,6 +253,7 @@ NB_MODULE(test_stl_ext, m) { m.def("optional_ret_opt_movable_ptr", []() { return new std::optional(new Movable()); }); m.def("optional_ret_opt_none", []() { return std::optional(); }); m.def("optional_unbound_type", [](std::optional &x) { return x; }, nb::arg("x") = nb::none()); + m.def("optional_unbound_type_with_nullopt_as_default", [](std::optional &x) { return x; }, nb::arg("x") = std::nullopt); // ----- test43-test50 ------ m.def("variant_copyable", [](std::variant &) {}); diff --git a/tests/test_stl.py b/tests/test_stl.py index c4547760..e591973e 100644 --- a/tests/test_stl.py +++ b/tests/test_stl.py @@ -438,12 +438,14 @@ def test41_std_optional_ret_opt_none(): def test42_std_optional_unbound_type(): - assert t.optional_unbound_type(3) == 3 - assert t.optional_unbound_type(None) is None - assert t.optional_unbound_type() is None - assert t.optional_unbound_type.__doc__ == ( - "optional_unbound_type(x: Optional[int] = None) -> Optional[int]" - ) + for method_name in ("optional_unbound_type", "optional_unbound_type_with_nullopt_as_default"): + method = getattr(t, method_name) + assert method(3) == 3 + assert method(None) is None + assert method() is None + assert method.__doc__ == ( + f"{method_name}(x: Optional[int] = None) -> Optional[int]" + ) def test43_std_variant_copyable(clean):