diff --git a/docs/api_core.rst b/docs/api_core.rst index e6e88bce..cd51dcd9 100644 --- a/docs/api_core.rst +++ b/docs/api_core.rst @@ -1262,6 +1262,20 @@ Casting implementation may also attempt *implicit conversions* to perform the cast. The function raises a :cpp:type:`cast_error` when the conversion fails. + See :cpp:func:`try_cast()` for an alternative that never raises. + +.. cpp:function:: template bool try_cast(const detail::api &value, T &out, bool convert = true) noexcept + + Convert the Python object `value` (typically a :cpp:class:`handle` or a + :cpp:class:`object` subclass) into a C++ object of type `T`, and store it + in the output parameter `out`. + + When the `convert` argument is set to ``true`` (the default), the + implementation may also attempt *implicit conversions* to perform the cast. + + The function returns `false` when the conversion fails. In this case, the + `out` parameter is left untouched. See :cpp:func:`cast()` for an alternative + that instead raises an exception in this case. .. cpp:function:: template object cast(T &&value, rv_policy policy = rv_policy::automatic_reference) diff --git a/include/nanobind/nb_cast.h b/include/nanobind/nb_cast.h index 737146d7..7bfd16b5 100644 --- a/include/nanobind/nb_cast.h +++ b/include/nanobind/nb_cast.h @@ -347,6 +347,29 @@ struct type_caster : type_caster_base { }; NAMESPACE_END(detail) +template +bool try_cast(const detail::api &value, T &out, bool convert = true) noexcept { + using Caster = detail::make_caster; + using Output = typename Caster::template Cast; + + static_assert(!std::is_same_v, + "nanobind::try_cast(): cannot return a reference to a temporary."); + + Caster caster; + if (caster.from_python(value.derived().ptr(), + convert ? (uint8_t) detail::cast_flags::convert + : (uint8_t) 0, nullptr)) { + if constexpr (Caster::IsClass) + out = caster.operator Output(); + else + out = std::move(caster.operator Output&&()); + + return true; + } + + return false; +} + template T cast(const detail::api &value, bool convert = true) { if constexpr (std::is_same_v) { diff --git a/tests/test_classes.cpp b/tests/test_classes.cpp index 8cb66e5a..7bd5c0e2 100644 --- a/tests/test_classes.cpp +++ b/tests/test_classes.cpp @@ -488,4 +488,29 @@ NB_MODULE(test_classes_ext, m) { m.def("factory_2", []() { return (Base *) new AnotherSubclass(); }); m.def("check_shared", [](Shared *) { }); + + m.def("try_cast_1", [](nb::handle h) { + Struct s; + bool rv = nb::try_cast(h, s); + return std::make_pair(rv, std::move(s)); + }); + + m.def("try_cast_2", [](nb::handle h) { + Struct s; + Struct &s2 = s; + bool rv = nb::try_cast(h, s2); + return std::make_pair(rv, std::move(s2)); + }); + + m.def("try_cast_3", [](nb::handle h) { + Struct *sp = nullptr; + bool rv = nb::try_cast(h, sp); + return std::make_pair(rv, sp); + }, nb::rv_policy::none); + + m.def("try_cast_4", [](nb::handle h) { + int i = 0; + bool rv = nb::try_cast(h, i); + return std::make_pair(rv, i); + }); } diff --git a/tests/test_classes.py b/tests/test_classes.py index 071e1e5e..a751b96e 100644 --- a/tests/test_classes.py +++ b/tests/test_classes.py @@ -666,3 +666,47 @@ def test38_pickle(clean): unpickled=1, destructed=2 ) + +def test39_try_cast(clean): + s = t.Struct(123) + + assert_stats(value_constructed=1) + t.reset() + + rv, s2 = t.try_cast_1(s) + assert rv is True and s2 is not s and s.value() == 123 and s2.value() == 123 + del s2 + assert_stats(default_constructed=1, move_constructed=2, copy_assigned=1, destructed=3) + t.reset() + + rv, s2 = t.try_cast_2(s) + assert rv is True and s2 is not s and s.value() == 123 and s2.value() == 123 + del s2 + assert_stats(default_constructed=1, move_constructed=2, copy_assigned=1, destructed=3) + t.reset() + + rv, s2 = t.try_cast_3(s) + assert rv is True and s2 is s and s.value() == 123 + del s2 + assert_stats() + t.reset() + + rv, s2 = t.try_cast_2(1) + assert rv is False + del s2 + assert_stats(default_constructed=1, move_constructed=2, destructed=3) + t.reset() + + rv, s2 = t.try_cast_3(1) + assert rv is False and s2 is None + del s2 + assert_stats() + t.reset() + + rv, s2 = t.try_cast_4(s) + assert rv is False and s2 == 0 + rv, s2 = t.try_cast_4(123) + assert rv is True and s2 is 123 + del s, s2 + + assert_stats(destructed=1)