Skip to content

Commit

Permalink
Merge branch 'branch-24.06' into device_async_resource_ref
Browse files Browse the repository at this point in the history
  • Loading branch information
harrism committed Apr 23, 2024
2 parents bd46d13 + 317a61c commit debb2f5
Show file tree
Hide file tree
Showing 55 changed files with 399 additions and 91 deletions.
6 changes: 3 additions & 3 deletions cpp/cmake/modules/ConfigureCUDA.cmake
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# =============================================================================
# Copyright (c) 2018-2023, NVIDIA CORPORATION.
# Copyright (c) 2018-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
Expand All @@ -13,8 +13,8 @@
# =============================================================================

if(DISABLE_DEPRECATION_WARNINGS)
list(APPEND RAFT_CXX_FLAGS -Wno-deprecated-declarations)
list(APPEND RAFT_CUDA_FLAGS -Xcompiler=-Wno-deprecated-declarations)
list(APPEND RAFT_CXX_FLAGS -Wno-deprecated-declarations -DRAFT_HIDE_DEPRECATION_WARNINGS)
list(APPEND RAFT_CUDA_FLAGS -Xcompiler=-Wno-deprecated-declarations -DRAFT_HIDE_DEPRECATION_WARNINGS)
endif()

# Be very strict when compiling with GCC as host compiler (and thus more lenient when compiling with
Expand Down
4 changes: 3 additions & 1 deletion cpp/include/raft/cluster/specializations.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,8 +15,10 @@
*/
#pragma once

#ifndef RAFT_HIDE_DEPRECATION_WARNINGS
#pragma message( \
__FILE__ \
" is deprecated and will be removed." \
" Including specializations is not necessary any more." \
" For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html")
#endif
4 changes: 3 additions & 1 deletion cpp/include/raft/common/cub_wrappers.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -24,9 +24,11 @@

#pragma once

#ifndef RAFT_HIDE_DEPRECATION_WARNINGS
#pragma message(__FILE__ \
" is deprecated and will be removed in a future release." \
" Please note that there is no equivalent in RAFT's public API"
" so this file will eventually be removed altogether.")
#endif

#include <raft/util/detail/cub_wrappers.cuh>
4 changes: 3 additions & 1 deletion cpp/include/raft/common/device_loads_stores.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -24,8 +24,10 @@

#pragma once

#ifndef RAFT_HIDE_DEPRECATION_WARNINGS
#pragma message(__FILE__ \
" is deprecated and will be removed in a future release." \
" Please use the raft/util version instead.")
#endif

#include <raft/util/device_loads_stores.cuh>
4 changes: 3 additions & 1 deletion cpp/include/raft/common/scatter.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -24,8 +24,10 @@

#pragma once

#ifndef RAFT_HIDE_DEPRECATION_WARNINGS
#pragma message(__FILE__ \
" is deprecated and will be removed in a future release." \
" Please use the raft/matrix version instead.")
#endif

#include <raft/util/scatter.cuh>
4 changes: 3 additions & 1 deletion cpp/include/raft/common/seive.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -24,8 +24,10 @@

#pragma once

#ifndef RAFT_HIDE_DEPRECATION_WARNINGS
#pragma message(__FILE__ \
" is deprecated and will be removed in a future release." \
" Please use the raft/util version instead.")
#endif

#include <raft/util/seive.hpp>
4 changes: 3 additions & 1 deletion cpp/include/raft/core/detail/logger.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,8 +15,10 @@
*/
#pragma once

#ifndef RAFT_HIDE_DEPRECATION_WARNINGS
#pragma message(__FILE__ \
" is deprecated and will be removed in future releases." \
" Please use the <raft/core/logger.hpp> version instead.")
#endif

#include <raft/core/logger.hpp>
4 changes: 3 additions & 1 deletion cpp/include/raft/distance/specializations.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,8 +15,10 @@
*/
#pragma once

#ifndef RAFT_HIDE_DEPRECATION_WARNINGS
#pragma message( \
__FILE__ \
" is deprecated and will be removed." \
" Including specializations is not necessary any more." \
" For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html")
#endif
4 changes: 3 additions & 1 deletion cpp/include/raft/distance/specializations/distance.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,8 +15,10 @@
*/
#pragma once

#ifndef RAFT_HIDE_DEPRECATION_WARNINGS
#pragma message( \
__FILE__ \
" is deprecated and will be removed." \
" Including specializations is not necessary any more." \
" For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html")
#endif
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,8 +15,10 @@
*/
#pragma once

#ifndef RAFT_HIDE_DEPRECATION_WARNINGS
#pragma message( \
__FILE__ \
" is deprecated and will be removed." \
" Including specializations is not necessary any more." \
" For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html")
#endif
4 changes: 3 additions & 1 deletion cpp/include/raft/lap/lap.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -24,9 +24,11 @@

#pragma once

#ifndef RAFT_HIDE_DEPRECATION_WARNINGS
#pragma message(__FILE__ \
" is deprecated and will be removed in a future release." \
" Please use the raft/solver version instead.")
#endif

#include <raft/solver/linear_assignment.cuh>

Expand Down
4 changes: 3 additions & 1 deletion cpp/include/raft/lap/lap.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -24,8 +24,10 @@

#pragma once

#ifndef RAFT_HIDE_DEPRECATION_WARNINGS
#pragma message(__FILE__ \
" is deprecated and will be removed in a future release." \
" Please use the cuh version instead.")
#endif

#include <raft/solver/linear_assignment.cuh>
117 changes: 95 additions & 22 deletions cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,18 @@ namespace raft {
namespace linalg {
namespace detail {

template <int warpSize, int rpb>
template <int warpSize, int tpb, int rpw, bool noLoop = false>
struct ReductionThinPolicy {
static constexpr int LogicalWarpSize = warpSize;
static constexpr int RowsPerBlock = rpb;
static constexpr int ThreadsPerBlock = LogicalWarpSize * RowsPerBlock;
static_assert(tpb % warpSize == 0);

static constexpr int LogicalWarpSize = warpSize;
static constexpr int ThreadsPerBlock = tpb;
static constexpr int RowsPerLogicalWarp = rpw;
static constexpr int NumLogicalWarps = ThreadsPerBlock / LogicalWarpSize;
static constexpr int RowsPerBlock = NumLogicalWarps * RowsPerLogicalWarp;

// Whether D (run-time arg) will be smaller than warpSize (compile-time parameter)
static constexpr bool NoSequentialReduce = noLoop;
};

template <typename Policy,
Expand All @@ -53,19 +60,72 @@ RAFT_KERNEL __launch_bounds__(Policy::ThreadsPerBlock)
FinalLambda final_op,
bool inplace = false)
{
IdxType i = threadIdx.y + (Policy::RowsPerBlock * static_cast<IdxType>(blockIdx.x));
if (i >= N) return;
/* The strategy to achieve near-SOL memory bandwidth differs based on D:
* - For small D, we need to process multiple rows per logical warp in order to have
* multiple loads per thread and increase bytes in flight and amortize latencies.
* - For large D, we start with a sequential reduction. The compiler partially unrolls
* that loop (e.g. first a loop of stride 16, then 8, 4, and 1).
*/
IdxType i0 = threadIdx.y + (Policy::RowsPerBlock * static_cast<IdxType>(blockIdx.x));
if (i0 >= N) return;

OutType acc = init;
for (IdxType j = threadIdx.x; j < D; j += Policy::LogicalWarpSize) {
acc = reduce_op(acc, main_op(data[j + (D * i)], j));
OutType acc[Policy::RowsPerLogicalWarp];
#pragma unroll
for (int k = 0; k < Policy::RowsPerLogicalWarp; k++) {
acc[k] = init;
}
acc = raft::logicalWarpReduce<Policy::LogicalWarpSize>(acc, reduce_op);
if (threadIdx.x == 0) {

if constexpr (Policy::NoSequentialReduce) {
IdxType j = threadIdx.x;
if (j < D) {
#pragma unroll
for (IdxType k = 0; k < Policy::RowsPerLogicalWarp; k++) {
// Only the first row is known to be within bounds. Clamp to avoid out-of-mem read.
const IdxType i = raft::min(i0 + k * Policy::NumLogicalWarps, N - 1);
acc[k] = reduce_op(acc[k], main_op(data[j + (D * i)], j));
}
}
} else {
for (IdxType j = threadIdx.x; j < D; j += Policy::LogicalWarpSize) {
#pragma unroll
for (IdxType k = 0; k < Policy::RowsPerLogicalWarp; k++) {
const IdxType i = raft::min(i0 + k * Policy::NumLogicalWarps, N - 1);
acc[k] = reduce_op(acc[k], main_op(data[j + (D * i)], j));
}
}
}

/* This vector reduction has two benefits compared to naive separate reductions:
* - It avoids the LSU bottleneck when the number of columns is around 32 (e.g. for 32, 5 shuffles
* are required and there is no initial sequential reduction to amortize that cost).
* - It distributes the outputs to multiple threads, enabling a coalesced store when the number of
* rows per logical warp and logical warp size are equal.
*/
raft::logicalWarpReduceVector<Policy::LogicalWarpSize, Policy::RowsPerLogicalWarp>(
acc, threadIdx.x, reduce_op);

constexpr int reducOutVecWidth =
std::max(1, Policy::RowsPerLogicalWarp / Policy::LogicalWarpSize);
constexpr int reducOutGroupSize =
std::max(1, Policy::LogicalWarpSize / Policy::RowsPerLogicalWarp);
constexpr int reducNumGroups = Policy::LogicalWarpSize / reducOutGroupSize;

if (threadIdx.x % reducOutGroupSize == 0) {
const int groupId = threadIdx.x / reducOutGroupSize;
if (inplace) {
dots[i] = final_op(reduce_op(dots[i], acc));
#pragma unroll
for (int k = 0; k < reducOutVecWidth; k++) {
const int reductionId = k * reducNumGroups + groupId;
const IdxType i = i0 + reductionId * Policy::NumLogicalWarps;
if (i < N) { dots[i] = final_op(reduce_op(dots[i], acc[k])); }
}
} else {
dots[i] = final_op(acc);
#pragma unroll
for (int k = 0; k < reducOutVecWidth; k++) {
const int reductionId = k * reducNumGroups + groupId;
const IdxType i = i0 + reductionId * Policy::NumLogicalWarps;
if (i < N) { dots[i] = final_op(acc[k]); }
}
}
}
}
Expand All @@ -89,8 +149,12 @@ void coalescedReductionThin(OutType* dots,
FinalLambda final_op = raft::identity_op())
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"coalescedReductionThin<%d,%d>", Policy::LogicalWarpSize, Policy::RowsPerBlock);
dim3 threads(Policy::LogicalWarpSize, Policy::RowsPerBlock, 1);
"coalescedReductionThin<%d,%d,%d,%d>",
Policy::LogicalWarpSize,
Policy::ThreadsPerBlock,
Policy::RowsPerLogicalWarp,
static_cast<int>(Policy::NoSequentialReduce));
dim3 threads(Policy::LogicalWarpSize, Policy::NumLogicalWarps, 1);
dim3 blocks(ceildiv<IdxType>(N, Policy::RowsPerBlock), 1, 1);
coalescedReductionThinKernel<Policy>
<<<blocks, threads, 0, stream>>>(dots, data, D, N, init, main_op, reduce_op, final_op, inplace);
Expand All @@ -115,19 +179,28 @@ void coalescedReductionThinDispatcher(OutType* dots,
FinalLambda final_op = raft::identity_op())
{
if (D <= IdxType(2)) {
coalescedReductionThin<ReductionThinPolicy<2, 64>>(
coalescedReductionThin<ReductionThinPolicy<2, 128, 8, true>>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
} else if (D <= IdxType(4)) {
coalescedReductionThin<ReductionThinPolicy<4, 32>>(
coalescedReductionThin<ReductionThinPolicy<4, 128, 8, true>>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
} else if (D <= IdxType(8)) {
coalescedReductionThin<ReductionThinPolicy<8, 16>>(
coalescedReductionThin<ReductionThinPolicy<8, 128, 8, true>>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
} else if (D <= IdxType(16)) {
coalescedReductionThin<ReductionThinPolicy<16, 8>>(
coalescedReductionThin<ReductionThinPolicy<16, 128, 8, true>>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
} else if (D <= IdxType(32)) {
coalescedReductionThin<ReductionThinPolicy<32, 128, 8, true>>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
} else if (D < IdxType(128)) {
coalescedReductionThin<ReductionThinPolicy<32, 128, 4, false>>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
} else {
coalescedReductionThin<ReductionThinPolicy<32, 4>>(
// For D=128 (included) and above, the 4x-unrolled loading loop is used
// and multiple rows per warp are counter-productive in terms of cache-friendliness
// and register use.
coalescedReductionThin<ReductionThinPolicy<32, 128, 1, false>>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
}
}
Expand Down Expand Up @@ -319,10 +392,10 @@ void coalescedReductionThickDispatcher(OutType* dots,
// Note: multiple elements per thread to take advantage of the sequential reduction and loop
// unrolling
if (D < IdxType(32768)) {
coalescedReductionThick<ReductionThickPolicy<256, 32>, ReductionThinPolicy<32, 4>>(
coalescedReductionThick<ReductionThickPolicy<256, 32>, ReductionThinPolicy<32, 128, 1>>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
} else {
coalescedReductionThick<ReductionThickPolicy<256, 64>, ReductionThinPolicy<32, 4>>(
coalescedReductionThick<ReductionThickPolicy<256, 64>, ReductionThinPolicy<32, 128, 1>>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
}
}
Expand Down
2 changes: 2 additions & 0 deletions cpp/include/raft/linalg/detail/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
*/
#pragma once

#ifndef RAFT_HIDE_DEPRECATION_WARNINGS
#pragma message(__FILE__ \
" is deprecated and will be removed in a future release." \
" Use cublaslt_wrappers.hpp if you really need this low-level api.")
#endif

#include "cublaslt_wrappers.hpp"

Expand Down
Loading

0 comments on commit debb2f5

Please sign in to comment.