From fd47111ca5ca36a0beb4cc03d02243000806c65b Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Thu, 10 Aug 2023 15:41:19 -0600 Subject: [PATCH] Disallow BsrMatrix tensor core SpMV for non-scalar types --- .../impl/KokkosSparse_spmv_bsrmatrix_impl.hpp | 56 ++++++++++++------- .../impl/KokkosSparse_spmv_bsrmatrix_spec.hpp | 20 ++----- 2 files changed, 41 insertions(+), 35 deletions(-) diff --git a/sparse/impl/KokkosSparse_spmv_bsrmatrix_impl.hpp b/sparse/impl/KokkosSparse_spmv_bsrmatrix_impl.hpp index abf44589f7..f5374e3c6d 100644 --- a/sparse/impl/KokkosSparse_spmv_bsrmatrix_impl.hpp +++ b/sparse/impl/KokkosSparse_spmv_bsrmatrix_impl.hpp @@ -37,6 +37,40 @@ struct BsrMatrixSpMVTensorCoreFunctorParams { int leagueDim_y; }; +/*! \brief Can the tensor core impl be used in ExecutionSpace to operate on + AMatrix, XMatrix, and YMatrix? +*/ +template +class TensorCoresAvailable { +#if defined(KOKKOS_ENABLE_CUDA) + using AScalar = typename AMatrix::non_const_value_type; + using YScalar = typename YMatrix::non_const_value_type; + using XScalar = typename XMatrix::non_const_value_type; + + using a_mem_space = typename AMatrix::memory_space; + using x_mem_space = typename XMatrix::memory_space; + using y_mem_space = typename YMatrix::memory_space; + + template + constexpr static bool is_scalar() { + return std::is_scalar_v || + std::is_same_v, Kokkos::Experimental::half_t>; + } + + public: + constexpr static inline bool value = + Kokkos::SpaceAccessibility::accessible && + Kokkos::SpaceAccessibility::accessible && + Kokkos::SpaceAccessibility::accessible && + is_scalar() && is_scalar() && is_scalar() && + std::is_same_v; +#else + public: + constexpr static inline bool value = false; +#endif +}; + /// \brief Functor for the BsrMatrix SpMV multivector implementation utilizing /// tensor cores. /// @@ -471,30 +505,12 @@ struct BsrMatrixSpMVTensorCoreDispatcher { "execution spaces"); } - /*true if none of T1, T2, or T3 are complex*/ - template - struct none_complex { - const static bool value = !Kokkos::ArithTraits::is_complex && - !Kokkos::ArithTraits::is_complex && - !Kokkos::ArithTraits::is_complex; - }; - - /*true if T1::execution_space, T2, or T3 are all GPU exec space*/ - template - struct all_gpu { - const static bool value = KokkosKernels::Impl::kk_is_gpu_exec_space() && - KokkosKernels::Impl::kk_is_gpu_exec_space() && - KokkosKernels::Impl::kk_is_gpu_exec_space(); - }; - static void dispatch(YScalar alpha, AMatrix a, XMatrix x, YScalar beta, YMatrix y) { // tag will be false unless all conditions are met using tag = std::integral_constant< - bool, none_complex::value && - all_gpu::value>; + bool, TensorCoresAvailable::value>; tag_dispatch(tag{}, alpha, a, x, beta, y); } }; diff --git a/sparse/impl/KokkosSparse_spmv_bsrmatrix_spec.hpp b/sparse/impl/KokkosSparse_spmv_bsrmatrix_spec.hpp index 69ff744e9d..90d9b25759 100644 --- a/sparse/impl/KokkosSparse_spmv_bsrmatrix_spec.hpp +++ b/sparse/impl/KokkosSparse_spmv_bsrmatrix_spec.hpp @@ -228,25 +228,15 @@ struct SPMV_MV_BSRMATRIX::is_complex) method = Method::Fallback; - if (Kokkos::ArithTraits::is_complex) method = Method::Fallback; - if (Kokkos::ArithTraits::is_complex) method = Method::Fallback; - // can't use tensor cores outside GPU - if (!KokkosKernels::Impl::kk_is_gpu_exec_space< - typename AMatrix::execution_space>()) - method = Method::Fallback; - if (!KokkosKernels::Impl::kk_is_gpu_exec_space< - typename XVector::execution_space>()) - method = Method::Fallback; - if (!KokkosKernels::Impl::kk_is_gpu_exec_space< - typename YVector::execution_space>()) + + if (!KokkosSparse::Experimental::Impl::TensorCoresAvailable< + typename AMatrix::execution_space, AMatrix, XVector, + YVector>::value) { method = Method::Fallback; + } // can't use tensor cores unless mode is no-transpose if (mode[0] != KokkosSparse::NoTranspose[0]) method = Method::Fallback; #if KOKKOS_HALF_T_IS_FLOAT