Skip to content

Commit

Permalink
Merge check for number of rhs in CUDA triangular solvers
Browse files Browse the repository at this point in the history
This makes specifying the with_num_rhs parameter for the triangular solvers a requirement on the CUDA executor
to avoid the (apparent) internal data corruption in cuSPARSE that may occur otherwise when analyzing
with <= 32 rhs and solving with > 32 rhs.

Additionally, I needed to modify the solver tests, since instead of the typical `forall_solvers(..., forall_vectors(...))`,
here we need to use `forall_vectors(..., forall_solvers_with_num_rhs(...))`.

Related PR: #1184
  • Loading branch information
upsj authored Nov 5, 2022
2 parents b1cf825 + bb21691 commit 64c18ca
Show file tree
Hide file tree
Showing 10 changed files with 311 additions and 184 deletions.
13 changes: 9 additions & 4 deletions core/solver/direct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,15 @@ Direct<ValueType, IndexType>::Direct(const Factory* factory,
if (separate_diag) {
GKO_NOT_SUPPORTED(type);
}
const auto lower_factory =
lower_type::build().with_unit_diagonal(lower_unit_diag).on(exec);
const auto upper_factory =
upper_type::build().with_unit_diagonal(upper_unit_diag).on(exec);
const auto num_rhs = factory->get_parameters().num_rhs;
const auto lower_factory = lower_type::build()
.with_num_rhs(num_rhs)
.with_unit_diagonal(lower_unit_diag)
.on(exec);
const auto upper_factory = upper_type::build()
.with_num_rhs(num_rhs)
.with_unit_diagonal(upper_unit_diag)
.on(exec);
switch (type) {
case storage_type::empty:
// remove the factor storage entirely
Expand Down
33 changes: 33 additions & 0 deletions cuda/solver/common_trs_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <cusparse.h>


#include <ginkgo/core/base/exception.hpp>
#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/base/math.hpp>

Expand Down Expand Up @@ -84,6 +85,7 @@ struct CudaSolveStruct : gko::solver::SolveStruct {
cusparseHandle_t handle;
cusparseSpSMDescr_t spsm_descr;
cusparseSpMatDescr_t descr_a;
size_type num_rhs;

// Implicit parameter in spsm_solve, therefore stored here.
array<char> work;
Expand All @@ -94,8 +96,12 @@ struct CudaSolveStruct : gko::solver::SolveStruct {
: handle{exec->get_cusparse_handle()},
spsm_descr{},
descr_a{},
num_rhs{num_rhs},
work{exec}
{
if (num_rhs == 0) {
return;
}
cusparse::pointer_mode_guard pm_guard(handle);
spsm_descr = cusparse::create_spsm_descr();
descr_a = cusparse::create_csr(
Expand Down Expand Up @@ -143,6 +149,17 @@ struct CudaSolveStruct : gko::solver::SolveStruct {
matrix::Dense<ValueType>* output, matrix::Dense<ValueType>*,
matrix::Dense<ValueType>*) const
{
if (input->get_size()[1] != num_rhs) {
throw gko::ValueMismatch{
__FILE__,
__LINE__,
__FUNCTION__,
input->get_size()[1],
num_rhs,
"the dimensions of the multivector do not match the value "
"provided at generation time. Check the value specified in "
".with_num_rhs(...)."};
}
cusparse::pointer_mode_guard pm_guard(handle);
auto descr_b = cusparse::create_dnmat(
input->get_size(), input->get_stride(),
Expand Down Expand Up @@ -191,6 +208,7 @@ struct CudaSolveStruct : gko::solver::SolveStruct {
csrsm2Info_t solve_info;
cusparseSolvePolicy_t policy;
cusparseMatDescr_t factor_descr;
size_type num_rhs;
mutable array<char> work;

CudaSolveStruct(std::shared_ptr<const gko::CudaExecutor> exec,
Expand All @@ -202,8 +220,12 @@ struct CudaSolveStruct : gko::solver::SolveStruct {
solve_info{},
policy{},
factor_descr{},
num_rhs{num_rhs},
work{exec}
{
if (num_rhs == 0) {
return;
}
cusparse::pointer_mode_guard pm_guard(handle);
factor_descr = cusparse::create_mat_descr();
solve_info = cusparse::create_solve_info();
Expand Down Expand Up @@ -243,6 +265,17 @@ struct CudaSolveStruct : gko::solver::SolveStruct {
matrix::Dense<ValueType>* output, matrix::Dense<ValueType>*,
matrix::Dense<ValueType>*) const
{
if (input->get_size()[1] != num_rhs) {
throw gko::ValueMismatch{
__FILE__,
__LINE__,
__FUNCTION__,
input->get_size()[1],
num_rhs,
"the dimensions of the multivector do not match the value "
"provided at generation time. Check the value specified in "
".with_num_rhs(...)."};
}
cusparse::pointer_mode_guard pm_guard(handle);
dense::copy(exec, input, output);
cusparse::csrsm2_solve(
Expand Down
10 changes: 10 additions & 0 deletions cuda/test/solver/lower_trs_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,4 +213,14 @@ TEST_F(LowerTrs, CudaMultipleRhsApplyIsEquivalentToRef)
}


TEST_F(LowerTrs, CudaApplyThrowsWithWrongNumRHS)
{
initialize_data(50, 3);
auto d_lower_trs_factory = gko::solver::LowerTrs<>::build().on(cuda);
auto d_solver = d_lower_trs_factory->generate(d_csr_mtx);

ASSERT_THROW(d_solver->apply(d_b2.get(), d_x.get()), gko::ValueMismatch);
}


} // namespace
10 changes: 10 additions & 0 deletions cuda/test/solver/upper_trs_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,4 +213,14 @@ TEST_F(UpperTrs, CudaMultipleRhsApplyIsEquivalentToRef)
}


TEST_F(UpperTrs, CudaApplyThrowsWithWrongNumRHS)
{
initialize_data(50, 3);
auto d_lower_trs_factory = gko::solver::UpperTrs<>::build().on(cuda);
auto d_solver = d_lower_trs_factory->generate(d_csr_mtx);

ASSERT_THROW(d_solver->apply(d_b2.get(), d_x.get()), gko::ValueMismatch);
}


} // namespace
9 changes: 9 additions & 0 deletions include/ginkgo/core/solver/direct.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ class Direct : public EnableLinOp<Direct<ValueType, IndexType>>,

GKO_CREATE_FACTORY_PARAMETERS(parameters, Factory)
{
/**
* Number of right hand sides.
*
* @note This value is currently only required for the CUDA executor,
* which will throw an exception if a different number of rhs is
* passed to Direct::apply.
*/
gko::size_type GKO_FACTORY_PARAMETER_SCALAR(num_rhs, 1u);

/** The factorization factory to use for generating the factors. */
std::shared_ptr<const LinOpFactory> GKO_FACTORY_PARAMETER_SCALAR(
factorization, nullptr);
Expand Down
14 changes: 4 additions & 10 deletions include/ginkgo/core/solver/triangular.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,8 @@ class LowerTrs : public EnableLinOp<LowerTrs<ValueType, IndexType>>,
/**
* Number of right hand sides.
*
* @note This value is currently a dummy value which is not used by the
* analysis step. It is possible that future algorithms (cusparse
* csrsm2) make use of the number of right hand sides for a more
* sophisticated implementation. Hence this parameter is left
* here. But currently, there is no need to use it.
* @note This value is currently only required for the CUDA
* trisolve_algorithm::sparselib algorithm.
*/
gko::size_type GKO_FACTORY_PARAMETER_SCALAR(num_rhs, 1u);

Expand Down Expand Up @@ -264,11 +261,8 @@ class UpperTrs : public EnableLinOp<UpperTrs<ValueType, IndexType>>,
/**
* Number of right hand sides.
*
* @note This value is currently a dummy value which is not used by the
* analysis step. It is possible that future algorithms (cusparse
* csrsm2) make use of the number of right hand sides for a more
* sophisticated implementation. Hence this parameter is left
* here. But currently, there is no need to use it.
* @note This value is currently only required for the CUDA
* trisolve_algorithm::sparselib algorithm.
*/
gko::size_type GKO_FACTORY_PARAMETER_SCALAR(num_rhs, 1u);

Expand Down
2 changes: 2 additions & 0 deletions test/solver/direct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class Direct : public CommonTestFixture {
.with_factorization(factorization_type::build()
.with_symmetric_sparsity(true)
.on(ref))
.with_num_rhs(static_cast<gko::size_type>(nrhs))
.on(ref);
alpha = gen_mtx(1, 1);
beta = gen_mtx(1, 1);
Expand All @@ -106,6 +107,7 @@ class Direct : public CommonTestFixture {
.with_factorization(factorization_type::build()
.with_symmetric_sparsity(true)
.on(exec))
.with_num_rhs(static_cast<gko::size_type>(nrhs))
.on(exec);
dalpha = gko::clone(exec, alpha);
dbeta = gko::clone(exec, beta);
Expand Down
66 changes: 34 additions & 32 deletions test/solver/lower_trs_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ TEST_F(LowerTrs, ApplyTriangularSparseMtxUnitDiagIsEquivalentToRef)
TEST_F(LowerTrs, ApplyFullDenseMtxMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 4, 50);
auto lower_trs_factory = solver_type::build().on(ref);
auto d_lower_trs_factory = solver_type::build().on(exec);
auto lower_trs_factory = solver_type::build().with_num_rhs(4u).on(ref);
auto d_lower_trs_factory = solver_type::build().with_num_rhs(4u).on(exec);
auto solver = lower_trs_factory->generate(mtx);
auto d_solver = d_lower_trs_factory->generate(dmtx);

Expand All @@ -255,9 +255,9 @@ TEST_F(LowerTrs, ApplyFullDenseMtxUnitDiagMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 5, 50);
auto lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(ref);
solver_type::build().with_num_rhs(5u).with_unit_diagonal(true).on(ref);
auto d_lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(exec);
solver_type::build().with_num_rhs(5u).with_unit_diagonal(true).on(exec);
auto solver = lower_trs_factory->generate(mtx);
auto d_solver = d_lower_trs_factory->generate(dmtx);

Expand All @@ -271,8 +271,8 @@ TEST_F(LowerTrs, ApplyFullDenseMtxUnitDiagMultipleRhsIsEquivalentToRef)
TEST_F(LowerTrs, ApplyFullSparseMtxMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 6, 5);
auto lower_trs_factory = solver_type::build().on(ref);
auto d_lower_trs_factory = solver_type::build().on(exec);
auto lower_trs_factory = solver_type::build().with_num_rhs(6u).on(ref);
auto d_lower_trs_factory = solver_type::build().with_num_rhs(6u).on(exec);
auto solver = lower_trs_factory->generate(mtx);
auto d_solver = d_lower_trs_factory->generate(dmtx);

Expand All @@ -287,9 +287,9 @@ TEST_F(LowerTrs, ApplyFullSparseMtxUnitDiagMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 7, 5);
auto lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(ref);
solver_type::build().with_num_rhs(7u).with_unit_diagonal(true).on(ref);
auto d_lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(exec);
solver_type::build().with_num_rhs(7u).with_unit_diagonal(true).on(exec);
auto solver = lower_trs_factory->generate(mtx);
auto d_solver = d_lower_trs_factory->generate(dmtx);

Expand All @@ -303,8 +303,8 @@ TEST_F(LowerTrs, ApplyFullSparseMtxUnitDiagMultipleRhsIsEquivalentToRef)
TEST_F(LowerTrs, ApplyTriangularDenseMtxMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 8, 50);
auto lower_trs_factory = solver_type::build().on(ref);
auto d_lower_trs_factory = solver_type::build().on(exec);
auto lower_trs_factory = solver_type::build().with_num_rhs(8u).on(ref);
auto d_lower_trs_factory = solver_type::build().with_num_rhs(8u).on(exec);
auto solver = lower_trs_factory->generate(mtx_l);
auto d_solver = d_lower_trs_factory->generate(dmtx_l);

Expand All @@ -319,9 +319,9 @@ TEST_F(LowerTrs, ApplyTriangularDenseMtxUnitDiagMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 9, 50);
auto lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(ref);
solver_type::build().with_num_rhs(9u).with_unit_diagonal(true).on(ref);
auto d_lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(exec);
solver_type::build().with_num_rhs(9u).with_unit_diagonal(true).on(exec);
auto solver = lower_trs_factory->generate(mtx_l);
auto d_solver = d_lower_trs_factory->generate(dmtx_l);

Expand All @@ -335,8 +335,8 @@ TEST_F(LowerTrs, ApplyTriangularDenseMtxUnitDiagMultipleRhsIsEquivalentToRef)
TEST_F(LowerTrs, ApplyTriangularSparseMtxMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 10, 5);
auto lower_trs_factory = solver_type::build().on(ref);
auto d_lower_trs_factory = solver_type::build().on(exec);
auto lower_trs_factory = solver_type::build().with_num_rhs(10u).on(ref);
auto d_lower_trs_factory = solver_type::build().with_num_rhs(10u).on(exec);
auto solver = lower_trs_factory->generate(mtx_l);
auto d_solver = d_lower_trs_factory->generate(dmtx_l);

Expand All @@ -351,9 +351,10 @@ TEST_F(LowerTrs, ApplyTriangularSparseMtxUnitDiagMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 11, 5);
auto lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(ref);
solver_type::build().with_num_rhs(11u).with_unit_diagonal(true).on(ref);
auto d_lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(exec);
solver_type::build().with_num_rhs(11u).with_unit_diagonal(true).on(
exec);
auto solver = lower_trs_factory->generate(mtx_l);
auto d_solver = d_lower_trs_factory->generate(dmtx_l);

Expand Down Expand Up @@ -507,8 +508,8 @@ TEST_F(LowerTrs, ClassicalApplyFullDenseMtxMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 4, 50);
dmtx->set_strategy(std::make_shared<mtx_type::classical>());
auto lower_trs_factory = solver_type::build().on(ref);
auto d_lower_trs_factory = solver_type::build().on(exec);
auto lower_trs_factory = solver_type::build().with_num_rhs(4u).on(ref);
auto d_lower_trs_factory = solver_type::build().with_num_rhs(4u).on(exec);
auto solver = lower_trs_factory->generate(mtx);
auto d_solver = d_lower_trs_factory->generate(dmtx);

Expand All @@ -524,9 +525,9 @@ TEST_F(LowerTrs, ClassicalApplyFullDenseMtxUnitDiagMultipleRhsIsEquivalentToRef)
initialize_data(50, 5, 50);
dmtx->set_strategy(std::make_shared<mtx_type::classical>());
auto lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(ref);
solver_type::build().with_num_rhs(5u).with_unit_diagonal(true).on(ref);
auto d_lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(exec);
solver_type::build().with_num_rhs(5u).with_unit_diagonal(true).on(exec);
auto solver = lower_trs_factory->generate(mtx);
auto d_solver = d_lower_trs_factory->generate(dmtx);

Expand All @@ -541,8 +542,8 @@ TEST_F(LowerTrs, ClassicalApplyFullSparseMtxMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 6, 5);
dmtx->set_strategy(std::make_shared<mtx_type::classical>());
auto lower_trs_factory = solver_type::build().on(ref);
auto d_lower_trs_factory = solver_type::build().on(exec);
auto lower_trs_factory = solver_type::build().with_num_rhs(6u).on(ref);
auto d_lower_trs_factory = solver_type::build().with_num_rhs(6u).on(exec);
auto solver = lower_trs_factory->generate(mtx);
auto d_solver = d_lower_trs_factory->generate(dmtx);

Expand All @@ -559,9 +560,9 @@ TEST_F(LowerTrs,
initialize_data(50, 7, 5);
dmtx->set_strategy(std::make_shared<mtx_type::classical>());
auto lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(ref);
solver_type::build().with_num_rhs(7u).with_unit_diagonal(true).on(ref);
auto d_lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(exec);
solver_type::build().with_num_rhs(7u).with_unit_diagonal(true).on(exec);
auto solver = lower_trs_factory->generate(mtx);
auto d_solver = d_lower_trs_factory->generate(dmtx);

Expand All @@ -576,8 +577,8 @@ TEST_F(LowerTrs, ClassicalApplyTriangularDenseMtxMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 8, 50);
dmtx_l->set_strategy(std::make_shared<mtx_type::classical>());
auto lower_trs_factory = solver_type::build().on(ref);
auto d_lower_trs_factory = solver_type::build().on(exec);
auto lower_trs_factory = solver_type::build().with_num_rhs(8u).on(ref);
auto d_lower_trs_factory = solver_type::build().with_num_rhs(8u).on(exec);
auto solver = lower_trs_factory->generate(mtx_l);
auto d_solver = d_lower_trs_factory->generate(dmtx_l);

Expand All @@ -594,9 +595,9 @@ TEST_F(LowerTrs,
initialize_data(50, 9, 50);
dmtx_l->set_strategy(std::make_shared<mtx_type::classical>());
auto lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(ref);
solver_type::build().with_num_rhs(9u).with_unit_diagonal(true).on(ref);
auto d_lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(exec);
solver_type::build().with_num_rhs(9u).with_unit_diagonal(true).on(exec);
auto solver = lower_trs_factory->generate(mtx_l);
auto d_solver = d_lower_trs_factory->generate(dmtx_l);

Expand All @@ -611,8 +612,8 @@ TEST_F(LowerTrs, ClassicalApplyTriangularSparseMtxMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 10, 5);
dmtx_l->set_strategy(std::make_shared<mtx_type::classical>());
auto lower_trs_factory = solver_type::build().on(ref);
auto d_lower_trs_factory = solver_type::build().on(exec);
auto lower_trs_factory = solver_type::build().with_num_rhs(10u).on(ref);
auto d_lower_trs_factory = solver_type::build().with_num_rhs(10u).on(exec);
auto solver = lower_trs_factory->generate(mtx_l);
auto d_solver = d_lower_trs_factory->generate(dmtx_l);

Expand All @@ -629,9 +630,10 @@ TEST_F(LowerTrs,
initialize_data(50, 11, 5);
dmtx_l->set_strategy(std::make_shared<mtx_type::classical>());
auto lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(ref);
solver_type::build().with_num_rhs(11u).with_unit_diagonal(true).on(ref);
auto d_lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(exec);
solver_type::build().with_num_rhs(11u).with_unit_diagonal(true).on(
exec);
auto solver = lower_trs_factory->generate(mtx_l);
auto d_solver = d_lower_trs_factory->generate(dmtx_l);

Expand Down
Loading

0 comments on commit 64c18ca

Please sign in to comment.