From db1f57b9837d6e133a1252247a08b06deb27f67d Mon Sep 17 00:00:00 2001 From: Brian Kelley Date: Thu, 31 Aug 2023 17:41:21 -0600 Subject: [PATCH] Stokhos: update for new KokkosSparse::spmv overloads KokkosKernels 4.2 will add some new overloads for spmv that accept execution space instances. The Stokhos specializations for spmv, where Scalar is PCE or MPVector, won't have them yet. So make sure that the existing Stokhos specializations still work by calling a Stokhos implementation (and not a KokkosKernels implementation). --- .../pce/linalg/Kokkos_CrsMatrix_UQ_PCE.hpp | 47 +++++++++++++++++++ .../linalg/Kokkos_CrsMatrix_MP_Vector.hpp | 47 +++++++++++++++++++ 2 files changed, 94 insertions(+) diff --git a/packages/stokhos/src/sacado/kokkos/pce/linalg/Kokkos_CrsMatrix_UQ_PCE.hpp b/packages/stokhos/src/sacado/kokkos/pce/linalg/Kokkos_CrsMatrix_UQ_PCE.hpp index 0d99e835f1b5..714b05c5e17e 100644 --- a/packages/stokhos/src/sacado/kokkos/pce/linalg/Kokkos_CrsMatrix_UQ_PCE.hpp +++ b/packages/stokhos/src/sacado/kokkos/pce/linalg/Kokkos_CrsMatrix_UQ_PCE.hpp @@ -1619,6 +1619,53 @@ spmv( spmv(mode, a, A, x, b, y, RANK_TWO()); } +template +std::enable_if_t< + Kokkos::is_view_uq_pce< Kokkos::View< InputType, InputP... > >::value && + Kokkos::is_view_uq_pce< Kokkos::View< OutputType, OutputP... > >::value && + KokkosSparse::is_crs_matrix::value> +spmv( + KokkosKernels::Experimental::Controls controls, + const char mode[], + const AlphaType& a, + const MatrixType& A, + const Kokkos::View< InputType, InputP... >& x, + const BetaType& b, + const Kokkos::View< OutputType, OutputP... >& y) +{ + using RANK_SPECIALISE = std::conditional_t::rank == 2, RANK_TWO, RANK_ONE>; + spmv(controls, mode, a, A, x, b, y, RANK_SPECIALISE()); +} + +template +std::enable_if_t< + Kokkos::is_view_uq_pce< Kokkos::View< InputType, InputP... > >::value && + Kokkos::is_view_uq_pce< Kokkos::View< OutputType, OutputP... > >::value && + KokkosSparse::is_crs_matrix::value> +spmv( + const char mode[], + const AlphaType& a, + const MatrixType& A, + const Kokkos::View< InputType, InputP... >& x, + const BetaType& b, + const Kokkos::View< OutputType, OutputP... >& y) +{ + using RANK_SPECIALISE = std::conditional_t::rank == 2, RANK_TWO, RANK_ONE>; + spmv(mode, a, A, x, b, y, RANK_SPECIALISE()); +} + } #endif /* #ifndef KOKKOS_CRSMATRIX_UQ_PCE_HPP */ diff --git a/packages/stokhos/src/sacado/kokkos/vector/linalg/Kokkos_CrsMatrix_MP_Vector.hpp b/packages/stokhos/src/sacado/kokkos/vector/linalg/Kokkos_CrsMatrix_MP_Vector.hpp index 1d2d365556ec..eb6aa8239653 100644 --- a/packages/stokhos/src/sacado/kokkos/vector/linalg/Kokkos_CrsMatrix_MP_Vector.hpp +++ b/packages/stokhos/src/sacado/kokkos/vector/linalg/Kokkos_CrsMatrix_MP_Vector.hpp @@ -749,6 +749,53 @@ spmv( spmv(mode, a, A, x, b, y, RANK_TWO()); } +template +std::enable_if_t< + Kokkos::is_view_mp_vector< Kokkos::View< InputType, InputP... > >::value && + Kokkos::is_view_mp_vector< Kokkos::View< OutputType, OutputP... > >::value && + KokkosSparse::is_crs_matrix::value> +spmv( + KokkosKernels::Experimental::Controls controls, + const char mode[], + const AlphaType& a, + const MatrixType& A, + const Kokkos::View< InputType, InputP... >& x, + const BetaType& b, + const Kokkos::View< OutputType, OutputP... >& y) +{ + using RANK_SPECIALISE = std::conditional_t::rank == 2, RANK_TWO, RANK_ONE>; + spmv(controls, mode, a, A, x, b, y, RANK_SPECIALISE()); +} + +template +std::enable_if_t< + Kokkos::is_view_mp_vector< Kokkos::View< InputType, InputP... > >::value && + Kokkos::is_view_mp_vector< Kokkos::View< OutputType, OutputP... > >::value && + KokkosSparse::is_crs_matrix::value> +spmv( + const char mode[], + const AlphaType& a, + const MatrixType& A, + const Kokkos::View< InputType, InputP... >& x, + const BetaType& b, + const Kokkos::View< OutputType, OutputP... >& y) +{ + using RANK_SPECIALISE = std::conditional_t::rank == 2, RANK_TWO, RANK_ONE>; + spmv(mode, a, A, x, b, y, RANK_SPECIALISE()); +} + } #endif /* #ifndef KOKKOS_CRSMATRIX_MP_VECTOR_HPP */