Skip to content

Commit

Permalink
Stokhos: fix KokkosKernels trilinos#1959
Browse files Browse the repository at this point in the history
PR trilinos#12190 actually failed to fix KokkosSparse::spmv for Sacado scalar
types, when building with Kokkos/KokkosKernels develop branch.

This actually fixes that issue (tested with develop and master KokkosKernels)
and is quite a bit cleaner (though it uses version macros that can be
taken out for the 4.2 release).
  • Loading branch information
brian-kelley committed Sep 5, 2023
1 parent d3953b4 commit eab6d5d
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 194 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1467,7 +1467,11 @@ class MeanMultiply< KokkosSparse::CrsMatrix< Sacado::UQ::PCE<MatrixStorage>,

namespace KokkosSparse {

template <typename AlphaType,
template <
#if KOKKOSKERNELS_VERSION >= 40199
typename ExecutionSpace,
#endif
typename AlphaType,
typename BetaType,
typename MatrixType,
typename InputType,
Expand All @@ -1479,6 +1483,10 @@ typename std::enable_if<
Kokkos::is_view_uq_pce< Kokkos::View< OutputType, OutputP... > >::value
>::type
spmv(
#if KOKKOSKERNELS_VERSION >= 40199
const ExecutionSpace& space,
#endif
KokkosKernels::Experimental::Controls,
const char mode[],
const AlphaType& a,
const MatrixType& A,
Expand All @@ -1494,6 +1502,12 @@ spmv(
typedef Stokhos::MeanMultiply<MatrixType, typename InputVectorType::const_type,
OutputVectorType> mean_multiply_type;

#if KOKKOSKERNELS_VERSION >= 40199
if(space != ExecutionSpace()) {
Kokkos::Impl::raise_error(
"Stokhos spmv not implemented for non-default execution space instance");
}
#endif
if(mode[0]!='N') {
Kokkos::Impl::raise_error(
"Stokhos spmv not implemented for transposed or conjugated matrix-vector multiplies");
Expand All @@ -1514,7 +1528,11 @@ spmv(
Sacado::Value<BetaType>::eval(b) );
}

template <typename AlphaType,
template <
#if KOKKOSKERNELS_VERSION >= 40199
typename ExecutionSpace,
#endif
typename AlphaType,
typename BetaType,
typename MatrixType,
typename InputType,
Expand All @@ -1526,30 +1544,10 @@ typename std::enable_if<
Kokkos::is_view_uq_pce< Kokkos::View< OutputType, OutputP... > >::value
>::type
spmv(
#if KOKKOSKERNELS_VERSION >= 40199
const ExecutionSpace& space,
#endif
KokkosKernels::Experimental::Controls,
const char mode[],
const AlphaType& a,
const MatrixType& A,
const Kokkos::View< InputType, InputP... >& x,
const BetaType& b,
const Kokkos::View< OutputType, OutputP... >& y,
const RANK_ONE)
{
spmv(mode, a, A, x, b, y, RANK_ONE());
}

template <typename AlphaType,
typename BetaType,
typename MatrixType,
typename InputType,
typename ... InputP,
typename OutputType,
typename ... OutputP>
typename std::enable_if<
Kokkos::is_view_uq_pce< Kokkos::View< InputType, InputP... > >::value &&
Kokkos::is_view_uq_pce< Kokkos::View< OutputType, OutputP... > >::value
>::type
spmv(
const char mode[],
const AlphaType& a,
const MatrixType& A,
Expand All @@ -1558,14 +1556,24 @@ spmv(
const Kokkos::View< OutputType, OutputP... >& y,
const RANK_TWO)
{
#if KOKKOSKERNELS_VERSION >= 40199
if(space != ExecutionSpace()) {
Kokkos::Impl::raise_error(
"Stokhos spmv not implemented for non-default execution space instance");
}
#endif
if(mode[0]!='N') {
Kokkos::Impl::raise_error(
"Stokhos spmv not implemented for transposed or conjugated matrix-vector multiplies");
}
if (y.extent(1) == 1) {
auto y_1D = subview(y, Kokkos::ALL(), 0);
auto x_1D = subview(x, Kokkos::ALL(), 0);
spmv(mode, a, A, x_1D, b, y_1D, RANK_ONE());
#if KOKKOSKERNELS_VERSION >= 40199
spmv(space, KokkosKernels::Experimental::Controls(), mode, a, A, x_1D, b, y_1D, RANK_ONE());
#else
spmv(KokkosKernels::Experimental::Controls(), mode, a, A, x_1D, b, y_1D, RANK_ONE());
#endif
}
else {
typedef Kokkos::View< OutputType, OutputP... > OutputVectorType;
Expand Down Expand Up @@ -1595,77 +1603,6 @@ spmv(
}
}

template <typename AlphaType,
typename BetaType,
typename MatrixType,
typename InputType,
typename ... InputP,
typename OutputType,
typename ... OutputP>
typename std::enable_if<
Kokkos::is_view_uq_pce< Kokkos::View< InputType, InputP... > >::value &&
Kokkos::is_view_uq_pce< Kokkos::View< OutputType, OutputP... > >::value
>::type
spmv(
KokkosKernels::Experimental::Controls,
const char mode[],
const AlphaType& a,
const MatrixType& A,
const Kokkos::View< InputType, InputP... >& x,
const BetaType& b,
const Kokkos::View< OutputType, OutputP... >& y,
const RANK_TWO)
{
spmv(mode, a, A, x, b, y, RANK_TWO());
}

template <typename AlphaType,
typename BetaType,
typename MatrixType,
typename InputType,
typename ... InputP,
typename OutputType,
typename ... OutputP>
std::enable_if_t<
Kokkos::is_view_uq_pce< Kokkos::View< InputType, InputP... > >::value &&
Kokkos::is_view_uq_pce< Kokkos::View< OutputType, OutputP... > >::value &&
KokkosSparse::is_crs_matrix<MatrixType>::value>
spmv(
KokkosKernels::Experimental::Controls controls,
const char mode[],
const AlphaType& a,
const MatrixType& A,
const Kokkos::View< InputType, InputP... >& x,
const BetaType& b,
const Kokkos::View< OutputType, OutputP... >& y)
{
using RANK_SPECIALISE = std::conditional_t<std::decay_t<decltype(x)>::rank == 2, RANK_TWO, RANK_ONE>;
spmv(controls, mode, a, A, x, b, y, RANK_SPECIALISE());
}

template <typename AlphaType,
typename BetaType,
typename MatrixType,
typename InputType,
typename ... InputP,
typename OutputType,
typename ... OutputP>
std::enable_if_t<
Kokkos::is_view_uq_pce< Kokkos::View< InputType, InputP... > >::value &&
Kokkos::is_view_uq_pce< Kokkos::View< OutputType, OutputP... > >::value &&
KokkosSparse::is_crs_matrix<MatrixType>::value>
spmv(
const char mode[],
const AlphaType& a,
const MatrixType& A,
const Kokkos::View< InputType, InputP... >& x,
const BetaType& b,
const Kokkos::View< OutputType, OutputP... >& y)
{
using RANK_SPECIALISE = std::conditional_t<std::decay_t<decltype(x)>::rank == 2, RANK_TWO, RANK_ONE>;
spmv(mode, a, A, x, b, y, RANK_SPECIALISE());
}

}

#endif /* #ifndef KOKKOS_CRSMATRIX_UQ_PCE_HPP */
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,11 @@ class Multiply< KokkosSparse::CrsMatrix< Sacado::MP::Vector<MatrixStorage>,

namespace KokkosSparse {

template <typename AlphaType,
template <
#if KOKKOSKERNELS_VERSION >= 40199
typename ExecutionSpace,
#endif
typename AlphaType,
typename BetaType,
typename MatrixType,
typename InputType,
Expand All @@ -541,6 +545,10 @@ typename std::enable_if<
Kokkos::is_view_mp_vector< Kokkos::View< OutputType, OutputP... > >::value
>::type
spmv(
#if KOKKOSKERNELS_VERSION >= 40199
const ExecutionSpace& space,
#endif
KokkosKernels::Experimental::Controls,
const char mode[],
const AlphaType& a,
const MatrixType& A,
Expand All @@ -554,6 +562,12 @@ spmv(
using input_vector_type = const_type_t<InputVectorType>;
typedef typename InputVectorType::array_type::non_const_value_type value_type;

#if KOKKOSKERNELS_VERSION >= 40199
if(space != ExecutionSpace()) {
Kokkos::Impl::raise_error(
"Stokhos spmv not implemented for non-default execution space instance");
}
#endif
if(mode[0]!='N') {
Kokkos::Impl::raise_error(
"Stokhos spmv not implemented for transposed or conjugated matrix-vector multiplies");
Expand Down Expand Up @@ -612,7 +626,11 @@ spmv(
}
}

template <typename AlphaType,
template <
#if KOKKOSKERNELS_VERSION >= 40199
typename ExecutionSpace,
#endif
typename AlphaType,
typename BetaType,
typename MatrixType,
typename InputType,
Expand All @@ -624,30 +642,10 @@ typename std::enable_if<
Kokkos::is_view_mp_vector< Kokkos::View< OutputType, OutputP... > >::value
>::type
spmv(
#if KOKKOSKERNELS_VERSION >= 40199
const ExecutionSpace& space,
#endif
KokkosKernels::Experimental::Controls,
const char mode[],
const AlphaType& a,
const MatrixType& A,
const Kokkos::View< InputType, InputP... >& x,
const BetaType& b,
const Kokkos::View< OutputType, OutputP... >& y,
const RANK_ONE)
{
spmv(mode, a, A, x, b, y, RANK_ONE());
}

template <typename AlphaType,
typename BetaType,
typename MatrixType,
typename InputType,
typename ... InputP,
typename OutputType,
typename ... OutputP>
typename std::enable_if<
Kokkos::is_view_mp_vector< Kokkos::View< InputType, InputP... > >::value &&
Kokkos::is_view_mp_vector< Kokkos::View< OutputType, OutputP... > >::value
>::type
spmv(
const char mode[],
const AlphaType& a,
const MatrixType& A,
Expand All @@ -656,14 +654,24 @@ spmv(
const Kokkos::View< OutputType, OutputP... >& y,
const RANK_TWO)
{
#if KOKKOSKERNELS_VERSION >= 40199
if(space != ExecutionSpace()) {
Kokkos::Impl::raise_error(
"Stokhos spmv not implemented for non-default execution space instance");
}
#endif
if(mode[0]!='N') {
Kokkos::Impl::raise_error(
"Stokhos spmv not implemented for transposed or conjugated matrix-vector multiplies");
}
if (y.extent(1) == 1) {
auto y_1D = subview(y, Kokkos::ALL(), 0);
auto x_1D = subview(x, Kokkos::ALL(), 0);
spmv(mode, a, A, x_1D, b, y_1D, RANK_ONE());
#if KOKKOSKERNELS_VERSION >= 40199
spmv(space, KokkosKernels::Experimental::Controls(), mode, a, A, x_1D, b, y_1D, RANK_ONE());
#else
spmv(KokkosKernels::Experimental::Controls(), mode, a, A, x_1D, b, y_1D, RANK_ONE());
#endif
}
else {
typedef Kokkos::View< OutputType, OutputP... > OutputVectorType;
Expand Down Expand Up @@ -725,77 +733,6 @@ spmv(
}
}

template <typename AlphaType,
typename BetaType,
typename MatrixType,
typename InputType,
typename ... InputP,
typename OutputType,
typename ... OutputP>
typename std::enable_if<
Kokkos::is_view_mp_vector< Kokkos::View< InputType, InputP... > >::value &&
Kokkos::is_view_mp_vector< Kokkos::View< OutputType, OutputP... > >::value
>::type
spmv(
KokkosKernels::Experimental::Controls,
const char mode[],
const AlphaType& a,
const MatrixType& A,
const Kokkos::View< InputType, InputP... >& x,
const BetaType& b,
const Kokkos::View< OutputType, OutputP... >& y,
const RANK_TWO)
{
spmv(mode, a, A, x, b, y, RANK_TWO());
}

template <typename AlphaType,
typename BetaType,
typename MatrixType,
typename InputType,
typename ... InputP,
typename OutputType,
typename ... OutputP>
std::enable_if_t<
Kokkos::is_view_mp_vector< Kokkos::View< InputType, InputP... > >::value &&
Kokkos::is_view_mp_vector< Kokkos::View< OutputType, OutputP... > >::value &&
KokkosSparse::is_crs_matrix<MatrixType>::value>
spmv(
KokkosKernels::Experimental::Controls controls,
const char mode[],
const AlphaType& a,
const MatrixType& A,
const Kokkos::View< InputType, InputP... >& x,
const BetaType& b,
const Kokkos::View< OutputType, OutputP... >& y)
{
using RANK_SPECIALISE = std::conditional_t<std::decay_t<decltype(x)>::rank == 2, RANK_TWO, RANK_ONE>;
spmv(controls, mode, a, A, x, b, y, RANK_SPECIALISE());
}

template <typename AlphaType,
typename BetaType,
typename MatrixType,
typename InputType,
typename ... InputP,
typename OutputType,
typename ... OutputP>
std::enable_if_t<
Kokkos::is_view_mp_vector< Kokkos::View< InputType, InputP... > >::value &&
Kokkos::is_view_mp_vector< Kokkos::View< OutputType, OutputP... > >::value &&
KokkosSparse::is_crs_matrix<MatrixType>::value>
spmv(
const char mode[],
const AlphaType& a,
const MatrixType& A,
const Kokkos::View< InputType, InputP... >& x,
const BetaType& b,
const Kokkos::View< OutputType, OutputP... >& y)
{
using RANK_SPECIALISE = std::conditional_t<std::decay_t<decltype(x)>::rank == 2, RANK_TWO, RANK_ONE>;
spmv(mode, a, A, x, b, y, RANK_SPECIALISE());
}

}

#endif /* #ifndef KOKKOS_CRSMATRIX_MP_VECTOR_HPP */

0 comments on commit eab6d5d

Please sign in to comment.