Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FlashAttention CUDA fixes #6438

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <mma.h>

#define FATTN_KQ_STRIDE 256
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.

template<int D, int parallel_blocks> // D == head size
__launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1)
Expand Down Expand Up @@ -59,13 +60,13 @@ static __global__ void flash_attn_vec_ext_f16(
KQ[tid] = -INFINITY;
half2 * KQ2 = (half2 *) KQ;

half kqmax = -INFINITY;
half kqmax = -HALF_MAX_HALF;
half kqsum = 0.0f;

__shared__ half kqmax_shared[WARP_SIZE];
__shared__ half kqsum_shared[WARP_SIZE];
if (threadIdx.y == 0) {
kqmax_shared[threadIdx.x] = -INFINITY;
kqmax_shared[threadIdx.x] = -HALF_MAX_HALF;
kqsum_shared[threadIdx.x] = 0.0f;
}
__syncthreads();
Expand Down Expand Up @@ -139,7 +140,7 @@ static __global__ void flash_attn_vec_ext_f16(
if (tid < D) {
#pragma unroll
for (int k0 = 0; k0 < D; k0 += 2) {
if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) {
if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
break;
}

Expand Down Expand Up @@ -253,9 +254,9 @@ static __global__ void flash_attn_ext_f16(
__shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
half2 * KQ2 = (half2 *) KQ;

half2 KQ_rowsum[ncols/nwarps] = {{0.0f, 0.0f}};
half2 KQ_max[ncols/nwarps] = {{-INFINITY, -INFINITY}};
half2 KQ_max_scale[ncols/nwarps] = {{0.0f, 0.0f}};
half2 KQ_rowsum[ncols/nwarps] = {{ 0.0f, 0.0f}};
half2 KQ_max[ncols/nwarps] = {{-HALF_MAX_HALF, -HALF_MAX_HALF}};
half2 KQ_max_scale[ncols/nwarps] = {{ 0.0f, 0.0f}};

__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
half2 * VKQ2 = (half2 *) VKQ;
Expand Down Expand Up @@ -578,6 +579,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");

GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");

ggml_cuda_set_device(ctx.device);

const cudaStream_t main_stream = ctx.stream();
Expand Down
4 changes: 2 additions & 2 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9973,7 +9973,7 @@ static int llama_decode_internal(
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
kv_self.n = std::min(kv_self.size, std::max(128u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128)));
kv_self.n = std::min(kv_self.size, std::max(256u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 256)));
//kv_self.n = llama_kv_cache_cell_max(kv_self);
}
}
Expand Down Expand Up @@ -13909,7 +13909,7 @@ struct llama_context * llama_new_context_with_model(
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;

// this is necessary due to kv_self.n being padded later during inference
cparams.n_ctx = GGML_PAD(cparams.n_ctx, 32);
cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be padding to the head size. For example, Phi-2 has HS=80 which is not a multiple of 256

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, (for CUDA) it's better to pad to a power of 2 and the current CUDA implementation expects this padding to work correctly. The matrix shapes are VKQ = (D, BS) = V * KQ = (D, PAD) * (PAD, BS) = V * K * Q = (D, PAD) * (PAD, D) * (D, BS). Initially I did an implementation with PAD == D and warp counts scaling with D but that leads to issues with warp counts being e.g. 5 for head size 80. With the current implementation you get a better shape for KQ in terms of both the calculation and the softmax. For the calculation of VKQ awkward D values are treated via > 1 accumulators for VKQ parts. If the padding was to a multiple of D the kernel would need to deal with awkward shapes for all parts of the calculation instead of just a single part. There is also the consideration that exponential functions are very expensive and the smaller you choose D the more of them do you have to calculate.


// with causal attention, the batch size is limited by the context size
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
Expand Down
Loading