Skip to content

Commit

Permalink
Made type_caster<Eigen::Map<T>> correctly consider T's constness (#325)
Browse files Browse the repository at this point in the history
  • Loading branch information
WKarel committed Oct 17, 2023
1 parent d9cecac commit c087ebd
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 2 deletions.
2 changes: 1 addition & 1 deletion include/nanobind/eigen/dense.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ struct type_caster<Eigen::Map<T, Options, StrideType>,
is_ndarray_scalar_v<typename T::Scalar>>> {
using Map = Eigen::Map<T, Options, StrideType>;
using NDArray =
array_for_eigen_t<Map, std::conditional_t<std::is_const_v<Map>,
array_for_eigen_t<Map, std::conditional_t<std::is_const_v<T>,
const typename Map::Scalar,
typename Map::Scalar>>;
using NDArrayCaster = type_caster<NDArray>;
Expand Down
5 changes: 4 additions & 1 deletion tests/test_eigen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,12 @@ NB_MODULE(test_eigen_ext, m) {
.def(nb::init<>())
.def_rw("member", &ClassWithEigenMember::member);

m.def("castToMapVXi", [](nb::object obj) -> Eigen::Map<Eigen::VectorXi> {
m.def("castToMapVXi", [](nb::object obj) {
return nb::cast<Eigen::Map<Eigen::VectorXi>>(obj);
});
m.def("castToMapCnstVXi", [](nb::object obj) {
return nb::cast<Eigen::Map<const Eigen::VectorXi>>(obj);
});
m.def("castToRefVXi", [](nb::object obj) -> Eigen::VectorXi {
return nb::cast<Eigen::Ref<Eigen::VectorXi>>(obj);
});
Expand Down
3 changes: 3 additions & 0 deletions tests/test_eigen.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,11 @@ def test12_cast():
vec2 = vec[::2]
vecf = np.float32(vec)
assert_array_equal(t.castToMapVXi(vec), vec)
assert_array_equal(t.castToMapCnstVXi(vec), vec)
assert_array_equal(t.castToRefVXi(vec), vec)
assert_array_equal(t.castToRefCnstVXi(vec), vec)
assert t.castToMapVXi(vec).flags.writeable
assert not t.castToMapCnstVXi(vec).flags.writeable
assert_array_equal(t.castToDRefCnstVXi(vec), vec)
for v in vec2, vecf:
with pytest.raises(RuntimeError, match="bad[_ ]cast"):
Expand Down

0 comments on commit c087ebd

Please sign in to comment.