Skip to content

Commit

Permalink
fix remaining tests
Browse files Browse the repository at this point in the history
  • Loading branch information
upsj committed Nov 5, 2022
1 parent 7da3987 commit bb21691
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 64 deletions.
2 changes: 2 additions & 0 deletions test/solver/direct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class Direct : public CommonTestFixture {
.with_factorization(factorization_type::build()
.with_symmetric_sparsity(true)
.on(ref))
.with_num_rhs(static_cast<gko::size_type>(nrhs))
.on(ref);
alpha = gen_mtx(1, 1);
beta = gen_mtx(1, 1);
Expand All @@ -106,6 +107,7 @@ class Direct : public CommonTestFixture {
.with_factorization(factorization_type::build()
.with_symmetric_sparsity(true)
.on(exec))
.with_num_rhs(static_cast<gko::size_type>(nrhs))
.on(exec);
dalpha = gko::clone(exec, alpha);
dbeta = gko::clone(exec, beta);
Expand Down
66 changes: 34 additions & 32 deletions test/solver/lower_trs_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ TEST_F(LowerTrs, ApplyTriangularSparseMtxUnitDiagIsEquivalentToRef)
TEST_F(LowerTrs, ApplyFullDenseMtxMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 4, 50);
auto lower_trs_factory = solver_type::build().on(ref);
auto d_lower_trs_factory = solver_type::build().on(exec);
auto lower_trs_factory = solver_type::build().with_num_rhs(4u).on(ref);
auto d_lower_trs_factory = solver_type::build().with_num_rhs(4u).on(exec);
auto solver = lower_trs_factory->generate(mtx);
auto d_solver = d_lower_trs_factory->generate(dmtx);

Expand All @@ -255,9 +255,9 @@ TEST_F(LowerTrs, ApplyFullDenseMtxUnitDiagMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 5, 50);
auto lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(ref);
solver_type::build().with_num_rhs(5u).with_unit_diagonal(true).on(ref);
auto d_lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(exec);
solver_type::build().with_num_rhs(5u).with_unit_diagonal(true).on(exec);
auto solver = lower_trs_factory->generate(mtx);
auto d_solver = d_lower_trs_factory->generate(dmtx);

Expand All @@ -271,8 +271,8 @@ TEST_F(LowerTrs, ApplyFullDenseMtxUnitDiagMultipleRhsIsEquivalentToRef)
TEST_F(LowerTrs, ApplyFullSparseMtxMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 6, 5);
auto lower_trs_factory = solver_type::build().on(ref);
auto d_lower_trs_factory = solver_type::build().on(exec);
auto lower_trs_factory = solver_type::build().with_num_rhs(6u).on(ref);
auto d_lower_trs_factory = solver_type::build().with_num_rhs(6u).on(exec);
auto solver = lower_trs_factory->generate(mtx);
auto d_solver = d_lower_trs_factory->generate(dmtx);

Expand All @@ -287,9 +287,9 @@ TEST_F(LowerTrs, ApplyFullSparseMtxUnitDiagMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 7, 5);
auto lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(ref);
solver_type::build().with_num_rhs(7u).with_unit_diagonal(true).on(ref);
auto d_lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(exec);
solver_type::build().with_num_rhs(7u).with_unit_diagonal(true).on(exec);
auto solver = lower_trs_factory->generate(mtx);
auto d_solver = d_lower_trs_factory->generate(dmtx);

Expand All @@ -303,8 +303,8 @@ TEST_F(LowerTrs, ApplyFullSparseMtxUnitDiagMultipleRhsIsEquivalentToRef)
TEST_F(LowerTrs, ApplyTriangularDenseMtxMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 8, 50);
auto lower_trs_factory = solver_type::build().on(ref);
auto d_lower_trs_factory = solver_type::build().on(exec);
auto lower_trs_factory = solver_type::build().with_num_rhs(8u).on(ref);
auto d_lower_trs_factory = solver_type::build().with_num_rhs(8u).on(exec);
auto solver = lower_trs_factory->generate(mtx_l);
auto d_solver = d_lower_trs_factory->generate(dmtx_l);

Expand All @@ -319,9 +319,9 @@ TEST_F(LowerTrs, ApplyTriangularDenseMtxUnitDiagMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 9, 50);
auto lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(ref);
solver_type::build().with_num_rhs(9u).with_unit_diagonal(true).on(ref);
auto d_lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(exec);
solver_type::build().with_num_rhs(9u).with_unit_diagonal(true).on(exec);
auto solver = lower_trs_factory->generate(mtx_l);
auto d_solver = d_lower_trs_factory->generate(dmtx_l);

Expand All @@ -335,8 +335,8 @@ TEST_F(LowerTrs, ApplyTriangularDenseMtxUnitDiagMultipleRhsIsEquivalentToRef)
TEST_F(LowerTrs, ApplyTriangularSparseMtxMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 10, 5);
auto lower_trs_factory = solver_type::build().on(ref);
auto d_lower_trs_factory = solver_type::build().on(exec);
auto lower_trs_factory = solver_type::build().with_num_rhs(10u).on(ref);
auto d_lower_trs_factory = solver_type::build().with_num_rhs(10u).on(exec);
auto solver = lower_trs_factory->generate(mtx_l);
auto d_solver = d_lower_trs_factory->generate(dmtx_l);

Expand All @@ -351,9 +351,10 @@ TEST_F(LowerTrs, ApplyTriangularSparseMtxUnitDiagMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 11, 5);
auto lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(ref);
solver_type::build().with_num_rhs(11u).with_unit_diagonal(true).on(ref);
auto d_lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(exec);
solver_type::build().with_num_rhs(11u).with_unit_diagonal(true).on(
exec);
auto solver = lower_trs_factory->generate(mtx_l);
auto d_solver = d_lower_trs_factory->generate(dmtx_l);

Expand Down Expand Up @@ -507,8 +508,8 @@ TEST_F(LowerTrs, ClassicalApplyFullDenseMtxMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 4, 50);
dmtx->set_strategy(std::make_shared<mtx_type::classical>());
auto lower_trs_factory = solver_type::build().on(ref);
auto d_lower_trs_factory = solver_type::build().on(exec);
auto lower_trs_factory = solver_type::build().with_num_rhs(4u).on(ref);
auto d_lower_trs_factory = solver_type::build().with_num_rhs(4u).on(exec);
auto solver = lower_trs_factory->generate(mtx);
auto d_solver = d_lower_trs_factory->generate(dmtx);

Expand All @@ -524,9 +525,9 @@ TEST_F(LowerTrs, ClassicalApplyFullDenseMtxUnitDiagMultipleRhsIsEquivalentToRef)
initialize_data(50, 5, 50);
dmtx->set_strategy(std::make_shared<mtx_type::classical>());
auto lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(ref);
solver_type::build().with_num_rhs(5u).with_unit_diagonal(true).on(ref);
auto d_lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(exec);
solver_type::build().with_num_rhs(5u).with_unit_diagonal(true).on(exec);
auto solver = lower_trs_factory->generate(mtx);
auto d_solver = d_lower_trs_factory->generate(dmtx);

Expand All @@ -541,8 +542,8 @@ TEST_F(LowerTrs, ClassicalApplyFullSparseMtxMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 6, 5);
dmtx->set_strategy(std::make_shared<mtx_type::classical>());
auto lower_trs_factory = solver_type::build().on(ref);
auto d_lower_trs_factory = solver_type::build().on(exec);
auto lower_trs_factory = solver_type::build().with_num_rhs(6u).on(ref);
auto d_lower_trs_factory = solver_type::build().with_num_rhs(6u).on(exec);
auto solver = lower_trs_factory->generate(mtx);
auto d_solver = d_lower_trs_factory->generate(dmtx);

Expand All @@ -559,9 +560,9 @@ TEST_F(LowerTrs,
initialize_data(50, 7, 5);
dmtx->set_strategy(std::make_shared<mtx_type::classical>());
auto lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(ref);
solver_type::build().with_num_rhs(7u).with_unit_diagonal(true).on(ref);
auto d_lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(exec);
solver_type::build().with_num_rhs(7u).with_unit_diagonal(true).on(exec);
auto solver = lower_trs_factory->generate(mtx);
auto d_solver = d_lower_trs_factory->generate(dmtx);

Expand All @@ -576,8 +577,8 @@ TEST_F(LowerTrs, ClassicalApplyTriangularDenseMtxMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 8, 50);
dmtx_l->set_strategy(std::make_shared<mtx_type::classical>());
auto lower_trs_factory = solver_type::build().on(ref);
auto d_lower_trs_factory = solver_type::build().on(exec);
auto lower_trs_factory = solver_type::build().with_num_rhs(8u).on(ref);
auto d_lower_trs_factory = solver_type::build().with_num_rhs(8u).on(exec);
auto solver = lower_trs_factory->generate(mtx_l);
auto d_solver = d_lower_trs_factory->generate(dmtx_l);

Expand All @@ -594,9 +595,9 @@ TEST_F(LowerTrs,
initialize_data(50, 9, 50);
dmtx_l->set_strategy(std::make_shared<mtx_type::classical>());
auto lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(ref);
solver_type::build().with_num_rhs(9u).with_unit_diagonal(true).on(ref);
auto d_lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(exec);
solver_type::build().with_num_rhs(9u).with_unit_diagonal(true).on(exec);
auto solver = lower_trs_factory->generate(mtx_l);
auto d_solver = d_lower_trs_factory->generate(dmtx_l);

Expand All @@ -611,8 +612,8 @@ TEST_F(LowerTrs, ClassicalApplyTriangularSparseMtxMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 10, 5);
dmtx_l->set_strategy(std::make_shared<mtx_type::classical>());
auto lower_trs_factory = solver_type::build().on(ref);
auto d_lower_trs_factory = solver_type::build().on(exec);
auto lower_trs_factory = solver_type::build().with_num_rhs(10u).on(ref);
auto d_lower_trs_factory = solver_type::build().with_num_rhs(10u).on(exec);
auto solver = lower_trs_factory->generate(mtx_l);
auto d_solver = d_lower_trs_factory->generate(dmtx_l);

Expand All @@ -629,9 +630,10 @@ TEST_F(LowerTrs,
initialize_data(50, 11, 5);
dmtx_l->set_strategy(std::make_shared<mtx_type::classical>());
auto lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(ref);
solver_type::build().with_num_rhs(11u).with_unit_diagonal(true).on(ref);
auto d_lower_trs_factory =
solver_type::build().with_unit_diagonal(true).on(exec);
solver_type::build().with_num_rhs(11u).with_unit_diagonal(true).on(
exec);
auto solver = lower_trs_factory->generate(mtx_l);
auto d_solver = d_lower_trs_factory->generate(dmtx_l);

Expand Down
66 changes: 34 additions & 32 deletions test/solver/upper_trs_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ TEST_F(UpperTrs, ApplyTriangularSparseMtxUnitDiagIsEquivalentToRef)
TEST_F(UpperTrs, ApplyFullDenseMtxMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 4, 50);
auto upper_trs_factory = solver_type::build().on(ref);
auto d_upper_trs_factory = solver_type::build().on(exec);
auto upper_trs_factory = solver_type::build().with_num_rhs(4u).on(ref);
auto d_upper_trs_factory = solver_type::build().with_num_rhs(4u).on(exec);
auto solver = upper_trs_factory->generate(mtx);
auto d_solver = d_upper_trs_factory->generate(dmtx);

Expand All @@ -255,9 +255,9 @@ TEST_F(UpperTrs, ApplyFullDenseMtxUnitDiagMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 5, 50);
auto upper_trs_factory =
solver_type::build().with_unit_diagonal(true).on(ref);
solver_type::build().with_num_rhs(5u).with_unit_diagonal(true).on(ref);
auto d_upper_trs_factory =
solver_type::build().with_unit_diagonal(true).on(exec);
solver_type::build().with_num_rhs(5u).with_unit_diagonal(true).on(exec);
auto solver = upper_trs_factory->generate(mtx);
auto d_solver = d_upper_trs_factory->generate(dmtx);

Expand All @@ -271,8 +271,8 @@ TEST_F(UpperTrs, ApplyFullDenseMtxUnitDiagMultipleRhsIsEquivalentToRef)
TEST_F(UpperTrs, ApplyFullSparseMtxMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 6, 5);
auto upper_trs_factory = solver_type::build().on(ref);
auto d_upper_trs_factory = solver_type::build().on(exec);
auto upper_trs_factory = solver_type::build().with_num_rhs(6u).on(ref);
auto d_upper_trs_factory = solver_type::build().with_num_rhs(6u).on(exec);
auto solver = upper_trs_factory->generate(mtx);
auto d_solver = d_upper_trs_factory->generate(dmtx);

Expand All @@ -287,9 +287,9 @@ TEST_F(UpperTrs, ApplyFullSparseMtxUnitDiagMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 7, 5);
auto upper_trs_factory =
solver_type::build().with_unit_diagonal(true).on(ref);
solver_type::build().with_num_rhs(7u).with_unit_diagonal(true).on(ref);
auto d_upper_trs_factory =
solver_type::build().with_unit_diagonal(true).on(exec);
solver_type::build().with_num_rhs(7u).with_unit_diagonal(true).on(exec);
auto solver = upper_trs_factory->generate(mtx);
auto d_solver = d_upper_trs_factory->generate(dmtx);

Expand All @@ -303,8 +303,8 @@ TEST_F(UpperTrs, ApplyFullSparseMtxUnitDiagMultipleRhsIsEquivalentToRef)
TEST_F(UpperTrs, ApplyTriangularDenseMtxMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 8, 50);
auto upper_trs_factory = solver_type::build().on(ref);
auto d_upper_trs_factory = solver_type::build().on(exec);
auto upper_trs_factory = solver_type::build().with_num_rhs(8u).on(ref);
auto d_upper_trs_factory = solver_type::build().with_num_rhs(8u).on(exec);
auto solver = upper_trs_factory->generate(mtx_u);
auto d_solver = d_upper_trs_factory->generate(dmtx_u);

Expand All @@ -319,9 +319,9 @@ TEST_F(UpperTrs, ApplyTriangularDenseMtxUnitDiagMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 9, 50);
auto upper_trs_factory =
solver_type::build().with_unit_diagonal(true).on(ref);
solver_type::build().with_num_rhs(9u).with_unit_diagonal(true).on(ref);
auto d_upper_trs_factory =
solver_type::build().with_unit_diagonal(true).on(exec);
solver_type::build().with_num_rhs(9u).with_unit_diagonal(true).on(exec);
auto solver = upper_trs_factory->generate(mtx_u);
auto d_solver = d_upper_trs_factory->generate(dmtx_u);

Expand All @@ -335,8 +335,8 @@ TEST_F(UpperTrs, ApplyTriangularDenseMtxUnitDiagMultipleRhsIsEquivalentToRef)
TEST_F(UpperTrs, ApplyTriangularSparseMtxMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 10, 5);
auto upper_trs_factory = solver_type::build().on(ref);
auto d_upper_trs_factory = solver_type::build().on(exec);
auto upper_trs_factory = solver_type::build().with_num_rhs(10u).on(ref);
auto d_upper_trs_factory = solver_type::build().with_num_rhs(10u).on(exec);
auto solver = upper_trs_factory->generate(mtx_u);
auto d_solver = d_upper_trs_factory->generate(dmtx_u);

Expand All @@ -351,9 +351,10 @@ TEST_F(UpperTrs, ApplyTriangularSparseMtxUnitDiagMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 11, 5);
auto upper_trs_factory =
solver_type::build().with_unit_diagonal(true).on(ref);
solver_type::build().with_num_rhs(11u).with_unit_diagonal(true).on(ref);
auto d_upper_trs_factory =
solver_type::build().with_unit_diagonal(true).on(exec);
solver_type::build().with_num_rhs(11u).with_unit_diagonal(true).on(
exec);
auto solver = upper_trs_factory->generate(mtx_u);
auto d_solver = d_upper_trs_factory->generate(dmtx_u);

Expand Down Expand Up @@ -507,8 +508,8 @@ TEST_F(UpperTrs, ClassicalApplyFullDenseMtxMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 4, 50);
dmtx->set_strategy(std::make_shared<mtx_type::classical>());
auto upper_trs_factory = solver_type::build().on(ref);
auto d_upper_trs_factory = solver_type::build().on(exec);
auto upper_trs_factory = solver_type::build().with_num_rhs(4u).on(ref);
auto d_upper_trs_factory = solver_type::build().with_num_rhs(4u).on(exec);
auto solver = upper_trs_factory->generate(mtx);
auto d_solver = d_upper_trs_factory->generate(dmtx);

Expand All @@ -524,9 +525,9 @@ TEST_F(UpperTrs, ClassicalApplyFullDenseMtxUnitDiagMultipleRhsIsEquivalentToRef)
initialize_data(50, 5, 50);
dmtx->set_strategy(std::make_shared<mtx_type::classical>());
auto upper_trs_factory =
solver_type::build().with_unit_diagonal(true).on(ref);
solver_type::build().with_num_rhs(5u).with_unit_diagonal(true).on(ref);
auto d_upper_trs_factory =
solver_type::build().with_unit_diagonal(true).on(exec);
solver_type::build().with_num_rhs(5u).with_unit_diagonal(true).on(exec);
auto solver = upper_trs_factory->generate(mtx);
auto d_solver = d_upper_trs_factory->generate(dmtx);

Expand All @@ -541,8 +542,8 @@ TEST_F(UpperTrs, ClassicalApplyFullSparseMtxMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 6, 5);
dmtx->set_strategy(std::make_shared<mtx_type::classical>());
auto upper_trs_factory = solver_type::build().on(ref);
auto d_upper_trs_factory = solver_type::build().on(exec);
auto upper_trs_factory = solver_type::build().with_num_rhs(6u).on(ref);
auto d_upper_trs_factory = solver_type::build().with_num_rhs(6u).on(exec);
auto solver = upper_trs_factory->generate(mtx);
auto d_solver = d_upper_trs_factory->generate(dmtx);

Expand All @@ -559,9 +560,9 @@ TEST_F(UpperTrs,
initialize_data(50, 7, 5);
dmtx->set_strategy(std::make_shared<mtx_type::classical>());
auto upper_trs_factory =
solver_type::build().with_unit_diagonal(true).on(ref);
solver_type::build().with_num_rhs(7u).with_unit_diagonal(true).on(ref);
auto d_upper_trs_factory =
solver_type::build().with_unit_diagonal(true).on(exec);
solver_type::build().with_num_rhs(7u).with_unit_diagonal(true).on(exec);
auto solver = upper_trs_factory->generate(mtx);
auto d_solver = d_upper_trs_factory->generate(dmtx);

Expand All @@ -576,8 +577,8 @@ TEST_F(UpperTrs, ClassicalApplyTriangularDenseMtxMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 8, 50);
dmtx_u->set_strategy(std::make_shared<mtx_type::classical>());
auto upper_trs_factory = solver_type::build().on(ref);
auto d_upper_trs_factory = solver_type::build().on(exec);
auto upper_trs_factory = solver_type::build().with_num_rhs(8u).on(ref);
auto d_upper_trs_factory = solver_type::build().with_num_rhs(8u).on(exec);
auto solver = upper_trs_factory->generate(mtx_u);
auto d_solver = d_upper_trs_factory->generate(dmtx_u);

Expand All @@ -594,9 +595,9 @@ TEST_F(UpperTrs,
initialize_data(50, 9, 50);
dmtx_u->set_strategy(std::make_shared<mtx_type::classical>());
auto upper_trs_factory =
solver_type::build().with_unit_diagonal(true).on(ref);
solver_type::build().with_num_rhs(9u).with_unit_diagonal(true).on(ref);
auto d_upper_trs_factory =
solver_type::build().with_unit_diagonal(true).on(exec);
solver_type::build().with_num_rhs(9u).with_unit_diagonal(true).on(exec);
auto solver = upper_trs_factory->generate(mtx_u);
auto d_solver = d_upper_trs_factory->generate(dmtx_u);

Expand All @@ -611,8 +612,8 @@ TEST_F(UpperTrs, ClassicalApplyTriangularSparseMtxMultipleRhsIsEquivalentToRef)
{
initialize_data(50, 10, 5);
dmtx_u->set_strategy(std::make_shared<mtx_type::classical>());
auto upper_trs_factory = solver_type::build().on(ref);
auto d_upper_trs_factory = solver_type::build().on(exec);
auto upper_trs_factory = solver_type::build().with_num_rhs(10u).on(ref);
auto d_upper_trs_factory = solver_type::build().with_num_rhs(10u).on(exec);
auto solver = upper_trs_factory->generate(mtx_u);
auto d_solver = d_upper_trs_factory->generate(dmtx_u);

Expand All @@ -629,9 +630,10 @@ TEST_F(UpperTrs,
initialize_data(50, 11, 5);
dmtx_u->set_strategy(std::make_shared<mtx_type::classical>());
auto upper_trs_factory =
solver_type::build().with_unit_diagonal(true).on(ref);
solver_type::build().with_unit_diagonal(true).with_num_rhs(11u).on(ref);
auto d_upper_trs_factory =
solver_type::build().with_unit_diagonal(true).on(exec);
solver_type::build().with_unit_diagonal(true).with_num_rhs(11u).on(
exec);
auto solver = upper_trs_factory->generate(mtx_u);
auto d_solver = d_upper_trs_factory->generate(dmtx_u);

Expand Down

0 comments on commit bb21691

Please sign in to comment.