Skip to content

Commit

Permalink
add num_rhs to Direct parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
upsj committed Nov 5, 2022
1 parent b613bc0 commit 7da3987
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
13 changes: 9 additions & 4 deletions core/solver/direct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,15 @@ Direct<ValueType, IndexType>::Direct(const Factory* factory,
if (separate_diag) {
GKO_NOT_SUPPORTED(type);
}
const auto lower_factory =
lower_type::build().with_unit_diagonal(lower_unit_diag).on(exec);
const auto upper_factory =
upper_type::build().with_unit_diagonal(upper_unit_diag).on(exec);
const auto num_rhs = factory->get_parameters().num_rhs;
const auto lower_factory = lower_type::build()
.with_num_rhs(num_rhs)
.with_unit_diagonal(lower_unit_diag)
.on(exec);
const auto upper_factory = upper_type::build()
.with_num_rhs(num_rhs)
.with_unit_diagonal(upper_unit_diag)
.on(exec);
switch (type) {
case storage_type::empty:
// remove the factor storage entirely
Expand Down
9 changes: 9 additions & 0 deletions include/ginkgo/core/solver/direct.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ class Direct : public EnableLinOp<Direct<ValueType, IndexType>>,

GKO_CREATE_FACTORY_PARAMETERS(parameters, Factory)
{
/**
* Number of right hand sides.
*
* @note This value is currently only required for the CUDA executor,
* which will throw an exception if a different number of rhs is
* passed to Direct::apply.
*/
gko::size_type GKO_FACTORY_PARAMETER_SCALAR(num_rhs, 1u);

/** The factorization factory to use for generating the factors. */
std::shared_ptr<const LinOpFactory> GKO_FACTORY_PARAMETER_SCALAR(
factorization, nullptr);
Expand Down

0 comments on commit 7da3987

Please sign in to comment.