Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complex numbers support for ndarray. #319

Merged
merged 11 commits into from
Oct 18, 2023
Merged

Conversation

luigifcruz
Copy link
Contributor

This PR introduces support for complex numbers in the ndarray type.

  • Added a trait is_complex_t to identify complex number types.
  • Updated ndarray_traits to include an is_complex trait.
  • Modified is_ndarray_scalar_v to account for complex numbers.
  • Introduced a new ndarray_arg specialization for complex numbers.
  • Updated the dtype function to handle complex numbers and return the appropriate dtype code.
  • Extended the format switch-case in nd_ndarray_tpbuffer to handle complex numbers with 64 and 128 bits.

@wjakob
Copy link
Owner

wjakob commented Oct 10, 2023

Well, that's awesome! I did not realize it could be done with so few lines of code changed. But this needs really tests, could I ask you to add a few for returning and receiving complex-valued arrays?

@wjakob
Copy link
Owner

wjakob commented Oct 10, 2023

Also, does it work with another tensor framework besides NumPy? (i.e., something that goes through the regular dlpack conversion as opposed to the buffer protocol)

@chrisrichardson
Copy link

chrisrichardson commented Oct 14, 2023

I think ndarray.h may need an extra || ndarray_traits<T>::is_complex around line 277 for ndarray_info
and something like:
template<typename T> struct is_complex_t<const std::complex<T>> : public std::true_type {};
around line 77

@wjakob
Copy link
Owner

wjakob commented Oct 16, 2023

@luigifcruz -- I was thinking of including this in an upcoming release in the next couple of days, but having test coverage would be a prerequisite for this. Will you be able to add them? Thanks!

@luigifcruz
Copy link
Contributor Author

@wjakob -- Yes! I'll write the tests. Sorry for the delay!

Also, does it work with another tensor framework besides NumPy? (i.e., something that goes through the regular dlpack conversion as opposed to the buffer protocol)

I can't see any reason why it wouldn't work. But I didn't test it with anything other than NumPy yet.

@luigifcruz
Copy link
Contributor Author

@chrisrichardson Good catch! I'll update the PR with your recommendations. Thanks!

@wjakob
Copy link
Owner

wjakob commented Oct 18, 2023

One small request also from me: #include <complex> pulls in the whole math library, which is a huge amount of header file code that may not be needed by other users of ndarray.h. So I would prefer if all the complex detection-specific bits remain in the stl/complex.h header file. You should be able to do so using a variant of the patch below:

diff --git a/include/nanobind/ndarray.h b/include/nanobind/ndarray.h
index 31a6858..145ef88 100644
--- a/include/nanobind/ndarray.h
+++ b/include/nanobind/ndarray.h
@@ -14,7 +14,6 @@
 
 #include <nanobind/nanobind.h>
 #include <initializer_list>
-#include <complex>
 
 NAMESPACE_BEGIN(NB_NAMESPACE)
 
@@ -66,15 +65,19 @@ struct dltensor {
 
 NAMESPACE_END(dlpack)
 
+NAMESPACE_BEGIN(detail)
+
+template <typename T> struct is_complex
+    : public std::false_type { };
+
+NAMESPACE_END(detail)
+
 constexpr size_t any = (size_t) -1;
 
 template <size_t... Is> struct shape {
     static constexpr size_t size = sizeof...(Is);
 };
 
-template<typename T> struct is_complex_t : public std::false_type {};
-template<typename T> struct is_complex_t<std::complex<T>> : public std::true_type {};
-
 struct c_contig { };
 struct f_contig { };
 struct any_contig { };
@@ -85,7 +88,7 @@ struct jax { };
 struct ro { };
 
 template <typename T> struct ndarray_traits {
-    static constexpr bool is_complex = is_complex_t<T>::value;
+    static constexpr bool is_complex = detail::is_complex<T>::value;
     static constexpr bool is_float   = std::is_floating_point_v<T>;
     static constexpr bool is_bool    = std::is_same_v<std::remove_cv_t<T>, bool>;
     static constexpr bool is_int     = std::is_integral_v<T> && !is_bool;
diff --git a/include/nanobind/stl/complex.h b/include/nanobind/stl/complex.h
index d69b846..c8ca034 100644
--- a/include/nanobind/stl/complex.h
+++ b/include/nanobind/stl/complex.h
@@ -15,6 +15,9 @@
 NAMESPACE_BEGIN(NB_NAMESPACE)
 NAMESPACE_BEGIN(detail)
 
+template <typename T>
+struct is_complex<std::complex<T>> : public std::true_type { };
+
 template <typename T> struct type_caster<std::complex<T>> {
     NB_TYPE_CASTER(std::complex<T>, const_name("complex") )

@wjakob
Copy link
Owner

wjakob commented Oct 18, 2023

I also renamed is_complex_t -> is_complex (the _t suffix is usually reserved for using x_t = typename x...::type-style type aliases), and I moved it to the detail namespace since it is not a public API.

@luigifcruz
Copy link
Contributor Author

Cool! Thanks @wjakob for the feedback. I implemented the changes. I changed the style of the first template to keep it consistent with the templates on the stl/complex.h header file. Let me know if you want more changes. I'll start to implement the tests now.

@luigifcruz
Copy link
Contributor Author

I added a struct prototype in the stl/complex.h file. I concluded that this would be better than importing the whole ndarray.h to stl/complex.h. Let me know if prefer otherwise.

@wjakob
Copy link
Owner

wjakob commented Oct 18, 2023

The design looks all good to me — now it will just need tests.

@wjakob
Copy link
Owner

wjakob commented Oct 18, 2023

Don't worry about the " API rate limit exceeded" error message, that just happens sometimes.

@luigifcruz
Copy link
Contributor Author

Hmm, I can't think of any more tests that are not redundant. I think these should cover all the different behaviors expected from complex numbers. Let me know if you need more!

@wjakob
Copy link
Owner

wjakob commented Oct 18, 2023

It looks great -- many thanks for adding the tests. PyTorch et al. are still missing, but I see that it's an experimental feature there. So probably not something we need to look into yet.

@wjakob wjakob merged commit bead53e into wjakob:master Oct 18, 2023
21 checks passed
@luigifcruz
Copy link
Contributor Author

Awesome! Thanks for this nice library!

@luigifcruz luigifcruz deleted the complex_patch branch October 18, 2023 20:16
wjakob pushed a commit that referenced this pull request Oct 18, 2023
@garth-wells
Copy link
Contributor

I'm getting runtime type failures for const complex, i.e. nb::ndarray<const std::complex<double>>. Passing a read-only NumPy array with complex types to a function expecting a nb::ndarray<const std::complex<double>> fails. Real types work fine. The const case isn't covered by the tests.

I can dig into the cause, but maybe someone can spot an easy fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants