Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve coalesced reduction performance for tall and thin matrices (up to 2.6x faster) #2259

Merged
merged 8 commits into from
Apr 22, 2024
117 changes: 95 additions & 22 deletions cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,18 @@ namespace raft {
namespace linalg {
namespace detail {

template <int warpSize, int rpb>
template <int warpSize, int tpb, int rpw, bool noLoop = false>
struct ReductionThinPolicy {
static constexpr int LogicalWarpSize = warpSize;
static constexpr int RowsPerBlock = rpb;
static constexpr int ThreadsPerBlock = LogicalWarpSize * RowsPerBlock;
static_assert(tpb % warpSize == 0);

static constexpr int LogicalWarpSize = warpSize;
static constexpr int ThreadsPerBlock = tpb;
static constexpr int RowsPerLogicalWarp = rpw;
static constexpr int NumLogicalWarps = ThreadsPerBlock / LogicalWarpSize;
static constexpr int RowsPerBlock = NumLogicalWarps * RowsPerLogicalWarp;

// Whether D (run-time arg) will be smaller than warpSize (compile-time parameter)
static constexpr bool NoSequentialReduce = noLoop;
};

template <typename Policy,
Expand All @@ -53,19 +60,72 @@ RAFT_KERNEL __launch_bounds__(Policy::ThreadsPerBlock)
FinalLambda final_op,
bool inplace = false)
{
IdxType i = threadIdx.y + (Policy::RowsPerBlock * static_cast<IdxType>(blockIdx.x));
if (i >= N) return;
/* The strategy to achieve near-SOL memory bandwidth differs based on D:
* - For small D, we need to process multiple rows per logical warp in order to have
* multiple loads per thread and increase bytes in flight and amortize latencies.
* - For large D, we start with a sequential reduction. The compiler partially unrolls
* that loop (e.g. first a loop of stride 16, then 8, 4, and 1).
*/
IdxType i0 = threadIdx.y + (Policy::RowsPerBlock * static_cast<IdxType>(blockIdx.x));
if (i0 >= N) return;

OutType acc = init;
for (IdxType j = threadIdx.x; j < D; j += Policy::LogicalWarpSize) {
acc = reduce_op(acc, main_op(data[j + (D * i)], j));
OutType acc[Policy::RowsPerLogicalWarp];
#pragma unroll
for (int k = 0; k < Policy::RowsPerLogicalWarp; k++) {
acc[k] = init;
}
acc = raft::logicalWarpReduce<Policy::LogicalWarpSize>(acc, reduce_op);
if (threadIdx.x == 0) {

if constexpr (Policy::NoSequentialReduce) {
IdxType j = threadIdx.x;
if (j < D) {
#pragma unroll
for (IdxType k = 0; k < Policy::RowsPerLogicalWarp; k++) {
// Only the first row is known to be within bounds. Clamp to avoid out-of-mem read.
const IdxType i = raft::min(i0 + k * Policy::NumLogicalWarps, N - 1);
acc[k] = reduce_op(acc[k], main_op(data[j + (D * i)], j));
tfeher marked this conversation as resolved.
Show resolved Hide resolved
}
}
} else {
for (IdxType j = threadIdx.x; j < D; j += Policy::LogicalWarpSize) {
#pragma unroll
for (IdxType k = 0; k < Policy::RowsPerLogicalWarp; k++) {
const IdxType i = raft::min(i0 + k * Policy::NumLogicalWarps, N - 1);
acc[k] = reduce_op(acc[k], main_op(data[j + (D * i)], j));
}
}
}

/* This vector reduction has two benefits compared to naive separate reductions:
* - It avoids the LSU bottleneck when the number of columns is around 32 (e.g. for 32, 5 shuffles
* are required and there is no initial sequential reduction to amortize that cost).
* - It distributes the outputs to multiple threads, enabling a coalesced store when the number of
* rows per logical warp and logical warp size are equal.
*/
raft::logicalWarpReduceVector<Policy::LogicalWarpSize, Policy::RowsPerLogicalWarp>(
acc, threadIdx.x, reduce_op);

constexpr int reducOutVecWidth =
std::max(1, Policy::RowsPerLogicalWarp / Policy::LogicalWarpSize);
constexpr int reducOutGroupSize =
std::max(1, Policy::LogicalWarpSize / Policy::RowsPerLogicalWarp);
constexpr int reducNumGroups = Policy::LogicalWarpSize / reducOutGroupSize;

if (threadIdx.x % reducOutGroupSize == 0) {
const int groupId = threadIdx.x / reducOutGroupSize;
if (inplace) {
dots[i] = final_op(reduce_op(dots[i], acc));
#pragma unroll
for (int k = 0; k < reducOutVecWidth; k++) {
const int reductionId = k * reducNumGroups + groupId;
const IdxType i = i0 + reductionId * Policy::NumLogicalWarps;
if (i < N) { dots[i] = final_op(reduce_op(dots[i], acc[k])); }
}
} else {
dots[i] = final_op(acc);
#pragma unroll
for (int k = 0; k < reducOutVecWidth; k++) {
const int reductionId = k * reducNumGroups + groupId;
const IdxType i = i0 + reductionId * Policy::NumLogicalWarps;
if (i < N) { dots[i] = final_op(acc[k]); }
}
}
}
}
Expand All @@ -89,8 +149,12 @@ void coalescedReductionThin(OutType* dots,
FinalLambda final_op = raft::identity_op())
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"coalescedReductionThin<%d,%d>", Policy::LogicalWarpSize, Policy::RowsPerBlock);
dim3 threads(Policy::LogicalWarpSize, Policy::RowsPerBlock, 1);
"coalescedReductionThin<%d,%d,%d,%d>",
Policy::LogicalWarpSize,
Policy::ThreadsPerBlock,
Policy::RowsPerLogicalWarp,
static_cast<int>(Policy::NoSequentialReduce));
dim3 threads(Policy::LogicalWarpSize, Policy::NumLogicalWarps, 1);
dim3 blocks(ceildiv<IdxType>(N, Policy::RowsPerBlock), 1, 1);
coalescedReductionThinKernel<Policy>
<<<blocks, threads, 0, stream>>>(dots, data, D, N, init, main_op, reduce_op, final_op, inplace);
Expand All @@ -115,19 +179,28 @@ void coalescedReductionThinDispatcher(OutType* dots,
FinalLambda final_op = raft::identity_op())
{
if (D <= IdxType(2)) {
coalescedReductionThin<ReductionThinPolicy<2, 64>>(
coalescedReductionThin<ReductionThinPolicy<2, 128, 8, true>>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
} else if (D <= IdxType(4)) {
coalescedReductionThin<ReductionThinPolicy<4, 32>>(
coalescedReductionThin<ReductionThinPolicy<4, 128, 8, true>>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
} else if (D <= IdxType(8)) {
coalescedReductionThin<ReductionThinPolicy<8, 16>>(
coalescedReductionThin<ReductionThinPolicy<8, 128, 8, true>>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
} else if (D <= IdxType(16)) {
coalescedReductionThin<ReductionThinPolicy<16, 8>>(
coalescedReductionThin<ReductionThinPolicy<16, 128, 8, true>>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
} else if (D <= IdxType(32)) {
coalescedReductionThin<ReductionThinPolicy<32, 128, 8, true>>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
} else if (D < IdxType(128)) {
coalescedReductionThin<ReductionThinPolicy<32, 128, 4, false>>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
} else {
coalescedReductionThin<ReductionThinPolicy<32, 4>>(
// For D=128 (included) and above, the 4x-unrolled loading loop is used
// and multiple rows per warp are counter-productive in terms of cache-friendliness
// and register use.
coalescedReductionThin<ReductionThinPolicy<32, 128, 1, false>>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
}
}
Expand Down Expand Up @@ -319,10 +392,10 @@ void coalescedReductionThickDispatcher(OutType* dots,
// Note: multiple elements per thread to take advantage of the sequential reduction and loop
// unrolling
if (D < IdxType(32768)) {
coalescedReductionThick<ReductionThickPolicy<256, 32>, ReductionThinPolicy<32, 4>>(
coalescedReductionThick<ReductionThickPolicy<256, 32>, ReductionThinPolicy<32, 128, 1>>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
} else {
coalescedReductionThick<ReductionThickPolicy<256, 64>, ReductionThinPolicy<32, 4>>(
coalescedReductionThick<ReductionThickPolicy<256, 64>, ReductionThinPolicy<32, 128, 1>>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
}
}
Expand Down
13 changes: 11 additions & 2 deletions cpp/include/raft/util/pow2_utils.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,6 +20,15 @@

namespace raft {

/**
* Checks whether an integer is a power of 2.
*/
template <typename T>
constexpr HDI std::enable_if_t<std::is_integral_v<T>, bool> is_pow2(T v)
{
return (v && !(v & (v - 1)));
}

/**
* @brief Fast arithmetics and alignment checks for power-of-two values known at compile time.
*
Expand All @@ -33,7 +42,7 @@ struct Pow2 {
static constexpr Type Mask = Value - 1;

static_assert(std::is_integral<Type>::value, "Value must be integral.");
static_assert(Value && !(Value & Mask), "Value must be power of two.");
static_assert(is_pow2(Value), "Value must be power of two.");

#define Pow2_FUNC_QUALIFIER static constexpr __host__ __device__ __forceinline__
#define Pow2_WHEN_INTEGRAL(I) std::enable_if_t<Pow2_IS_REPRESENTABLE_AS(I), I>
Expand Down
104 changes: 102 additions & 2 deletions cpp/include/raft/util/reduction.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ DI T logicalWarpReduce(T val, ReduceLambda reduce_op)
{
#pragma unroll
for (int i = logicalWarpSize / 2; i > 0; i >>= 1) {
T tmp = shfl_xor(val, i);
val = reduce_op(val, tmp);
const T tmp = shfl_xor(val, i, logicalWarpSize);
val = reduce_op(val, tmp);
}
return val;
}
Expand Down Expand Up @@ -197,4 +197,104 @@ DI i_t binaryBlockReduce(i_t val, i_t* shmem)
}
}

/**
* @brief Executes a collaborative vector reduction per sub-warp
*
* This uses fewer shuffles than naively reducing each element independently.
* Better performance is achieved with a larger vector width, up to vecWidth == warpSize/2.
* For example, for logicalWarpSize == 32 and vecWidth == 16, the naive method requires 80
* shuffles, this one only 31, 2.58x fewer.
*
* However, the output of the reduction is not broadcasted. The vector is modified in place and
* each thread holds a part of the output vector. The outputs are distributed in a round-robin
* pattern between the threads to facilitate coalesced IO. There are 2 possible layouts based on
* which of logicalWarpSize and vecWidth is larger:
* - If vecWidth >= logicalWarpSize, each thread has vecWidth/logicalWarpSize outputs.
* - If logicalWarpSize > vecWidth, logicalWarpSize/vecWidth threads have a copy of the same output.
*
* Example 1: logicalWarpSize == 4, vecWidth == 8, v = a+b+c+d
* IN OUT
* lane 0 | a0 a1 a2 a3 a4 a5 a6 a7 | v0 v4 - - - - - -
* lane 1 | b0 b1 b2 b3 b4 b5 b6 b7 | v1 v5 - - - - - -
* lane 2 | c0 c1 c2 c3 c4 c5 c6 c7 | v2 v6 - - - - - -
* lane 3 | d0 d1 d2 d3 d4 d5 d6 d7 | v3 v7 - - - - - -
*
* Example 2: logicalWarpSize == 8, vecWidth == 4, v = a+b+c+d+e+f+g+h
* IN OUT
* lane 0 | a0 a1 a2 a3 | v0 - - -
* lane 1 | b0 b1 b2 b3 | v0 - - -
* lane 2 | c0 c1 c2 c3 | v1 - - -
* lane 3 | d0 d1 d2 d3 | v1 - - -
* lane 4 | e0 e1 e2 e3 | v2 - - -
* lane 5 | f0 f1 f2 f3 | v2 - - -
* lane 6 | g0 g1 g2 g3 | v3 - - -
* lane 7 | h0 h1 h2 h3 | v3 - - -
*
* @tparam logicalWarpSize Sub-warp size. Must be 2, 4, 8, 16 or 32.
* @tparam vecWidth Vector width. Must be a power of two.
* @tparam T Vector element type.
* @tparam ReduceLambda Reduction operator type.
* @param[in,out] acc Pointer to a vector of size vecWidth or more in registers
* @param[in] lane_id Lane id between 0 and logicalWarpSize-1
tfeher marked this conversation as resolved.
Show resolved Hide resolved
* @param[in] reduce_op Reduction operator, assumed to be commutative and associative.
*/
template <int logicalWarpSize, int vecWidth, typename T, typename ReduceLambda>
DI void logicalWarpReduceVector(T* acc, int lane_id, ReduceLambda reduce_op)
{
static_assert(vecWidth > 0, "Vec width must be strictly positive.");
static_assert(!(vecWidth & (vecWidth - 1)), "Vec width must be a power of two.");
static_assert(logicalWarpSize >= 2 && logicalWarpSize <= 32,
"Logical warp size must be between 2 and 32");
static_assert(!(logicalWarpSize & (logicalWarpSize - 1)),
"Logical warp size must be a power of two.");

constexpr int shflStride = logicalWarpSize / 2;
constexpr int nextWarpSize = logicalWarpSize / 2;

// One step of the butterfly reduction, applied to each element of the vector.
#pragma unroll
for (int k = 0; k < vecWidth; k++) {
const T tmp = shfl_xor(acc[k], shflStride, logicalWarpSize);
acc[k] = reduce_op(acc[k], tmp);
}

constexpr int nextVecWidth = std::max(1, vecWidth / 2);

/* Split into 2 smaller logical warps and distribute half of the data to each for the next step.
* The distribution pattern is designed so that at the end the outputs are coalesced/round-robin.
* The idea is to distribute contiguous "chunks" of the vectors based on the new warp size. These
* chunks will be halved in the next step and so on.
*
* Example for logicalWarpSize == 4, vecWidth == 8:
* lane 0 | 0 1 2 3 4 5 6 7 | [0 1] [4 5] - - - - | [0] [4] - - - - - -
* lane 1 | 0 1 2 3 4 5 6 7 | [0 1] [4 5] - - - - | [1] [5] - - - - - -
* lane 2 | 0 1 2 3 4 5 6 7 | [2 3] [6 7] - - - - | [2] [6] - - - - - -
* lane 3 | 0 1 2 3 4 5 6 7 | [2 3] [6 7] - - - - | [3] [7] - - - - - -
* chunkSize=2 chunkSize=1
*/
if constexpr (nextVecWidth < vecWidth) {
T tmp[nextVecWidth];
const bool firstHalf = (lane_id % logicalWarpSize) < nextWarpSize;
constexpr int chunkSize = std::min(nextVecWidth, nextWarpSize);
constexpr int numChunks = nextVecWidth / chunkSize;
#pragma unroll
for (int c = 0; c < numChunks; c++) {
#pragma unroll
for (int i = 0; i < chunkSize; i++) {
const int k = c * chunkSize + i;
tmp[k] = firstHalf ? acc[2 * c * chunkSize + i] : acc[(2 * c + 1) * chunkSize + i];
}
}
#pragma unroll
for (int k = 0; k < nextVecWidth; k++) {
acc[k] = tmp[k];
}
}

// Recursively call with smaller sub-warps and possibly smaller vector width.
if constexpr (nextWarpSize > 1) {
logicalWarpReduceVector<nextWarpSize, nextVecWidth>(acc, lane_id % nextWarpSize, reduce_op);
}
}

} // namespace raft
46 changes: 36 additions & 10 deletions cpp/test/linalg/coalesced_reduction.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ struct coalescedReductionInputs {
template <typename T>
::std::ostream& operator<<(::std::ostream& os, const coalescedReductionInputs<T>& dims)
{
return os;
return os << "{ " << dims.tolerance << ", " << dims.rows << ", " << dims.cols << ", "
<< dims.seed;
}

// Or else, we get the following compilation error
Expand Down Expand Up @@ -113,15 +114,40 @@ class coalescedReductionTest : public ::testing::TestWithParam<coalescedReductio
rmm::device_uvector<T> dots_act;
};

const std::vector<coalescedReductionInputs<float>> inputsf = {{0.000002f, 1024, 32, 1234ULL},
{0.000002f, 1024, 64, 1234ULL},
{0.000002f, 1024, 128, 1234ULL},
{0.000002f, 1024, 256, 1234ULL}};

const std::vector<coalescedReductionInputs<double>> inputsd = {{0.000000001, 1024, 32, 1234ULL},
{0.000000001, 1024, 64, 1234ULL},
{0.000000001, 1024, 128, 1234ULL},
{0.000000001, 1024, 256, 1234ULL}};
// Note: it's important to have a variety of rows/columns combinations to test all possible code
// paths: thin (few cols or many rows), medium, thick (many cols, very few rows).

const std::vector<coalescedReductionInputs<float>> inputsf = {{0.000002f, 50, 2, 1234ULL},
{0.000002f, 50, 3, 1234ULL},
{0.000002f, 50, 7, 1234ULL},
{0.000002f, 50, 9, 1234ULL},
{0.000002f, 50, 20, 1234ULL},
{0.000002f, 50, 55, 1234ULL},
{0.000002f, 50, 100, 1234ULL},
{0.000002f, 50, 270, 1234ULL},
{0.000002f, 10000, 3, 1234ULL},
{0.000002f, 10000, 9, 1234ULL},
{0.000002f, 10000, 20, 1234ULL},
{0.000002f, 10000, 55, 1234ULL},
{0.000002f, 10000, 100, 1234ULL},
{0.000002f, 10000, 270, 1234ULL},
{0.0001f, 10, 25000, 1234ULL}};

const std::vector<coalescedReductionInputs<double>> inputsd = {{0.000000001, 50, 2, 1234ULL},
{0.000000001, 50, 3, 1234ULL},
{0.000000001, 50, 7, 1234ULL},
{0.000000001, 50, 9, 1234ULL},
{0.000000001, 50, 20, 1234ULL},
{0.000000001, 50, 55, 1234ULL},
{0.000000001, 50, 100, 1234ULL},
{0.000000001, 50, 270, 1234ULL},
{0.000000001, 10000, 3, 1234ULL},
{0.000000001, 10000, 9, 1234ULL},
{0.000000001, 10000, 20, 1234ULL},
{0.000000001, 10000, 55, 1234ULL},
{0.000000001, 10000, 100, 1234ULL},
{0.000000001, 10000, 270, 1234ULL},
{0.0000001, 10, 25000, 1234ULL}};

typedef coalescedReductionTest<float> coalescedReductionTestF;
TEST_P(coalescedReductionTestF, Result)
Expand Down
Loading