diff --git a/test/test_flash_attention_backward.py b/test/test_flash_attention_backward.py index df8c15efc81..35b8a7c37fa 100755 --- a/test/test_flash_attention_backward.py +++ b/test/test_flash_attention_backward.py @@ -1,4 +1,3 @@ -import os import sys import unittest @@ -150,15 +149,13 @@ def test_flash_attn_gqa_backward_fp16(self): self._backward_internal(torch.float16, n_heads_kv=int(N_HEADS // 2)) def test_flash_attn_gqa_backward_bf16(self): - if not os.environ.get('DISC_DEVICE'): - self._backward_internal(torch.bfloat16, n_heads_kv=int(N_HEADS // 2)) + self._backward_internal(torch.bfloat16, n_heads_kv=int(N_HEADS // 2)) def test_flash_attn_backward_fp16(self): self._backward_internal(torch.float16, n_heads_kv=N_HEADS) def test_flash_attn_backward_bf16(self): - if not os.environ.get('DISC_DEVICE'): - self._backward_internal(torch.bfloat16, n_heads_kv=N_HEADS) + self._backward_internal(torch.bfloat16, n_heads_kv=N_HEADS) def test_flash_attn_gqa_backward_fp16_alibi(self): self._backward_internal( diff --git a/torch_xla/csrc/ops/flash_attention_forward.cpp b/torch_xla/csrc/ops/flash_attention_forward.cpp index 5c478f69a51..9a73f26a9ba 100644 --- a/torch_xla/csrc/ops/flash_attention_forward.cpp +++ b/torch_xla/csrc/ops/flash_attention_forward.cpp @@ -23,7 +23,7 @@ xla::Shape NodeOutputShape(int batch_size, int num_heads, int seqlen_q, xla::PrimitiveType::F32, {batch_size, num_heads, seqlen_q}); xla::Shape out_shape = GetXlaShape(q); xla::Shape rng_state_shape = - xla::ShapeUtil::MakeShape(xla::PrimitiveType::U64, {2}); + xla::ShapeUtil::MakeShape(xla::PrimitiveType::S64, {2}); return xla::ShapeUtil::MakeTupleShape( {softmax_lse_shape, out_shape, rng_state_shape}); } diff --git a/torch_xla/csrc/runtime/disc/custom_call_flash_attention_backward.cc b/torch_xla/csrc/runtime/disc/custom_call_flash_attention_backward.cc index 8f4460d8145..402eeedc3de 100644 --- a/torch_xla/csrc/runtime/disc/custom_call_flash_attention_backward.cc +++ b/torch_xla/csrc/runtime/disc/custom_call_flash_attention_backward.cc @@ -27,7 +27,6 @@ namespace tao { namespace ral { DEFINE_TAO_TYPE_NAME_HELPER(Eigen::half, "f16"); -DEFINE_TAO_TYPE_NAME_HELPER(Eigen::bfloat16, "bf16"); struct FlashAttentionBackwardParams { using index_t = uint32_t; @@ -57,6 +56,7 @@ struct FlashAttentionBackwardParams { // The dimensions. int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; + int total_q; int total_k; // The scaling factors for the kernel. @@ -73,6 +73,12 @@ struct FlashAttentionBackwardParams { bool is_bf16; bool is_causal; + int window_size_left; + int window_size_right; + int alibi_slopes_batch_stride; + bool enable_alibi_slopes; + bool is_seqlens_k_cumulative; + int num_splits; // Backward specific params index_t do_batch_stride; @@ -88,9 +94,11 @@ struct FlashAttentionBackwardParams { index_t dk_head_stride; index_t dv_head_stride; + bool deterministic; + void FromString(const std::string& str) { std::vector params_list = absl::StrSplit(str, "|"); - TORCH_CHECK(params_list.size() == 43); + TORCH_CHECK(params_list.size() == 51); // Forward specific param absl::SimpleAtoi(params_list[0], &this->q_batch_stride); @@ -102,67 +110,61 @@ struct FlashAttentionBackwardParams { absl::SimpleAtoi(params_list[6], &this->q_head_stride); absl::SimpleAtoi(params_list[7], &this->k_head_stride); absl::SimpleAtoi(params_list[8], &this->v_head_stride); - absl::SimpleAtoi(params_list[9], &this->total_k); - absl::SimpleAtoi(params_list[10], &this->h); - absl::SimpleAtoi(params_list[11], &this->h_k); - absl::SimpleAtoi(params_list[12], &this->h_h_k_ratio); - absl::SimpleAtoi(params_list[13], &this->o_batch_stride); - absl::SimpleAtoi(params_list[14], &this->o_row_stride); - absl::SimpleAtoi(params_list[15], &this->o_head_stride); - absl::SimpleAtoi(params_list[16], &this->b); - absl::SimpleAtoi(params_list[17], &this->seqlen_q); - absl::SimpleAtoi(params_list[18], &this->seqlen_k); - absl::SimpleAtoi(params_list[19], &this->d); - absl::SimpleAtoi(params_list[20], &this->seqlen_q_rounded); - absl::SimpleAtoi(params_list[21], &this->seqlen_k_rounded); - absl::SimpleAtoi(params_list[22], &this->d_rounded); - absl::SimpleAtof(params_list[23], &this->scale_softmax); - absl::SimpleAtof(params_list[24], &this->scale_softmax_log2); - absl::SimpleAtof(params_list[25], &this->p_dropout); + absl::SimpleAtoi(params_list[9], &this->total_q); + absl::SimpleAtoi(params_list[10], &this->total_k); + absl::SimpleAtoi(params_list[11], &this->h); + absl::SimpleAtoi(params_list[12], &this->h_k); + absl::SimpleAtoi(params_list[13], &this->h_h_k_ratio); + absl::SimpleAtoi(params_list[14], &this->o_batch_stride); + absl::SimpleAtoi(params_list[15], &this->o_row_stride); + absl::SimpleAtoi(params_list[16], &this->o_head_stride); + absl::SimpleAtoi(params_list[17], &this->b); + absl::SimpleAtoi(params_list[18], &this->seqlen_q); + absl::SimpleAtoi(params_list[19], &this->seqlen_k); + absl::SimpleAtoi(params_list[20], &this->d); + absl::SimpleAtoi(params_list[21], &this->seqlen_q_rounded); + absl::SimpleAtoi(params_list[22], &this->seqlen_k_rounded); + absl::SimpleAtoi(params_list[23], &this->d_rounded); + absl::SimpleAtof(params_list[24], &this->scale_softmax); + absl::SimpleAtof(params_list[25], &this->scale_softmax_log2); + absl::SimpleAtof(params_list[26], &this->p_dropout); uint32_t tmp; - absl::SimpleAtoi(params_list[26], &tmp); + absl::SimpleAtoi(params_list[27], &tmp); this->p_dropout_in_uint8_t = uint8_t(tmp); - absl::SimpleAtof(params_list[27], &this->rp_dropout); - absl::SimpleAtof(params_list[28], &this->scale_softmax_rp_dropout); - absl::SimpleAtob(params_list[29], &this->is_bf16); - absl::SimpleAtob(params_list[30], &this->is_causal); + absl::SimpleAtof(params_list[28], &this->rp_dropout); + absl::SimpleAtof(params_list[29], &this->scale_softmax_rp_dropout); + absl::SimpleAtob(params_list[30], &this->is_bf16); + absl::SimpleAtob(params_list[31], &this->is_causal); + absl::SimpleAtoi(params_list[32], &this->window_size_left); + absl::SimpleAtoi(params_list[33], &this->window_size_right); + absl::SimpleAtoi(params_list[34], &this->alibi_slopes_batch_stride); + absl::SimpleAtob(params_list[35], &this->is_seqlens_k_cumulative); + absl::SimpleAtoi(params_list[36], &this->num_splits); + absl::SimpleAtob(params_list[37], &this->enable_alibi_slopes); // backward specific params - absl::SimpleAtoi(params_list[31], &this->do_batch_stride); - absl::SimpleAtoi(params_list[32], &this->do_row_stride); - absl::SimpleAtoi(params_list[33], &this->do_head_stride); - absl::SimpleAtoi(params_list[34], &this->dq_batch_stride); - absl::SimpleAtoi(params_list[35], &this->dk_batch_stride); - absl::SimpleAtoi(params_list[36], &this->dv_batch_stride); - absl::SimpleAtoi(params_list[37], &this->dq_row_stride); - absl::SimpleAtoi(params_list[38], &this->dk_row_stride); - absl::SimpleAtoi(params_list[39], &this->dv_row_stride); - absl::SimpleAtoi(params_list[40], &this->dq_head_stride); - absl::SimpleAtoi(params_list[41], &this->dk_head_stride); - absl::SimpleAtoi(params_list[42], &this->dv_head_stride); + const int offset = 38; // FlashAttentionForwardParams has 38 variables + absl::SimpleAtoi(params_list[offset + 0], &this->do_batch_stride); + absl::SimpleAtoi(params_list[offset + 1], &this->do_row_stride); + absl::SimpleAtoi(params_list[offset + 2], &this->do_head_stride); + absl::SimpleAtoi(params_list[offset + 3], &this->dq_batch_stride); + absl::SimpleAtoi(params_list[offset + 4], &this->dk_batch_stride); + absl::SimpleAtoi(params_list[offset + 5], &this->dv_batch_stride); + absl::SimpleAtoi(params_list[offset + 6], &this->dq_row_stride); + absl::SimpleAtoi(params_list[offset + 7], &this->dk_row_stride); + absl::SimpleAtoi(params_list[offset + 8], &this->dv_row_stride); + absl::SimpleAtoi(params_list[offset + 9], &this->dq_head_stride); + absl::SimpleAtoi(params_list[offset + 10], &this->dk_head_stride); + absl::SimpleAtoi(params_list[offset + 11], &this->dv_head_stride); + absl::SimpleAtob(params_list[offset + 12], &this->deterministic); } }; void run_mha_bwd(Flash_bwd_params& params, cudaStream_t stream, const bool configure) { FP16_SWITCH(!params.is_bf16, [&] { - if (params.d <= 32) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 64) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 96) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 128) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 160) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 192) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 224) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 256) { - run_mha_bwd_(params, stream, configure); - } + HEADDIM_SWITCH(params.d, + [&] { run_mha_bwd_(params, stream); }); }); } @@ -175,18 +177,21 @@ void run_mha_bwd(Flash_bwd_params& params, cudaStream_t stream, // buffers[5] = softmax_lse // buffers[6] = cu_seqlens_q // buffers[7] = cu_seqlens_k -// buffers[8] = dq // this is output -// buffers[9] = dk // this is output -// buffers[10] = dv // this is output -// buffers[11] = softmax_d // this is output +// buffers[8] = rng_state +// buffers[9] = alibi_slopes +// buffers[10] = dq // this is output +// buffers[11] = dk // this is output +// buffers[12] = dv // this is output +// buffers[13] = softmax_d // this is output template std::tuple, MemRefType, MemRefType, MemRefType> -custom_call_flash_attention_backward( +custom_call_flash_attention_backward_impl( ExecutionContext* ctx, void* stream_handle, MemRefType dout, MemRefType q, MemRefType k, MemRefType v, MemRefType out, MemRefType softmax_lse, MemRefType seqlens_q, MemRefType seqlens_k, + MemRefType rng_state, void* alibi_slopes_ptr, void* customAttrs) { auto attr = getOrParsePDLAttr(ctx, customAttrs, "custom_call_flash_attention_backward"); @@ -236,7 +241,6 @@ custom_call_flash_attention_backward( memset(&launch_params, 0, sizeof(launch_params)); launch_params.is_bf16 = params.is_bf16; - launch_params.is_bf16 = true; // Set the pointers and strides. launch_params.q_ptr = q.data; @@ -256,6 +260,9 @@ custom_call_flash_attention_backward( launch_params.cu_seqlens_q = static_cast(seqlens_q.data); launch_params.cu_seqlens_k = static_cast(seqlens_k.data); + launch_params.alibi_slopes_ptr = alibi_slopes_ptr; + launch_params.alibi_slopes_batch_stride = params.alibi_slopes_batch_stride; + // P = softmax(QK^T) launch_params.p_ptr = nullptr; // no softmax returned always @@ -284,6 +291,10 @@ custom_call_flash_attention_backward( launch_params.scale_softmax_rp_dropout = params.scale_softmax_rp_dropout; launch_params.is_causal = params.is_causal; + launch_params.window_size_left = params.window_size_left; + launch_params.window_size_right = params.window_size_right; + + launch_params.is_seqlens_k_cumulative = true; launch_params.do_ptr = dout.data; launch_params.do_row_stride = params.do_row_stride; @@ -305,10 +316,19 @@ custom_call_flash_attention_backward( auto opts = torch::TensorOptions().dtype(scalar_type).device(torch::kCUDA); at::Tensor dq_accum; if (loop) { - dq_accum = - torch::empty({launch_params.b, launch_params.h, - launch_params.seqlen_q_rounded, launch_params.d_rounded}, - opts.dtype(at::kFloat)); + if (!params.deterministic) { + dq_accum = torch::empty({params.total_q + 128 * launch_params.b, + launch_params.h, launch_params.d_rounded}, + opts.dtype(at::kFloat)); + } else { + auto dprops = at::cuda::getCurrentDeviceProperties(); + const int nsplits = (dprops->multiProcessorCount + + launch_params.b * launch_params.h - 1) / + (launch_params.b * launch_params.h); + dq_accum = torch::zeros({nsplits, params.total_q + 128 * launch_params.b, + launch_params.h, launch_params.d_rounded}, + opts.dtype(at::kFloat)); + } } at::Tensor dk = torch::from_blob( @@ -344,6 +364,10 @@ custom_call_flash_attention_backward( // Softmax sum launch_params.dsoftmax_sum = dsoftmax.data; + launch_params.deterministic = params.deterministic; + launch_params.dq_accum_split_stride = + !launch_params.deterministic ? 0 : dq_accum.stride(0); + auto launch = &run_mha_bwd; auto gen = at::get_generator_or_default( @@ -353,11 +377,11 @@ custom_call_flash_attention_backward( int64_t counter_offset = launch_params.b * launch_params.h * 32; bool is_dropout = (1.f - launch_params.p_dropout) > 0.0; - if (is_dropout) { - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - launch_params.philox_args = gen->philox_cuda_state(counter_offset); - } + // TODO(wenting.swt): According to the implementation in + // `flash_attn_varlen_func` of flash-attn v2.5.6, the forward generates + // `rng_state` which is passed as ctx to the backward. Hence, for simplifying + // the logic, the redundant branch where `rng_state` is None has been omitted. + launch_params.rng_state = reinterpret_cast(rng_state.data); launch(launch_params, gpu_stream, /*configure=*/false); @@ -378,12 +402,65 @@ custom_call_flash_attention_backward( return std::make_tuple(dq_res, dk_res, dv_res, dsoftmax); } +template +std::tuple, MemRefType, MemRefType, + MemRefType> +custom_call_flash_attention_backward_noalibi( + ExecutionContext* ctx, void* stream_handle, MemRefType dout, + MemRefType q, MemRefType k, MemRefType v, + MemRefType out, MemRefType softmax_lse, + MemRefType seqlens_q, MemRefType seqlens_k, + MemRefType rng_state, void* customAttrs) { + return custom_call_flash_attention_backward_impl( + ctx, stream_handle, dout, q, k, v, out, softmax_lse, seqlens_q, seqlens_k, + rng_state, nullptr, customAttrs); +} + +template +std::tuple, MemRefType, MemRefType, + MemRefType> +custom_call_flash_attention_backward_alibi_v1( + ExecutionContext* ctx, void* stream_handle, MemRefType dout, + MemRefType q, MemRefType k, MemRefType v, + MemRefType out, MemRefType softmax_lse, + MemRefType seqlens_q, MemRefType seqlens_k, + MemRefType rng_state, MemRefType alibi_slopes, + void* customAttrs) { + return custom_call_flash_attention_backward_impl( + ctx, stream_handle, dout, q, k, v, out, softmax_lse, seqlens_q, seqlens_k, + rng_state, alibi_slopes.data, customAttrs); +} + +template +std::tuple, MemRefType, MemRefType, + MemRefType> +custom_call_flash_attention_backward_alibi_v2( + ExecutionContext* ctx, void* stream_handle, MemRefType dout, + MemRefType q, MemRefType k, MemRefType v, + MemRefType out, MemRefType softmax_lse, + MemRefType seqlens_q, MemRefType seqlens_k, + MemRefType rng_state, MemRefType alibi_slopes, + void* customAttrs) { + return custom_call_flash_attention_backward_impl( + ctx, stream_handle, dout, q, k, v, out, softmax_lse, seqlens_q, seqlens_k, + rng_state, alibi_slopes.data, customAttrs); +} + +TAO_RAL_API( + "custom_call_flash_attention_backward", "gpu", + custom_call_flash_attention_backward_noalibi); +TAO_RAL_API( + "custom_call_flash_attention_backward", "gpu", + custom_call_flash_attention_backward_alibi_v1); +TAO_RAL_API( + "custom_call_flash_attention_backward", "gpu", + custom_call_flash_attention_backward_alibi_v2); TAO_RAL_API("custom_call_flash_attention_backward", "gpu", - custom_call_flash_attention_backward); + custom_call_flash_attention_backward_noalibi); TAO_RAL_API("custom_call_flash_attention_backward", "gpu", - custom_call_flash_attention_backward); + custom_call_flash_attention_backward_alibi_v1); TAO_RAL_API("custom_call_flash_attention_backward", "gpu", - custom_call_flash_attention_backward); + custom_call_flash_attention_backward_alibi_v2); } // namespace ral } // namespace tao \ No newline at end of file diff --git a/torch_xla/csrc/runtime/disc/custom_call_flash_attention_forward.cc b/torch_xla/csrc/runtime/disc/custom_call_flash_attention_forward.cc index ca281319b85..fcac32fa5c3 100644 --- a/torch_xla/csrc/runtime/disc/custom_call_flash_attention_forward.cc +++ b/torch_xla/csrc/runtime/disc/custom_call_flash_attention_forward.cc @@ -57,6 +57,7 @@ struct FlashAttentionForwardParams { // The dimensions. int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; + int total_q; int total_k; // The scaling factors for the kernel. @@ -73,10 +74,16 @@ struct FlashAttentionForwardParams { bool is_bf16; bool is_causal; + int window_size_left; + int window_size_right; + int alibi_slopes_batch_stride; + bool enable_alibi_slopes; + bool is_seqlens_k_cumulative; + int num_splits; void FromString(const std::string& str) { std::vector params_list = absl::StrSplit(str, "|"); - TORCH_CHECK(params_list.size() >= 31); // at least 31 variables + TORCH_CHECK(params_list.size() >= 38); // at least 38 variables absl::SimpleAtoi(params_list[0], &this->q_batch_stride); absl::SimpleAtoi(params_list[1], &this->k_batch_stride); absl::SimpleAtoi(params_list[2], &this->v_batch_stride); @@ -86,30 +93,37 @@ struct FlashAttentionForwardParams { absl::SimpleAtoi(params_list[6], &this->q_head_stride); absl::SimpleAtoi(params_list[7], &this->k_head_stride); absl::SimpleAtoi(params_list[8], &this->v_head_stride); - absl::SimpleAtoi(params_list[9], &this->total_k); - absl::SimpleAtoi(params_list[10], &this->h); - absl::SimpleAtoi(params_list[11], &this->h_k); - absl::SimpleAtoi(params_list[12], &this->h_h_k_ratio); - absl::SimpleAtoi(params_list[13], &this->o_batch_stride); - absl::SimpleAtoi(params_list[14], &this->o_row_stride); - absl::SimpleAtoi(params_list[15], &this->o_head_stride); - absl::SimpleAtoi(params_list[16], &this->b); - absl::SimpleAtoi(params_list[17], &this->seqlen_q); - absl::SimpleAtoi(params_list[18], &this->seqlen_k); - absl::SimpleAtoi(params_list[19], &this->d); - absl::SimpleAtoi(params_list[20], &this->seqlen_q_rounded); - absl::SimpleAtoi(params_list[21], &this->seqlen_k_rounded); - absl::SimpleAtoi(params_list[22], &this->d_rounded); - absl::SimpleAtof(params_list[23], &this->scale_softmax); - absl::SimpleAtof(params_list[24], &this->scale_softmax_log2); - absl::SimpleAtof(params_list[25], &this->p_dropout); + absl::SimpleAtoi(params_list[9], &this->total_q); + absl::SimpleAtoi(params_list[10], &this->total_k); + absl::SimpleAtoi(params_list[11], &this->h); + absl::SimpleAtoi(params_list[12], &this->h_k); + absl::SimpleAtoi(params_list[13], &this->h_h_k_ratio); + absl::SimpleAtoi(params_list[14], &this->o_batch_stride); + absl::SimpleAtoi(params_list[15], &this->o_row_stride); + absl::SimpleAtoi(params_list[16], &this->o_head_stride); + absl::SimpleAtoi(params_list[17], &this->b); + absl::SimpleAtoi(params_list[18], &this->seqlen_q); + absl::SimpleAtoi(params_list[19], &this->seqlen_k); + absl::SimpleAtoi(params_list[20], &this->d); + absl::SimpleAtoi(params_list[21], &this->seqlen_q_rounded); + absl::SimpleAtoi(params_list[22], &this->seqlen_k_rounded); + absl::SimpleAtoi(params_list[23], &this->d_rounded); + absl::SimpleAtof(params_list[24], &this->scale_softmax); + absl::SimpleAtof(params_list[25], &this->scale_softmax_log2); + absl::SimpleAtof(params_list[26], &this->p_dropout); uint32_t tmp; - absl::SimpleAtoi(params_list[26], &tmp); + absl::SimpleAtoi(params_list[27], &tmp); this->p_dropout_in_uint8_t = uint8_t(tmp); - absl::SimpleAtof(params_list[27], &this->rp_dropout); - absl::SimpleAtof(params_list[28], &this->scale_softmax_rp_dropout); - absl::SimpleAtob(params_list[29], &this->is_bf16); - absl::SimpleAtob(params_list[30], &this->is_causal); + absl::SimpleAtof(params_list[28], &this->rp_dropout); + absl::SimpleAtof(params_list[29], &this->scale_softmax_rp_dropout); + absl::SimpleAtob(params_list[30], &this->is_bf16); + absl::SimpleAtob(params_list[31], &this->is_causal); + absl::SimpleAtoi(params_list[32], &this->window_size_left); + absl::SimpleAtoi(params_list[33], &this->window_size_right); + absl::SimpleAtoi(params_list[34], &this->alibi_slopes_batch_stride); + absl::SimpleAtob(params_list[35], &this->is_seqlens_k_cumulative); + absl::SimpleAtoi(params_list[36], &this->num_splits); + absl::SimpleAtob(params_list[37], &this->enable_alibi_slopes); } }; @@ -122,14 +136,13 @@ struct FlashAttentionForwardParams { // result[0] = softmax_lse // this is output // result[1] = out_for_output // this is output template -std::tuple, MemRefType> -custom_call_flash_attention_forward(ExecutionContext* ctx, void* stream_handle, - MemRefType q, - MemRefType k, - MemRefType v, - MemRefType seqlens_q, - MemRefType seqlens_k, - void* customAttrs) { +std::tuple, MemRefType, + MemRefType> +custom_call_flash_attention_forward_impl( + ExecutionContext* ctx, void* stream_handle, MemRefType q, + MemRefType k, MemRefType v, + MemRefType seqlens_q, MemRefType seqlens_k, + void* alibi_slopes_ptr, void* customAttrs) { auto attr = getOrParsePDLAttr(ctx, customAttrs, "custom_call_flash_attention_forward"); if (!attr) { @@ -163,6 +176,13 @@ custom_call_flash_attention_forward(ExecutionContext* ctx, void* stream_handle, gpu_driver->alloc(ctx, output_element_count * sizeof(T_IN))); auto output = assignMemRef(output_ptr, q.sizes); + auto rng_state_ptr = + static_cast(gpu_driver->alloc(ctx, 2 * sizeof(int64_t))); + auto rng_state = + assignMemRef(rng_state_ptr, std::vector{2}); + + cudaMemsetAsync(rng_state_ptr, 0, 2 * sizeof(int64_t), gpu_stream); + FlashAttentionForwardParams params; params.FromString(std::move(backend_config)); @@ -190,6 +210,8 @@ custom_call_flash_attention_forward(ExecutionContext* ctx, void* stream_handle, launch_params.cu_seqlens_q = seqlens_q.data; launch_params.cu_seqlens_k = seqlens_k.data; + launch_params.alibi_slopes_ptr = alibi_slopes_ptr; + launch_params.alibi_slopes_batch_stride = params.alibi_slopes_batch_stride; // P = softmax(QK^T) launch_params.p_ptr = nullptr; // no softmax returned always @@ -219,6 +241,16 @@ custom_call_flash_attention_forward(ExecutionContext* ctx, void* stream_handle, launch_params.scale_softmax_rp_dropout = params.scale_softmax_rp_dropout; launch_params.is_causal = params.is_causal; + launch_params.window_size_left = params.window_size_left; + launch_params.window_size_right = params.window_size_right; + + launch_params.is_seqlens_k_cumulative = params.is_seqlens_k_cumulative; + + // set params splitkv + launch_params.num_splits = params.num_splits; + + // Forward kernel will populate memory with the seed and offset. + launch_params.rng_state = reinterpret_cast(rng_state_ptr); if ((1.f - launch_params.p_dropout) > 0.0) { // number of times random will be generated per thread, to offset philox @@ -233,20 +265,67 @@ custom_call_flash_attention_forward(ExecutionContext* ctx, void* stream_handle, } FP16_SWITCH(!launch_params.is_bf16, [&] { - FWD_HEADDIM_SWITCH(launch_params.d, [&] { + HEADDIM_SWITCH(launch_params.d, [&] { + // TODO(wenting.swt): support split_kv run_mha_fwd_(launch_params, gpu_stream); }); }); - return std::make_tuple(softmax_lse, output); + return std::make_tuple(softmax_lse, output, rng_state); +} + +template +std::tuple, MemRefType, + MemRefType> +custom_call_flash_attention_forward_noalibi( + ExecutionContext* ctx, void* stream_handle, MemRefType q, + MemRefType k, MemRefType v, + MemRefType seqlens_q, MemRefType seqlens_k, + void* customAttrs) { + return custom_call_flash_attention_forward_impl( + ctx, stream_handle, q, k, v, seqlens_q, seqlens_k, nullptr, customAttrs); } +template +std::tuple, MemRefType, + MemRefType> +custom_call_flash_attention_forward_alibi_v1( + ExecutionContext* ctx, void* stream_handle, MemRefType q, + MemRefType k, MemRefType v, + MemRefType seqlens_q, MemRefType seqlens_k, + MemRefType alibi_slopes, void* customAttrs) { + return custom_call_flash_attention_forward_impl( + ctx, stream_handle, q, k, v, seqlens_q, seqlens_k, alibi_slopes.data, + customAttrs); +} + +template +std::tuple, MemRefType, + MemRefType> +custom_call_flash_attention_forward_alibi_v2( + ExecutionContext* ctx, void* stream_handle, MemRefType q, + MemRefType k, MemRefType v, + MemRefType seqlens_q, MemRefType seqlens_k, + MemRefType alibi_slopes, void* customAttrs) { + return custom_call_flash_attention_forward_impl( + ctx, stream_handle, q, k, v, seqlens_q, seqlens_k, alibi_slopes.data, + customAttrs); +} + +TAO_RAL_API("custom_call_flash_attention_forward", "gpu", + custom_call_flash_attention_forward_noalibi); +TAO_RAL_API( + "custom_call_flash_attention_forward", "gpu", + custom_call_flash_attention_forward_alibi_v1); +TAO_RAL_API( + "custom_call_flash_attention_forward", "gpu", + custom_call_flash_attention_forward_alibi_v2); TAO_RAL_API("custom_call_flash_attention_forward", "gpu", - custom_call_flash_attention_forward); + custom_call_flash_attention_forward_noalibi); TAO_RAL_API("custom_call_flash_attention_forward", "gpu", - custom_call_flash_attention_forward); + custom_call_flash_attention_forward_alibi_v1); TAO_RAL_API("custom_call_flash_attention_forward", "gpu", - custom_call_flash_attention_forward); + custom_call_flash_attention_forward_alibi_v2); } // namespace ral } // namespace tao