-
Notifications
You must be signed in to change notification settings - Fork 190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Complex numbers support for ndarray. #319
Conversation
Well, that's awesome! I did not realize it could be done with so few lines of code changed. But this needs really tests, could I ask you to add a few for returning and receiving complex-valued arrays? |
Also, does it work with another tensor framework besides NumPy? (i.e., something that goes through the regular dlpack conversion as opposed to the buffer protocol) |
I think |
@luigifcruz -- I was thinking of including this in an upcoming release in the next couple of days, but having test coverage would be a prerequisite for this. Will you be able to add them? Thanks! |
@wjakob -- Yes! I'll write the tests. Sorry for the delay!
I can't see any reason why it wouldn't work. But I didn't test it with anything other than NumPy yet. |
@chrisrichardson Good catch! I'll update the PR with your recommendations. Thanks! |
One small request also from me: diff --git a/include/nanobind/ndarray.h b/include/nanobind/ndarray.h
index 31a6858..145ef88 100644
--- a/include/nanobind/ndarray.h
+++ b/include/nanobind/ndarray.h
@@ -14,7 +14,6 @@
#include <nanobind/nanobind.h>
#include <initializer_list>
-#include <complex>
NAMESPACE_BEGIN(NB_NAMESPACE)
@@ -66,15 +65,19 @@ struct dltensor {
NAMESPACE_END(dlpack)
+NAMESPACE_BEGIN(detail)
+
+template <typename T> struct is_complex
+ : public std::false_type { };
+
+NAMESPACE_END(detail)
+
constexpr size_t any = (size_t) -1;
template <size_t... Is> struct shape {
static constexpr size_t size = sizeof...(Is);
};
-template<typename T> struct is_complex_t : public std::false_type {};
-template<typename T> struct is_complex_t<std::complex<T>> : public std::true_type {};
-
struct c_contig { };
struct f_contig { };
struct any_contig { };
@@ -85,7 +88,7 @@ struct jax { };
struct ro { };
template <typename T> struct ndarray_traits {
- static constexpr bool is_complex = is_complex_t<T>::value;
+ static constexpr bool is_complex = detail::is_complex<T>::value;
static constexpr bool is_float = std::is_floating_point_v<T>;
static constexpr bool is_bool = std::is_same_v<std::remove_cv_t<T>, bool>;
static constexpr bool is_int = std::is_integral_v<T> && !is_bool;
diff --git a/include/nanobind/stl/complex.h b/include/nanobind/stl/complex.h
index d69b846..c8ca034 100644
--- a/include/nanobind/stl/complex.h
+++ b/include/nanobind/stl/complex.h
@@ -15,6 +15,9 @@
NAMESPACE_BEGIN(NB_NAMESPACE)
NAMESPACE_BEGIN(detail)
+template <typename T>
+struct is_complex<std::complex<T>> : public std::true_type { };
+
template <typename T> struct type_caster<std::complex<T>> {
NB_TYPE_CASTER(std::complex<T>, const_name("complex") ) |
I also renamed |
Cool! Thanks @wjakob for the feedback. I implemented the changes. I changed the style of the first template to keep it consistent with the templates on the |
I added a struct prototype in the |
The design looks all good to me — now it will just need tests. |
Don't worry about the " API rate limit exceeded" error message, that just happens sometimes. |
Hmm, I can't think of any more tests that are not redundant. I think these should cover all the different behaviors expected from complex numbers. Let me know if you need more! |
It looks great -- many thanks for adding the tests. PyTorch et al. are still missing, but I see that it's an experimental feature there. So probably not something we need to look into yet. |
Awesome! Thanks for this nice library! |
I'm getting runtime type failures for I can dig into the cause, but maybe someone can spot an easy fix. |
This PR introduces support for complex numbers in the
ndarray
type.is_complex_t
to identify complex number types.ndarray_traits
to include anis_complex
trait.is_ndarray_scalar_v
to account for complex numbers.ndarray_arg
specialization for complex numbers.nd_ndarray_tpbuffer
to handle complex numbers with 64 and 128 bits.