Skip to content

Commit

Permalink
AnyCTO with arbitrary number of functions (#4135)
Browse files Browse the repository at this point in the history
This PR extends AnyCTO to support an arbitrary number of optimized
functions. Some constructs such as PrefixSum need multiple functions
that may need to be optimized. It is also possible to provide no GPU
function, in which case the compile time parameters will be directly
given to the CPU function.
  • Loading branch information
AlexanderSinn committed Sep 8, 2024
1 parent 8a5a989 commit c6b4fde
Showing 1 changed file with 22 additions and 12 deletions.
34 changes: 22 additions & 12 deletions Src/Base/AMReX_CTOParallelForImpl.H
Original file line number Diff line number Diff line change
Expand Up @@ -44,24 +44,33 @@ namespace detail
}
};

template <class L, class F, typename... As>
template <class L, typename... As, class... Fs>
bool
AnyCTO_helper2 (const L& l, const F& f, TypeList<As...>,
std::array<int,sizeof...(As)> const& runtime_options)
AnyCTO_helper2 (const L& l, TypeList<As...>,
std::array<int,sizeof...(As)> const& runtime_options, const Fs&...cto_functs)
{
if (runtime_options == std::array<int,sizeof...(As)>{As::value...}) {
l(CTOWrapper<F, As::value...>{f});
if constexpr (sizeof...(cto_functs) != 0) {
// Apply the CTOWrapper to each function that was given in cto_functs
// and call the CPU function l with all of them
l(CTOWrapper<Fs, As::value...>{cto_functs}...);
} else {
// No functions in cto_functs so we call l directly with the compile time arguments
l(As{}...);
}
return true;
} else {
return false;
}
}

template <class L, class F, typename... PPs, typename RO>
template <class L, typename... PPs, typename RO, class...Fs>
void
AnyCTO_helper1 (const L& l, const F& f, TypeList<PPs...>, RO const& runtime_options)
AnyCTO_helper1 (const L& l, TypeList<PPs...>,
RO const& runtime_options, const Fs&...cto_functs)
{
bool found_option = (false || ... || AnyCTO_helper2(l, f, PPs{}, runtime_options));
bool found_option = (false || ... ||
AnyCTO_helper2(l, PPs{}, runtime_options, cto_functs...));
amrex::ignore_unused(found_option);
AMREX_ASSERT(found_option);
}
Expand Down Expand Up @@ -168,17 +177,18 @@ namespace detail
* \param list_of_compile_time_options list of all possible values of the parameters.
* \param runtime_options the run time parameters.
* \param l a callable object containing a CPU function that launches the provided GPU kernel.
* \param f a callable object containing the GPU kernel with optimizations.
* \param cto_functs a callable object containing the GPU kernel with optimizations.
*/
template <class L, class F, typename... CTOs>
template <class L, class... Fs, typename... CTOs>
void AnyCTO ([[maybe_unused]] TypeList<CTOs...> list_of_compile_time_options,
std::array<int,sizeof...(CTOs)> const& runtime_options,
L&& l, F&& f)
L&& l, Fs&&...cto_functs)
{
#if (__cplusplus >= 201703L)
detail::AnyCTO_helper1(std::forward<L>(l), std::forward<F>(f),
detail::AnyCTO_helper1(std::forward<L>(l),
CartesianProduct(typename CTOs::list_type{}...),
runtime_options);
runtime_options,
std::forward<Fs>(cto_functs)...);
#else
amrex::ignore_unused(runtime_options, l, f);
static_assert(std::is_integral<F>::value, "This requires C++17");
Expand Down

0 comments on commit c6b4fde

Please sign in to comment.