diff --git a/test/solver/direct.cpp b/test/solver/direct.cpp index 005082a8849..1f7e186b886 100644 --- a/test/solver/direct.cpp +++ b/test/solver/direct.cpp @@ -97,6 +97,7 @@ class Direct : public CommonTestFixture { .with_factorization(factorization_type::build() .with_symmetric_sparsity(true) .on(ref)) + .with_num_rhs(static_cast(nrhs)) .on(ref); alpha = gen_mtx(1, 1); beta = gen_mtx(1, 1); @@ -106,6 +107,7 @@ class Direct : public CommonTestFixture { .with_factorization(factorization_type::build() .with_symmetric_sparsity(true) .on(exec)) + .with_num_rhs(static_cast(nrhs)) .on(exec); dalpha = gko::clone(exec, alpha); dbeta = gko::clone(exec, beta); diff --git a/test/solver/lower_trs_kernels.cpp b/test/solver/lower_trs_kernels.cpp index d437a188d50..23d9fa1f8e1 100644 --- a/test/solver/lower_trs_kernels.cpp +++ b/test/solver/lower_trs_kernels.cpp @@ -239,8 +239,8 @@ TEST_F(LowerTrs, ApplyTriangularSparseMtxUnitDiagIsEquivalentToRef) TEST_F(LowerTrs, ApplyFullDenseMtxMultipleRhsIsEquivalentToRef) { initialize_data(50, 4, 50); - auto lower_trs_factory = solver_type::build().on(ref); - auto d_lower_trs_factory = solver_type::build().on(exec); + auto lower_trs_factory = solver_type::build().with_num_rhs(4u).on(ref); + auto d_lower_trs_factory = solver_type::build().with_num_rhs(4u).on(exec); auto solver = lower_trs_factory->generate(mtx); auto d_solver = d_lower_trs_factory->generate(dmtx); @@ -255,9 +255,9 @@ TEST_F(LowerTrs, ApplyFullDenseMtxUnitDiagMultipleRhsIsEquivalentToRef) { initialize_data(50, 5, 50); auto lower_trs_factory = - solver_type::build().with_unit_diagonal(true).on(ref); + solver_type::build().with_num_rhs(5u).with_unit_diagonal(true).on(ref); auto d_lower_trs_factory = - solver_type::build().with_unit_diagonal(true).on(exec); + solver_type::build().with_num_rhs(5u).with_unit_diagonal(true).on(exec); auto solver = lower_trs_factory->generate(mtx); auto d_solver = d_lower_trs_factory->generate(dmtx); @@ -271,8 +271,8 @@ TEST_F(LowerTrs, ApplyFullDenseMtxUnitDiagMultipleRhsIsEquivalentToRef) TEST_F(LowerTrs, ApplyFullSparseMtxMultipleRhsIsEquivalentToRef) { initialize_data(50, 6, 5); - auto lower_trs_factory = solver_type::build().on(ref); - auto d_lower_trs_factory = solver_type::build().on(exec); + auto lower_trs_factory = solver_type::build().with_num_rhs(6u).on(ref); + auto d_lower_trs_factory = solver_type::build().with_num_rhs(6u).on(exec); auto solver = lower_trs_factory->generate(mtx); auto d_solver = d_lower_trs_factory->generate(dmtx); @@ -287,9 +287,9 @@ TEST_F(LowerTrs, ApplyFullSparseMtxUnitDiagMultipleRhsIsEquivalentToRef) { initialize_data(50, 7, 5); auto lower_trs_factory = - solver_type::build().with_unit_diagonal(true).on(ref); + solver_type::build().with_num_rhs(7u).with_unit_diagonal(true).on(ref); auto d_lower_trs_factory = - solver_type::build().with_unit_diagonal(true).on(exec); + solver_type::build().with_num_rhs(7u).with_unit_diagonal(true).on(exec); auto solver = lower_trs_factory->generate(mtx); auto d_solver = d_lower_trs_factory->generate(dmtx); @@ -303,8 +303,8 @@ TEST_F(LowerTrs, ApplyFullSparseMtxUnitDiagMultipleRhsIsEquivalentToRef) TEST_F(LowerTrs, ApplyTriangularDenseMtxMultipleRhsIsEquivalentToRef) { initialize_data(50, 8, 50); - auto lower_trs_factory = solver_type::build().on(ref); - auto d_lower_trs_factory = solver_type::build().on(exec); + auto lower_trs_factory = solver_type::build().with_num_rhs(8u).on(ref); + auto d_lower_trs_factory = solver_type::build().with_num_rhs(8u).on(exec); auto solver = lower_trs_factory->generate(mtx_l); auto d_solver = d_lower_trs_factory->generate(dmtx_l); @@ -319,9 +319,9 @@ TEST_F(LowerTrs, ApplyTriangularDenseMtxUnitDiagMultipleRhsIsEquivalentToRef) { initialize_data(50, 9, 50); auto lower_trs_factory = - solver_type::build().with_unit_diagonal(true).on(ref); + solver_type::build().with_num_rhs(9u).with_unit_diagonal(true).on(ref); auto d_lower_trs_factory = - solver_type::build().with_unit_diagonal(true).on(exec); + solver_type::build().with_num_rhs(9u).with_unit_diagonal(true).on(exec); auto solver = lower_trs_factory->generate(mtx_l); auto d_solver = d_lower_trs_factory->generate(dmtx_l); @@ -335,8 +335,8 @@ TEST_F(LowerTrs, ApplyTriangularDenseMtxUnitDiagMultipleRhsIsEquivalentToRef) TEST_F(LowerTrs, ApplyTriangularSparseMtxMultipleRhsIsEquivalentToRef) { initialize_data(50, 10, 5); - auto lower_trs_factory = solver_type::build().on(ref); - auto d_lower_trs_factory = solver_type::build().on(exec); + auto lower_trs_factory = solver_type::build().with_num_rhs(10u).on(ref); + auto d_lower_trs_factory = solver_type::build().with_num_rhs(10u).on(exec); auto solver = lower_trs_factory->generate(mtx_l); auto d_solver = d_lower_trs_factory->generate(dmtx_l); @@ -351,9 +351,10 @@ TEST_F(LowerTrs, ApplyTriangularSparseMtxUnitDiagMultipleRhsIsEquivalentToRef) { initialize_data(50, 11, 5); auto lower_trs_factory = - solver_type::build().with_unit_diagonal(true).on(ref); + solver_type::build().with_num_rhs(11u).with_unit_diagonal(true).on(ref); auto d_lower_trs_factory = - solver_type::build().with_unit_diagonal(true).on(exec); + solver_type::build().with_num_rhs(11u).with_unit_diagonal(true).on( + exec); auto solver = lower_trs_factory->generate(mtx_l); auto d_solver = d_lower_trs_factory->generate(dmtx_l); @@ -507,8 +508,8 @@ TEST_F(LowerTrs, ClassicalApplyFullDenseMtxMultipleRhsIsEquivalentToRef) { initialize_data(50, 4, 50); dmtx->set_strategy(std::make_shared()); - auto lower_trs_factory = solver_type::build().on(ref); - auto d_lower_trs_factory = solver_type::build().on(exec); + auto lower_trs_factory = solver_type::build().with_num_rhs(4u).on(ref); + auto d_lower_trs_factory = solver_type::build().with_num_rhs(4u).on(exec); auto solver = lower_trs_factory->generate(mtx); auto d_solver = d_lower_trs_factory->generate(dmtx); @@ -524,9 +525,9 @@ TEST_F(LowerTrs, ClassicalApplyFullDenseMtxUnitDiagMultipleRhsIsEquivalentToRef) initialize_data(50, 5, 50); dmtx->set_strategy(std::make_shared()); auto lower_trs_factory = - solver_type::build().with_unit_diagonal(true).on(ref); + solver_type::build().with_num_rhs(5u).with_unit_diagonal(true).on(ref); auto d_lower_trs_factory = - solver_type::build().with_unit_diagonal(true).on(exec); + solver_type::build().with_num_rhs(5u).with_unit_diagonal(true).on(exec); auto solver = lower_trs_factory->generate(mtx); auto d_solver = d_lower_trs_factory->generate(dmtx); @@ -541,8 +542,8 @@ TEST_F(LowerTrs, ClassicalApplyFullSparseMtxMultipleRhsIsEquivalentToRef) { initialize_data(50, 6, 5); dmtx->set_strategy(std::make_shared()); - auto lower_trs_factory = solver_type::build().on(ref); - auto d_lower_trs_factory = solver_type::build().on(exec); + auto lower_trs_factory = solver_type::build().with_num_rhs(6u).on(ref); + auto d_lower_trs_factory = solver_type::build().with_num_rhs(6u).on(exec); auto solver = lower_trs_factory->generate(mtx); auto d_solver = d_lower_trs_factory->generate(dmtx); @@ -559,9 +560,9 @@ TEST_F(LowerTrs, initialize_data(50, 7, 5); dmtx->set_strategy(std::make_shared()); auto lower_trs_factory = - solver_type::build().with_unit_diagonal(true).on(ref); + solver_type::build().with_num_rhs(7u).with_unit_diagonal(true).on(ref); auto d_lower_trs_factory = - solver_type::build().with_unit_diagonal(true).on(exec); + solver_type::build().with_num_rhs(7u).with_unit_diagonal(true).on(exec); auto solver = lower_trs_factory->generate(mtx); auto d_solver = d_lower_trs_factory->generate(dmtx); @@ -576,8 +577,8 @@ TEST_F(LowerTrs, ClassicalApplyTriangularDenseMtxMultipleRhsIsEquivalentToRef) { initialize_data(50, 8, 50); dmtx_l->set_strategy(std::make_shared()); - auto lower_trs_factory = solver_type::build().on(ref); - auto d_lower_trs_factory = solver_type::build().on(exec); + auto lower_trs_factory = solver_type::build().with_num_rhs(8u).on(ref); + auto d_lower_trs_factory = solver_type::build().with_num_rhs(8u).on(exec); auto solver = lower_trs_factory->generate(mtx_l); auto d_solver = d_lower_trs_factory->generate(dmtx_l); @@ -594,9 +595,9 @@ TEST_F(LowerTrs, initialize_data(50, 9, 50); dmtx_l->set_strategy(std::make_shared()); auto lower_trs_factory = - solver_type::build().with_unit_diagonal(true).on(ref); + solver_type::build().with_num_rhs(9u).with_unit_diagonal(true).on(ref); auto d_lower_trs_factory = - solver_type::build().with_unit_diagonal(true).on(exec); + solver_type::build().with_num_rhs(9u).with_unit_diagonal(true).on(exec); auto solver = lower_trs_factory->generate(mtx_l); auto d_solver = d_lower_trs_factory->generate(dmtx_l); @@ -611,8 +612,8 @@ TEST_F(LowerTrs, ClassicalApplyTriangularSparseMtxMultipleRhsIsEquivalentToRef) { initialize_data(50, 10, 5); dmtx_l->set_strategy(std::make_shared()); - auto lower_trs_factory = solver_type::build().on(ref); - auto d_lower_trs_factory = solver_type::build().on(exec); + auto lower_trs_factory = solver_type::build().with_num_rhs(10u).on(ref); + auto d_lower_trs_factory = solver_type::build().with_num_rhs(10u).on(exec); auto solver = lower_trs_factory->generate(mtx_l); auto d_solver = d_lower_trs_factory->generate(dmtx_l); @@ -629,9 +630,10 @@ TEST_F(LowerTrs, initialize_data(50, 11, 5); dmtx_l->set_strategy(std::make_shared()); auto lower_trs_factory = - solver_type::build().with_unit_diagonal(true).on(ref); + solver_type::build().with_num_rhs(11u).with_unit_diagonal(true).on(ref); auto d_lower_trs_factory = - solver_type::build().with_unit_diagonal(true).on(exec); + solver_type::build().with_num_rhs(11u).with_unit_diagonal(true).on( + exec); auto solver = lower_trs_factory->generate(mtx_l); auto d_solver = d_lower_trs_factory->generate(dmtx_l); diff --git a/test/solver/upper_trs_kernels.cpp b/test/solver/upper_trs_kernels.cpp index 141ab78c86c..13da5adac69 100644 --- a/test/solver/upper_trs_kernels.cpp +++ b/test/solver/upper_trs_kernels.cpp @@ -239,8 +239,8 @@ TEST_F(UpperTrs, ApplyTriangularSparseMtxUnitDiagIsEquivalentToRef) TEST_F(UpperTrs, ApplyFullDenseMtxMultipleRhsIsEquivalentToRef) { initialize_data(50, 4, 50); - auto upper_trs_factory = solver_type::build().on(ref); - auto d_upper_trs_factory = solver_type::build().on(exec); + auto upper_trs_factory = solver_type::build().with_num_rhs(4u).on(ref); + auto d_upper_trs_factory = solver_type::build().with_num_rhs(4u).on(exec); auto solver = upper_trs_factory->generate(mtx); auto d_solver = d_upper_trs_factory->generate(dmtx); @@ -255,9 +255,9 @@ TEST_F(UpperTrs, ApplyFullDenseMtxUnitDiagMultipleRhsIsEquivalentToRef) { initialize_data(50, 5, 50); auto upper_trs_factory = - solver_type::build().with_unit_diagonal(true).on(ref); + solver_type::build().with_num_rhs(5u).with_unit_diagonal(true).on(ref); auto d_upper_trs_factory = - solver_type::build().with_unit_diagonal(true).on(exec); + solver_type::build().with_num_rhs(5u).with_unit_diagonal(true).on(exec); auto solver = upper_trs_factory->generate(mtx); auto d_solver = d_upper_trs_factory->generate(dmtx); @@ -271,8 +271,8 @@ TEST_F(UpperTrs, ApplyFullDenseMtxUnitDiagMultipleRhsIsEquivalentToRef) TEST_F(UpperTrs, ApplyFullSparseMtxMultipleRhsIsEquivalentToRef) { initialize_data(50, 6, 5); - auto upper_trs_factory = solver_type::build().on(ref); - auto d_upper_trs_factory = solver_type::build().on(exec); + auto upper_trs_factory = solver_type::build().with_num_rhs(6u).on(ref); + auto d_upper_trs_factory = solver_type::build().with_num_rhs(6u).on(exec); auto solver = upper_trs_factory->generate(mtx); auto d_solver = d_upper_trs_factory->generate(dmtx); @@ -287,9 +287,9 @@ TEST_F(UpperTrs, ApplyFullSparseMtxUnitDiagMultipleRhsIsEquivalentToRef) { initialize_data(50, 7, 5); auto upper_trs_factory = - solver_type::build().with_unit_diagonal(true).on(ref); + solver_type::build().with_num_rhs(7u).with_unit_diagonal(true).on(ref); auto d_upper_trs_factory = - solver_type::build().with_unit_diagonal(true).on(exec); + solver_type::build().with_num_rhs(7u).with_unit_diagonal(true).on(exec); auto solver = upper_trs_factory->generate(mtx); auto d_solver = d_upper_trs_factory->generate(dmtx); @@ -303,8 +303,8 @@ TEST_F(UpperTrs, ApplyFullSparseMtxUnitDiagMultipleRhsIsEquivalentToRef) TEST_F(UpperTrs, ApplyTriangularDenseMtxMultipleRhsIsEquivalentToRef) { initialize_data(50, 8, 50); - auto upper_trs_factory = solver_type::build().on(ref); - auto d_upper_trs_factory = solver_type::build().on(exec); + auto upper_trs_factory = solver_type::build().with_num_rhs(8u).on(ref); + auto d_upper_trs_factory = solver_type::build().with_num_rhs(8u).on(exec); auto solver = upper_trs_factory->generate(mtx_u); auto d_solver = d_upper_trs_factory->generate(dmtx_u); @@ -319,9 +319,9 @@ TEST_F(UpperTrs, ApplyTriangularDenseMtxUnitDiagMultipleRhsIsEquivalentToRef) { initialize_data(50, 9, 50); auto upper_trs_factory = - solver_type::build().with_unit_diagonal(true).on(ref); + solver_type::build().with_num_rhs(9u).with_unit_diagonal(true).on(ref); auto d_upper_trs_factory = - solver_type::build().with_unit_diagonal(true).on(exec); + solver_type::build().with_num_rhs(9u).with_unit_diagonal(true).on(exec); auto solver = upper_trs_factory->generate(mtx_u); auto d_solver = d_upper_trs_factory->generate(dmtx_u); @@ -335,8 +335,8 @@ TEST_F(UpperTrs, ApplyTriangularDenseMtxUnitDiagMultipleRhsIsEquivalentToRef) TEST_F(UpperTrs, ApplyTriangularSparseMtxMultipleRhsIsEquivalentToRef) { initialize_data(50, 10, 5); - auto upper_trs_factory = solver_type::build().on(ref); - auto d_upper_trs_factory = solver_type::build().on(exec); + auto upper_trs_factory = solver_type::build().with_num_rhs(10u).on(ref); + auto d_upper_trs_factory = solver_type::build().with_num_rhs(10u).on(exec); auto solver = upper_trs_factory->generate(mtx_u); auto d_solver = d_upper_trs_factory->generate(dmtx_u); @@ -351,9 +351,10 @@ TEST_F(UpperTrs, ApplyTriangularSparseMtxUnitDiagMultipleRhsIsEquivalentToRef) { initialize_data(50, 11, 5); auto upper_trs_factory = - solver_type::build().with_unit_diagonal(true).on(ref); + solver_type::build().with_num_rhs(11u).with_unit_diagonal(true).on(ref); auto d_upper_trs_factory = - solver_type::build().with_unit_diagonal(true).on(exec); + solver_type::build().with_num_rhs(11u).with_unit_diagonal(true).on( + exec); auto solver = upper_trs_factory->generate(mtx_u); auto d_solver = d_upper_trs_factory->generate(dmtx_u); @@ -507,8 +508,8 @@ TEST_F(UpperTrs, ClassicalApplyFullDenseMtxMultipleRhsIsEquivalentToRef) { initialize_data(50, 4, 50); dmtx->set_strategy(std::make_shared()); - auto upper_trs_factory = solver_type::build().on(ref); - auto d_upper_trs_factory = solver_type::build().on(exec); + auto upper_trs_factory = solver_type::build().with_num_rhs(4u).on(ref); + auto d_upper_trs_factory = solver_type::build().with_num_rhs(4u).on(exec); auto solver = upper_trs_factory->generate(mtx); auto d_solver = d_upper_trs_factory->generate(dmtx); @@ -524,9 +525,9 @@ TEST_F(UpperTrs, ClassicalApplyFullDenseMtxUnitDiagMultipleRhsIsEquivalentToRef) initialize_data(50, 5, 50); dmtx->set_strategy(std::make_shared()); auto upper_trs_factory = - solver_type::build().with_unit_diagonal(true).on(ref); + solver_type::build().with_num_rhs(5u).with_unit_diagonal(true).on(ref); auto d_upper_trs_factory = - solver_type::build().with_unit_diagonal(true).on(exec); + solver_type::build().with_num_rhs(5u).with_unit_diagonal(true).on(exec); auto solver = upper_trs_factory->generate(mtx); auto d_solver = d_upper_trs_factory->generate(dmtx); @@ -541,8 +542,8 @@ TEST_F(UpperTrs, ClassicalApplyFullSparseMtxMultipleRhsIsEquivalentToRef) { initialize_data(50, 6, 5); dmtx->set_strategy(std::make_shared()); - auto upper_trs_factory = solver_type::build().on(ref); - auto d_upper_trs_factory = solver_type::build().on(exec); + auto upper_trs_factory = solver_type::build().with_num_rhs(6u).on(ref); + auto d_upper_trs_factory = solver_type::build().with_num_rhs(6u).on(exec); auto solver = upper_trs_factory->generate(mtx); auto d_solver = d_upper_trs_factory->generate(dmtx); @@ -559,9 +560,9 @@ TEST_F(UpperTrs, initialize_data(50, 7, 5); dmtx->set_strategy(std::make_shared()); auto upper_trs_factory = - solver_type::build().with_unit_diagonal(true).on(ref); + solver_type::build().with_num_rhs(7u).with_unit_diagonal(true).on(ref); auto d_upper_trs_factory = - solver_type::build().with_unit_diagonal(true).on(exec); + solver_type::build().with_num_rhs(7u).with_unit_diagonal(true).on(exec); auto solver = upper_trs_factory->generate(mtx); auto d_solver = d_upper_trs_factory->generate(dmtx); @@ -576,8 +577,8 @@ TEST_F(UpperTrs, ClassicalApplyTriangularDenseMtxMultipleRhsIsEquivalentToRef) { initialize_data(50, 8, 50); dmtx_u->set_strategy(std::make_shared()); - auto upper_trs_factory = solver_type::build().on(ref); - auto d_upper_trs_factory = solver_type::build().on(exec); + auto upper_trs_factory = solver_type::build().with_num_rhs(8u).on(ref); + auto d_upper_trs_factory = solver_type::build().with_num_rhs(8u).on(exec); auto solver = upper_trs_factory->generate(mtx_u); auto d_solver = d_upper_trs_factory->generate(dmtx_u); @@ -594,9 +595,9 @@ TEST_F(UpperTrs, initialize_data(50, 9, 50); dmtx_u->set_strategy(std::make_shared()); auto upper_trs_factory = - solver_type::build().with_unit_diagonal(true).on(ref); + solver_type::build().with_num_rhs(9u).with_unit_diagonal(true).on(ref); auto d_upper_trs_factory = - solver_type::build().with_unit_diagonal(true).on(exec); + solver_type::build().with_num_rhs(9u).with_unit_diagonal(true).on(exec); auto solver = upper_trs_factory->generate(mtx_u); auto d_solver = d_upper_trs_factory->generate(dmtx_u); @@ -611,8 +612,8 @@ TEST_F(UpperTrs, ClassicalApplyTriangularSparseMtxMultipleRhsIsEquivalentToRef) { initialize_data(50, 10, 5); dmtx_u->set_strategy(std::make_shared()); - auto upper_trs_factory = solver_type::build().on(ref); - auto d_upper_trs_factory = solver_type::build().on(exec); + auto upper_trs_factory = solver_type::build().with_num_rhs(10u).on(ref); + auto d_upper_trs_factory = solver_type::build().with_num_rhs(10u).on(exec); auto solver = upper_trs_factory->generate(mtx_u); auto d_solver = d_upper_trs_factory->generate(dmtx_u); @@ -629,9 +630,10 @@ TEST_F(UpperTrs, initialize_data(50, 11, 5); dmtx_u->set_strategy(std::make_shared()); auto upper_trs_factory = - solver_type::build().with_unit_diagonal(true).on(ref); + solver_type::build().with_unit_diagonal(true).with_num_rhs(11u).on(ref); auto d_upper_trs_factory = - solver_type::build().with_unit_diagonal(true).on(exec); + solver_type::build().with_unit_diagonal(true).with_num_rhs(11u).on( + exec); auto solver = upper_trs_factory->generate(mtx_u); auto d_solver = d_upper_trs_factory->generate(dmtx_u);