Skip to content

Commit

Permalink
Stokhos: update for new KokkosSparse::spmv overloads
Browse files Browse the repository at this point in the history
KokkosKernels 4.2 will add some new overloads for spmv that accept
execution space instances. The Stokhos specializations for spmv,
where Scalar is PCE or MPVector, won't have them yet.
So make sure that the existing Stokhos specializations still work by
calling a Stokhos implementation (and not a KokkosKernels implementation).
  • Loading branch information
brian-kelley committed Aug 31, 2023
1 parent e764653 commit db1f57b
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1619,6 +1619,53 @@ spmv(
spmv(mode, a, A, x, b, y, RANK_TWO());
}

template <typename AlphaType,
typename BetaType,
typename MatrixType,
typename InputType,
typename ... InputP,
typename OutputType,
typename ... OutputP>
std::enable_if_t<
Kokkos::is_view_uq_pce< Kokkos::View< InputType, InputP... > >::value &&
Kokkos::is_view_uq_pce< Kokkos::View< OutputType, OutputP... > >::value &&
KokkosSparse::is_crs_matrix<MatrixType>::value>
spmv(
KokkosKernels::Experimental::Controls controls,
const char mode[],
const AlphaType& a,
const MatrixType& A,
const Kokkos::View< InputType, InputP... >& x,
const BetaType& b,
const Kokkos::View< OutputType, OutputP... >& y)
{
using RANK_SPECIALISE = std::conditional_t<std::decay_t<decltype(x)>::rank == 2, RANK_TWO, RANK_ONE>;
spmv(controls, mode, a, A, x, b, y, RANK_SPECIALISE());
}

template <typename AlphaType,
typename BetaType,
typename MatrixType,
typename InputType,
typename ... InputP,
typename OutputType,
typename ... OutputP>
std::enable_if_t<
Kokkos::is_view_uq_pce< Kokkos::View< InputType, InputP... > >::value &&
Kokkos::is_view_uq_pce< Kokkos::View< OutputType, OutputP... > >::value &&
KokkosSparse::is_crs_matrix<MatrixType>::value>
spmv(
const char mode[],
const AlphaType& a,
const MatrixType& A,
const Kokkos::View< InputType, InputP... >& x,
const BetaType& b,
const Kokkos::View< OutputType, OutputP... >& y)
{
using RANK_SPECIALISE = std::conditional_t<std::decay_t<decltype(x)>::rank == 2, RANK_TWO, RANK_ONE>;
spmv(mode, a, A, x, b, y, RANK_SPECIALISE());
}

}

#endif /* #ifndef KOKKOS_CRSMATRIX_UQ_PCE_HPP */
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,53 @@ spmv(
spmv(mode, a, A, x, b, y, RANK_TWO());
}

template <typename AlphaType,
typename BetaType,
typename MatrixType,
typename InputType,
typename ... InputP,
typename OutputType,
typename ... OutputP>
std::enable_if_t<
Kokkos::is_view_mp_vector< Kokkos::View< InputType, InputP... > >::value &&
Kokkos::is_view_mp_vector< Kokkos::View< OutputType, OutputP... > >::value &&
KokkosSparse::is_crs_matrix<MatrixType>::value>
spmv(
KokkosKernels::Experimental::Controls controls,
const char mode[],
const AlphaType& a,
const MatrixType& A,
const Kokkos::View< InputType, InputP... >& x,
const BetaType& b,
const Kokkos::View< OutputType, OutputP... >& y)
{
using RANK_SPECIALISE = std::conditional_t<std::decay_t<decltype(x)>::rank == 2, RANK_TWO, RANK_ONE>;
spmv(controls, mode, a, A, x, b, y, RANK_SPECIALISE());
}

template <typename AlphaType,
typename BetaType,
typename MatrixType,
typename InputType,
typename ... InputP,
typename OutputType,
typename ... OutputP>
std::enable_if_t<
Kokkos::is_view_mp_vector< Kokkos::View< InputType, InputP... > >::value &&
Kokkos::is_view_mp_vector< Kokkos::View< OutputType, OutputP... > >::value &&
KokkosSparse::is_crs_matrix<MatrixType>::value>
spmv(
const char mode[],
const AlphaType& a,
const MatrixType& A,
const Kokkos::View< InputType, InputP... >& x,
const BetaType& b,
const Kokkos::View< OutputType, OutputP... >& y)
{
using RANK_SPECIALISE = std::conditional_t<std::decay_t<decltype(x)>::rank == 2, RANK_TWO, RANK_ONE>;
spmv(mode, a, A, x, b, y, RANK_SPECIALISE());
}

}

#endif /* #ifndef KOKKOS_CRSMATRIX_MP_VECTOR_HPP */

0 comments on commit db1f57b

Please sign in to comment.