Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complex numbers support for ndarray. #319

Merged
merged 11 commits into from
Oct 18, 2023
12 changes: 8 additions & 4 deletions docs/ndarray.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ The following constraints are available

- A scalar type (``float``, ``uint8_t``, etc.) constrains the representation
of the ndarray.

For a complex type (``std::complex<F32>``, ``std::complex<F64>``, etc.),
the header ``<nanobind/stl/complex.h>`` must be included.

- This scalar type can be further annotated with ``const``, which is necessary
if you plan to call nanobind functions with arrays that do not permit write
Expand Down Expand Up @@ -469,10 +472,11 @@ For example, the following snippet makes ``__fp16`` (half-precision type on

namespace nanobind {
template <> struct ndarray_traits<__fp16> {
static constexpr bool is_float = true;
static constexpr bool is_bool = false;
static constexpr bool is_int = false;
static constexpr bool is_signed = true;
static constexpr bool is_complex = false;
static constexpr bool is_float = true;
static constexpr bool is_bool = false;
static constexpr bool is_int = false;
static constexpr bool is_signed = true;
};
};

Expand Down
38 changes: 32 additions & 6 deletions include/nanobind/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ struct dltensor {

NAMESPACE_END(dlpack)

NAMESPACE_BEGIN(detail)

template <typename T>
struct is_complex : public std::false_type { };

NAMESPACE_END(detail)

constexpr size_t any = (size_t) -1;

template <size_t... Is> struct shape {
Expand All @@ -81,18 +88,19 @@ struct jax { };
struct ro { };

template <typename T> struct ndarray_traits {
static constexpr bool is_float = std::is_floating_point_v<T>;
static constexpr bool is_bool = std::is_same_v<std::remove_cv_t<T>, bool>;
static constexpr bool is_int = std::is_integral_v<T> && !is_bool;
static constexpr bool is_signed = std::is_signed_v<T>;
static constexpr bool is_complex = detail::is_complex<T>::value;
static constexpr bool is_float = std::is_floating_point_v<T>;
static constexpr bool is_bool = std::is_same_v<std::remove_cv_t<T>, bool>;
static constexpr bool is_int = std::is_integral_v<T> && !is_bool;
static constexpr bool is_signed = std::is_signed_v<T>;
};

NAMESPACE_BEGIN(detail)

template <typename T>
constexpr bool is_ndarray_scalar_v =
ndarray_traits<T>::is_float || ndarray_traits<T>::is_int ||
ndarray_traits<T>::is_bool;
ndarray_traits<T>::is_bool || ndarray_traits<T>::is_complex;

template <typename> struct ndim_shape;
template <size_t... S> struct ndim_shape<std::index_sequence<S...>> {
Expand All @@ -115,6 +123,8 @@ template <typename T> constexpr dlpack::dtype dtype() {
result.code = (uint8_t) dlpack::dtype_code::Float;
else if constexpr (ndarray_traits<T>::is_signed)
result.code = (uint8_t) dlpack::dtype_code::Int;
else if constexpr (ndarray_traits<T>::is_complex)
result.code = (uint8_t) dlpack::dtype_code::Complex;
else if constexpr (std::is_same_v<std::remove_cv_t<T>, bool>)
result.code = (uint8_t) dlpack::dtype_code::Bool;
else
Expand Down Expand Up @@ -163,6 +173,21 @@ template <typename T> struct ndarray_arg<T, enable_if_t<ndarray_traits<T>::is_fl
}
};

template <typename T> struct ndarray_arg<T, enable_if_t<ndarray_traits<T>::is_complex>> {
static constexpr size_t size = 0;

static constexpr auto name =
const_name("dtype=complex") +
const_name<sizeof(T) * 8>() +
const_name<std::is_const_v<T>>(", writable=False", "");

static void apply(ndarray_req &tr) {
tr.dtype = dtype<T>();
tr.req_dtype = true;
tr.req_ro = std::is_const_v<T>;
}
};

template <typename T> struct ndarray_arg<T, enable_if_t<ndarray_traits<T>::is_int>> {
static constexpr size_t size = 0;

Expand Down Expand Up @@ -253,7 +278,8 @@ template <typename... Ts> struct ndarray_info {
template <typename T, typename... Ts> struct ndarray_info<T, Ts...> : ndarray_info<Ts...> {
using scalar_type =
std::conditional_t<ndarray_traits<T>::is_float || ndarray_traits<T>::is_int ||
ndarray_traits<T>::is_bool, T, typename ndarray_info<Ts...>::scalar_type>;
ndarray_traits<T>::is_bool || ndarray_traits<T>::is_complex,
T, typename ndarray_info<Ts...>::scalar_type>;
};

template <size_t... Is, typename... Ts> struct ndarray_info<shape<Is...>, Ts...> : ndarray_info<Ts...> {
Expand Down
9 changes: 9 additions & 0 deletions include/nanobind/stl/complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@
NAMESPACE_BEGIN(NB_NAMESPACE)
NAMESPACE_BEGIN(detail)

template <typename T>
struct is_complex;

template<typename T>
struct is_complex<std::complex<T>> : public std::true_type {};

template<typename T>
struct is_complex<const std::complex<T>> : public std::true_type {};

template <typename T> struct type_caster<std::complex<T>> {
NB_TYPE_CASTER(std::complex<T>, const_name("complex") )

Expand Down
7 changes: 7 additions & 0 deletions src/nb_ndarray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ static int nd_ndarray_tpbuffer(PyObject *exporter, Py_buffer *view, int) {
}
break;

case dlpack::dtype_code::Complex:
switch (t.dtype.bits) {
case 64: format = "Zf"; break;
case 128: format = "Zd"; break;
}
break;

case dlpack::dtype_code::Bool:
format = "?";
break;
Expand Down
14 changes: 12 additions & 2 deletions tests/test_ndarray.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <nanobind/nanobind.h>
#include <nanobind/ndarray.h>
#include <nanobind/stl/complex.h>
#include <algorithm>
#include <vector>

Expand Down Expand Up @@ -68,6 +69,7 @@ NB_MODULE(test_ndarray_ext, m) {
});

m.def("pass_float32", [](const nb::ndarray<float> &) { }, "array"_a.noconvert());
m.def("pass_complex64", [](const nb::ndarray<std::complex<float>> &) { }, "array"_a.noconvert());
m.def("pass_uint32", [](const nb::ndarray<uint32_t> &) { }, "array"_a.noconvert());
m.def("pass_bool", [](const nb::ndarray<bool> &) { }, "array"_a.noconvert());
m.def("pass_float32_shaped",
Expand Down Expand Up @@ -119,10 +121,11 @@ NB_MODULE(test_ndarray_ext, m) {
}
printf("Tensor is on CPU? %i\n", ndarray.device_type() == nb::device::cpu::value);
printf("Device ID = %u\n", ndarray.device_id());
printf("Tensor dtype check: int16=%i, uint32=%i, float32=%i\n",
printf("Tensor dtype check: int16=%i, uint32=%i, float32=%i complex64=%i\n",
ndarray.dtype() == nb::dtype<int16_t>(),
ndarray.dtype() == nb::dtype<uint32_t>(),
ndarray.dtype() == nb::dtype<float>()
ndarray.dtype() == nb::dtype<float>(),
ndarray.dtype() == nb::dtype<std::complex<float>>()
);
});

Expand Down Expand Up @@ -261,6 +264,13 @@ NB_MODULE(test_ndarray_ext, m) {
v(i, j) = (float) (i * 10 + j);
}, "x"_a.noconvert());

m.def("fill_view_5", [](nb::ndarray<std::complex<float>, nb::shape<2, 2>, nb::c_contig, nb::device::cpu> x) {
auto v = x.view();
for (size_t i = 0; i < v.shape(0); ++i)
for (size_t j = 0; j < v.shape(1); ++j)
v(i, j) *= std::complex<float>(-1.0f, 2.0f);
}, "x"_a.noconvert());

#if defined(__aarch64__)
m.def("ret_numpy_half", []() {
__fp16 *f = new __fp16[8] { 1, 2, 3, 4, 5, 6, 7, 8 };
Expand Down
25 changes: 24 additions & 1 deletion tests/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def test02_docstr():
assert t.get_shape.__doc__ == "get_shape(array: ndarray[writable=False]) -> list"
assert t.pass_uint32.__doc__ == "pass_uint32(array: ndarray[dtype=uint32]) -> None"
assert t.pass_float32.__doc__ == "pass_float32(array: ndarray[dtype=float32]) -> None"
assert t.pass_complex64.__doc__ == "pass_complex64(array: ndarray[dtype=complex64]) -> None"
assert t.pass_bool.__doc__ == "pass_bool(array: ndarray[dtype=bool]) -> None"
assert t.pass_float32_shaped.__doc__ == "pass_float32_shaped(array: ndarray[dtype=float32, shape=(3, *, 4)]) -> None"
assert t.pass_float32_shaped_ordered.__doc__ == "pass_float32_shaped_ordered(array: ndarray[dtype=float32, order='C', shape=(*, *, 4)]) -> None"
Expand All @@ -82,10 +83,12 @@ def test02_docstr():
def test03_constrain_dtype():
a_u32 = np.array([1], dtype=np.uint32)
a_f32 = np.array([1], dtype=np.float32)
a_cf64 = np.array([1+1j], dtype=np.complex64)
a_bool = np.array([1], dtype=np.bool_)

t.pass_uint32(a_u32)
t.pass_float32(a_f32)
t.pass_complex64(a_cf64)
t.pass_bool(a_bool)

with pytest.raises(TypeError) as excinfo:
Expand All @@ -96,6 +99,10 @@ def test03_constrain_dtype():
t.pass_float32(a_u32)
assert 'incompatible function arguments' in str(excinfo.value)

with pytest.raises(TypeError) as excinfo:
t.pass_complex64(a_u32)
assert 'incompatible function arguments' in str(excinfo.value)

with pytest.raises(TypeError) as excinfo:
t.pass_bool(a_u32)
assert 'incompatible function arguments' in str(excinfo.value)
Expand Down Expand Up @@ -573,7 +580,7 @@ def test31_view():
t.fill_view_1(x2)
assert np.allclose(x1, x2*2)

#2
# 2
x1 = np.zeros((3, 4), dtype=np.float32, order='C')
x2 = np.zeros((3, 4), dtype=np.float32, order='F')
t.fill_view_2(x1)
Expand All @@ -585,6 +592,15 @@ def test31_view():

assert np.all(x1 == x2) and np.all(x2 == x3) and np.all(x3 == x4)

# 3
x1 = np.array([[1+2j, 3+4j], [5+6j, 7+8j]], dtype=np.complex64)
x2 = x1 * 2
t.fill_view_1(x1.view(np.float32))
assert np.allclose(x1, x2)
x2 = x1 * (-1+2j)
t.fill_view_5(x1)
assert np.allclose(x1, x2)

@needs_numpy
def test32_half():
if not hasattr(t, 'ret_numpy_half'):
Expand All @@ -601,3 +617,10 @@ def test33_cast():
assert a.ndim == 0 and b.ndim == 0
assert a.dtype == np.int32 and b.dtype == np.float32
assert a == 1 and b == 1

@needs_numpy
def test34_complex_decompose():
x1 = np.array([1 + 2j, 3 + 4j, 5 + 6j], dtype=np.complex64)

assert np.all(x1.real == np.array([1, 3, 5], dtype=np.float32))
assert np.all(x1.imag == np.array([2, 4, 6], dtype=np.float32))
Loading