Skip to content

Commit

Permalink
fixes non-mpi residual norm dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed Mar 1, 2022
1 parent 4f09d10 commit 1dc6d94
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 36 deletions.
49 changes: 27 additions & 22 deletions core/stop/residual_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,28 +101,6 @@ bool use_distributed(Arg* linop, Rest*... rest)
}


#else


template <typename ValueType, typename LinOp, typename... Rest>
bool any_is_complex(const LinOp* in, Rest&&... rest)
{
return !(is_complex<ValueType>() ||
dynamic_cast<const ConvertibleTo<matrix::Dense<>>*>(in)) ||
any_is_complex<ValueType>(std::forward<Rest>(rest)...);
}


template <typename... Args>
bool use_distributed(Args*...)
{
return false;
}


#endif


template <typename ValueType, typename Function, typename... LinOps>
void norm_dispatch(Function&& fn, LinOps*... linops)
{
Expand All @@ -146,6 +124,33 @@ void norm_dispatch(Function&& fn, LinOps*... linops)
}


#else


template <typename ValueType, typename LinOp, typename... Rest>
bool any_is_complex(const LinOp* in, Rest&&... rest)
{
return !(is_complex<ValueType>() ||
dynamic_cast<const ConvertibleTo<matrix::Dense<>>*>(in)) ||
any_is_complex<ValueType>(std::forward<Rest>(rest)...);
}


template <typename ValueType, typename Function, typename... LinOps>
void norm_dispatch(Function&& fn, LinOps*... linops)
{
if (any_is_complex<ValueType>(linops...)) {
precision_dispatch<to_complex<ValueType>>(std::forward<Function>(fn),
linops...);
} else {
precision_dispatch<ValueType>(std::forward<Function>(fn), linops...);
}
}


#endif


template <typename ValueType>
ResidualNormBase<ValueType>::ResidualNormBase(
std::shared_ptr<const gko::Executor> exec, const CriterionArgs& args,
Expand Down
17 changes: 3 additions & 14 deletions include/ginkgo/core/base/precision_dispatch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -477,21 +477,10 @@ void precision_dispatch_real_complex_distributed(Function fn,
#else


template <typename ValueType, typename Function>
void precision_dispatch_real_complex_distributed(Function fn, const LinOp* in,
LinOp* out)
{
precision_dispatch_real_complex<ValueType>(fn, in, out);
}


template <typename ValueType, typename Function>
void precision_dispatch_real_complex_distributed(Function fn,
const LinOp* alpha,
const LinOp* in,
const LinOp* beta, LinOp* out)
template <typename ValueType, typename Function, typename... Args>
void precision_dispatch_real_complex_distributed(Function fn, Args*... args)
{
precision_dispatch_real_complex<ValueType>(fn, alpha, in, beta, out);
precision_dispatch_real_complex<ValueType>(fn, args...);
}


Expand Down

0 comments on commit 1dc6d94

Please sign in to comment.