Skip to content

Commit

Permalink
Added a non-throwing function nb::try_cast as alternative to ``nb…
Browse files Browse the repository at this point in the history
…::cast``.
  • Loading branch information
wjakob committed Jun 28, 2023
1 parent 0237e60 commit 6ca852c
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 0 deletions.
14 changes: 14 additions & 0 deletions docs/api_core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,20 @@ Casting
implementation may also attempt *implicit conversions* to perform the cast.

The function raises a :cpp:type:`cast_error` when the conversion fails.
See :cpp:func:`try_cast()` for an alternative that never raises.

.. cpp:function:: template <typename T, typename Derived> bool try_cast(const detail::api<Derived> &value, T &out, bool convert = true) noexcept

Convert the Python object `value` (typically a :cpp:class:`handle` or a
:cpp:class:`object` subclass) into a C++ object of type `T`, and store it
in the output parameter `out`.

When the `convert` argument is set to ``true`` (the default), the
implementation may also attempt *implicit conversions* to perform the cast.

The function returns `false` when the conversion fails. In this case, the
`out` parameter is left untouched. See :cpp:func:`cast()` for an alternative
that instead raises an exception in this case.

.. cpp:function:: template <typename T> object cast(T &&value, rv_policy policy = rv_policy::automatic_reference)

Expand Down
23 changes: 23 additions & 0 deletions include/nanobind/nb_cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,29 @@ struct type_caster : type_caster_base<Type> { };

NAMESPACE_END(detail)

template <typename T, typename Derived>
bool try_cast(const detail::api<Derived> &value, T &out, bool convert = true) noexcept {
using Caster = detail::make_caster<T>;
using Output = typename Caster::template Cast<T>;

static_assert(!std::is_same_v<const char *, T>,
"nanobind::try_cast(): cannot return a reference to a temporary.");

Caster caster;
if (caster.from_python(value.derived().ptr(),
convert ? (uint8_t) detail::cast_flags::convert
: (uint8_t) 0, nullptr)) {
if constexpr (Caster::IsClass)
out = caster.operator Output();
else
out = std::move(caster.operator Output&&());

return true;
}

return false;
}

template <typename T, typename Derived>
T cast(const detail::api<Derived> &value, bool convert = true) {
if constexpr (std::is_same_v<T, void>) {
Expand Down
25 changes: 25 additions & 0 deletions tests/test_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,4 +488,29 @@ NB_MODULE(test_classes_ext, m) {
m.def("factory_2", []() { return (Base *) new AnotherSubclass(); });

m.def("check_shared", [](Shared *) { });

m.def("try_cast_1", [](nb::handle h) {
Struct s;
bool rv = nb::try_cast<Struct>(h, s);
return std::make_pair(rv, std::move(s));
});

m.def("try_cast_2", [](nb::handle h) {
Struct s;
Struct &s2 = s;
bool rv = nb::try_cast<Struct &>(h, s2);
return std::make_pair(rv, std::move(s2));
});

m.def("try_cast_3", [](nb::handle h) {
Struct *sp = nullptr;
bool rv = nb::try_cast<Struct *>(h, sp);
return std::make_pair(rv, sp);
}, nb::rv_policy::none);

m.def("try_cast_4", [](nb::handle h) {
int i = 0;
bool rv = nb::try_cast<int>(h, i);
return std::make_pair(rv, i);
});
}
44 changes: 44 additions & 0 deletions tests/test_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,3 +666,47 @@ def test38_pickle(clean):
unpickled=1,
destructed=2
)

def test39_try_cast(clean):
s = t.Struct(123)

assert_stats(value_constructed=1)
t.reset()

rv, s2 = t.try_cast_1(s)
assert rv is True and s2 is not s and s.value() == 123 and s2.value() == 123
del s2
assert_stats(default_constructed=1, move_constructed=2, copy_assigned=1, destructed=3)
t.reset()

rv, s2 = t.try_cast_2(s)
assert rv is True and s2 is not s and s.value() == 123 and s2.value() == 123
del s2
assert_stats(default_constructed=1, move_constructed=2, copy_assigned=1, destructed=3)
t.reset()

rv, s2 = t.try_cast_3(s)
assert rv is True and s2 is s and s.value() == 123
del s2
assert_stats()
t.reset()

rv, s2 = t.try_cast_2(1)
assert rv is False
del s2
assert_stats(default_constructed=1, move_constructed=2, destructed=3)
t.reset()

rv, s2 = t.try_cast_3(1)
assert rv is False and s2 is None
del s2
assert_stats()
t.reset()

rv, s2 = t.try_cast_4(s)
assert rv is False and s2 == 0
rv, s2 = t.try_cast_4(123)
assert rv is True and s2 is 123
del s, s2

assert_stats(destructed=1)

0 comments on commit 6ca852c

Please sign in to comment.