Skip to content

Commit

Permalink
Support Flash Attention 2.5.6 for disc backend (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
yitongh committed Aug 12, 2024
1 parent e1b8ed7 commit 41baac7
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 113 deletions.
7 changes: 2 additions & 5 deletions test/test_flash_attention_backward.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import sys
import unittest

Expand Down Expand Up @@ -150,15 +149,13 @@ def test_flash_attn_gqa_backward_fp16(self):
self._backward_internal(torch.float16, n_heads_kv=int(N_HEADS // 2))

def test_flash_attn_gqa_backward_bf16(self):
if not os.environ.get('DISC_DEVICE'):
self._backward_internal(torch.bfloat16, n_heads_kv=int(N_HEADS // 2))
self._backward_internal(torch.bfloat16, n_heads_kv=int(N_HEADS // 2))

def test_flash_attn_backward_fp16(self):
self._backward_internal(torch.float16, n_heads_kv=N_HEADS)

def test_flash_attn_backward_bf16(self):
if not os.environ.get('DISC_DEVICE'):
self._backward_internal(torch.bfloat16, n_heads_kv=N_HEADS)
self._backward_internal(torch.bfloat16, n_heads_kv=N_HEADS)

def test_flash_attn_gqa_backward_fp16_alibi(self):
self._backward_internal(
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/flash_attention_forward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ xla::Shape NodeOutputShape(int batch_size, int num_heads, int seqlen_q,
xla::PrimitiveType::F32, {batch_size, num_heads, seqlen_q});
xla::Shape out_shape = GetXlaShape(q);
xla::Shape rng_state_shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::U64, {2});
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S64, {2});
return xla::ShapeUtil::MakeTupleShape(
{softmax_lse_shape, out_shape, rng_state_shape});
}
Expand Down
219 changes: 148 additions & 71 deletions torch_xla/csrc/runtime/disc/custom_call_flash_attention_backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ namespace tao {
namespace ral {

DEFINE_TAO_TYPE_NAME_HELPER(Eigen::half, "f16");
DEFINE_TAO_TYPE_NAME_HELPER(Eigen::bfloat16, "bf16");

struct FlashAttentionBackwardParams {
using index_t = uint32_t;
Expand Down Expand Up @@ -57,6 +56,7 @@ struct FlashAttentionBackwardParams {
// The dimensions.
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;

int total_q;
int total_k;

// The scaling factors for the kernel.
Expand All @@ -73,6 +73,12 @@ struct FlashAttentionBackwardParams {

bool is_bf16;
bool is_causal;
int window_size_left;
int window_size_right;
int alibi_slopes_batch_stride;
bool enable_alibi_slopes;
bool is_seqlens_k_cumulative;
int num_splits;

// Backward specific params
index_t do_batch_stride;
Expand All @@ -88,9 +94,11 @@ struct FlashAttentionBackwardParams {
index_t dk_head_stride;
index_t dv_head_stride;

bool deterministic;

void FromString(const std::string& str) {
std::vector<std::string> params_list = absl::StrSplit(str, "|");
TORCH_CHECK(params_list.size() == 43);
TORCH_CHECK(params_list.size() == 51);

// Forward specific param
absl::SimpleAtoi(params_list[0], &this->q_batch_stride);
Expand All @@ -102,67 +110,61 @@ struct FlashAttentionBackwardParams {
absl::SimpleAtoi(params_list[6], &this->q_head_stride);
absl::SimpleAtoi(params_list[7], &this->k_head_stride);
absl::SimpleAtoi(params_list[8], &this->v_head_stride);
absl::SimpleAtoi(params_list[9], &this->total_k);
absl::SimpleAtoi(params_list[10], &this->h);
absl::SimpleAtoi(params_list[11], &this->h_k);
absl::SimpleAtoi(params_list[12], &this->h_h_k_ratio);
absl::SimpleAtoi(params_list[13], &this->o_batch_stride);
absl::SimpleAtoi(params_list[14], &this->o_row_stride);
absl::SimpleAtoi(params_list[15], &this->o_head_stride);
absl::SimpleAtoi(params_list[16], &this->b);
absl::SimpleAtoi(params_list[17], &this->seqlen_q);
absl::SimpleAtoi(params_list[18], &this->seqlen_k);
absl::SimpleAtoi(params_list[19], &this->d);
absl::SimpleAtoi(params_list[20], &this->seqlen_q_rounded);
absl::SimpleAtoi(params_list[21], &this->seqlen_k_rounded);
absl::SimpleAtoi(params_list[22], &this->d_rounded);
absl::SimpleAtof(params_list[23], &this->scale_softmax);
absl::SimpleAtof(params_list[24], &this->scale_softmax_log2);
absl::SimpleAtof(params_list[25], &this->p_dropout);
absl::SimpleAtoi(params_list[9], &this->total_q);
absl::SimpleAtoi(params_list[10], &this->total_k);
absl::SimpleAtoi(params_list[11], &this->h);
absl::SimpleAtoi(params_list[12], &this->h_k);
absl::SimpleAtoi(params_list[13], &this->h_h_k_ratio);
absl::SimpleAtoi(params_list[14], &this->o_batch_stride);
absl::SimpleAtoi(params_list[15], &this->o_row_stride);
absl::SimpleAtoi(params_list[16], &this->o_head_stride);
absl::SimpleAtoi(params_list[17], &this->b);
absl::SimpleAtoi(params_list[18], &this->seqlen_q);
absl::SimpleAtoi(params_list[19], &this->seqlen_k);
absl::SimpleAtoi(params_list[20], &this->d);
absl::SimpleAtoi(params_list[21], &this->seqlen_q_rounded);
absl::SimpleAtoi(params_list[22], &this->seqlen_k_rounded);
absl::SimpleAtoi(params_list[23], &this->d_rounded);
absl::SimpleAtof(params_list[24], &this->scale_softmax);
absl::SimpleAtof(params_list[25], &this->scale_softmax_log2);
absl::SimpleAtof(params_list[26], &this->p_dropout);
uint32_t tmp;
absl::SimpleAtoi(params_list[26], &tmp);
absl::SimpleAtoi(params_list[27], &tmp);
this->p_dropout_in_uint8_t = uint8_t(tmp);
absl::SimpleAtof(params_list[27], &this->rp_dropout);
absl::SimpleAtof(params_list[28], &this->scale_softmax_rp_dropout);
absl::SimpleAtob(params_list[29], &this->is_bf16);
absl::SimpleAtob(params_list[30], &this->is_causal);
absl::SimpleAtof(params_list[28], &this->rp_dropout);
absl::SimpleAtof(params_list[29], &this->scale_softmax_rp_dropout);
absl::SimpleAtob(params_list[30], &this->is_bf16);
absl::SimpleAtob(params_list[31], &this->is_causal);
absl::SimpleAtoi(params_list[32], &this->window_size_left);
absl::SimpleAtoi(params_list[33], &this->window_size_right);
absl::SimpleAtoi(params_list[34], &this->alibi_slopes_batch_stride);
absl::SimpleAtob(params_list[35], &this->is_seqlens_k_cumulative);
absl::SimpleAtoi(params_list[36], &this->num_splits);
absl::SimpleAtob(params_list[37], &this->enable_alibi_slopes);

// backward specific params
absl::SimpleAtoi(params_list[31], &this->do_batch_stride);
absl::SimpleAtoi(params_list[32], &this->do_row_stride);
absl::SimpleAtoi(params_list[33], &this->do_head_stride);
absl::SimpleAtoi(params_list[34], &this->dq_batch_stride);
absl::SimpleAtoi(params_list[35], &this->dk_batch_stride);
absl::SimpleAtoi(params_list[36], &this->dv_batch_stride);
absl::SimpleAtoi(params_list[37], &this->dq_row_stride);
absl::SimpleAtoi(params_list[38], &this->dk_row_stride);
absl::SimpleAtoi(params_list[39], &this->dv_row_stride);
absl::SimpleAtoi(params_list[40], &this->dq_head_stride);
absl::SimpleAtoi(params_list[41], &this->dk_head_stride);
absl::SimpleAtoi(params_list[42], &this->dv_head_stride);
const int offset = 38; // FlashAttentionForwardParams has 38 variables
absl::SimpleAtoi(params_list[offset + 0], &this->do_batch_stride);
absl::SimpleAtoi(params_list[offset + 1], &this->do_row_stride);
absl::SimpleAtoi(params_list[offset + 2], &this->do_head_stride);
absl::SimpleAtoi(params_list[offset + 3], &this->dq_batch_stride);
absl::SimpleAtoi(params_list[offset + 4], &this->dk_batch_stride);
absl::SimpleAtoi(params_list[offset + 5], &this->dv_batch_stride);
absl::SimpleAtoi(params_list[offset + 6], &this->dq_row_stride);
absl::SimpleAtoi(params_list[offset + 7], &this->dk_row_stride);
absl::SimpleAtoi(params_list[offset + 8], &this->dv_row_stride);
absl::SimpleAtoi(params_list[offset + 9], &this->dq_head_stride);
absl::SimpleAtoi(params_list[offset + 10], &this->dk_head_stride);
absl::SimpleAtoi(params_list[offset + 11], &this->dv_head_stride);
absl::SimpleAtob(params_list[offset + 12], &this->deterministic);
}
};

void run_mha_bwd(Flash_bwd_params& params, cudaStream_t stream,
const bool configure) {
FP16_SWITCH(!params.is_bf16, [&] {
if (params.d <= 32) {
run_mha_bwd_<elem_type, 32>(params, stream, configure);
} else if (params.d <= 64) {
run_mha_bwd_<elem_type, 64>(params, stream, configure);
} else if (params.d <= 96) {
run_mha_bwd_<elem_type, 96>(params, stream, configure);
} else if (params.d <= 128) {
run_mha_bwd_<elem_type, 128>(params, stream, configure);
} else if (params.d <= 160) {
run_mha_bwd_<elem_type, 160>(params, stream, configure);
} else if (params.d <= 192) {
run_mha_bwd_<elem_type, 192>(params, stream, configure);
} else if (params.d <= 224) {
run_mha_bwd_<elem_type, 224>(params, stream, configure);
} else if (params.d <= 256) {
run_mha_bwd_<elem_type, 256>(params, stream, configure);
}
HEADDIM_SWITCH(params.d,
[&] { run_mha_bwd_<elem_type, kHeadDim>(params, stream); });
});
}

Expand All @@ -175,18 +177,21 @@ void run_mha_bwd(Flash_bwd_params& params, cudaStream_t stream,
// buffers[5] = softmax_lse
// buffers[6] = cu_seqlens_q
// buffers[7] = cu_seqlens_k
// buffers[8] = dq // this is output
// buffers[9] = dk // this is output
// buffers[10] = dv // this is output
// buffers[11] = softmax_d // this is output
// buffers[8] = rng_state
// buffers[9] = alibi_slopes
// buffers[10] = dq // this is output
// buffers[11] = dk // this is output
// buffers[12] = dv // this is output
// buffers[13] = softmax_d // this is output
template <typename T_IN, typename SOFT_MAX_TYPE, int M>
std::tuple<MemRefType<T_IN, M>, MemRefType<T_IN, M>, MemRefType<T_IN, M>,
MemRefType<SOFT_MAX_TYPE, M>>
custom_call_flash_attention_backward(
custom_call_flash_attention_backward_impl(
ExecutionContext* ctx, void* stream_handle, MemRefType<T_IN, M> dout,
MemRefType<T_IN, M> q, MemRefType<T_IN, M> k, MemRefType<T_IN, M> v,
MemRefType<T_IN, M> out, MemRefType<SOFT_MAX_TYPE, M> softmax_lse,
MemRefType<int32_t, 1> seqlens_q, MemRefType<int32_t, 1> seqlens_k,
MemRefType<int64_t, 1> rng_state, void* alibi_slopes_ptr,
void* customAttrs) {
auto attr = getOrParsePDLAttr(ctx, customAttrs,
"custom_call_flash_attention_backward");
Expand Down Expand Up @@ -236,7 +241,6 @@ custom_call_flash_attention_backward(
memset(&launch_params, 0, sizeof(launch_params));

launch_params.is_bf16 = params.is_bf16;
launch_params.is_bf16 = true;

// Set the pointers and strides.
launch_params.q_ptr = q.data;
Expand All @@ -256,6 +260,9 @@ custom_call_flash_attention_backward(
launch_params.cu_seqlens_q = static_cast<int*>(seqlens_q.data);
launch_params.cu_seqlens_k = static_cast<int*>(seqlens_k.data);

launch_params.alibi_slopes_ptr = alibi_slopes_ptr;
launch_params.alibi_slopes_batch_stride = params.alibi_slopes_batch_stride;

// P = softmax(QK^T)
launch_params.p_ptr = nullptr; // no softmax returned always

Expand Down Expand Up @@ -284,6 +291,10 @@ custom_call_flash_attention_backward(
launch_params.scale_softmax_rp_dropout = params.scale_softmax_rp_dropout;

launch_params.is_causal = params.is_causal;
launch_params.window_size_left = params.window_size_left;
launch_params.window_size_right = params.window_size_right;

launch_params.is_seqlens_k_cumulative = true;

launch_params.do_ptr = dout.data;
launch_params.do_row_stride = params.do_row_stride;
Expand All @@ -305,10 +316,19 @@ custom_call_flash_attention_backward(
auto opts = torch::TensorOptions().dtype(scalar_type).device(torch::kCUDA);
at::Tensor dq_accum;
if (loop) {
dq_accum =
torch::empty({launch_params.b, launch_params.h,
launch_params.seqlen_q_rounded, launch_params.d_rounded},
opts.dtype(at::kFloat));
if (!params.deterministic) {
dq_accum = torch::empty({params.total_q + 128 * launch_params.b,
launch_params.h, launch_params.d_rounded},
opts.dtype(at::kFloat));
} else {
auto dprops = at::cuda::getCurrentDeviceProperties();
const int nsplits = (dprops->multiProcessorCount +
launch_params.b * launch_params.h - 1) /
(launch_params.b * launch_params.h);
dq_accum = torch::zeros({nsplits, params.total_q + 128 * launch_params.b,
launch_params.h, launch_params.d_rounded},
opts.dtype(at::kFloat));
}
}

at::Tensor dk = torch::from_blob(
Expand Down Expand Up @@ -344,6 +364,10 @@ custom_call_flash_attention_backward(
// Softmax sum
launch_params.dsoftmax_sum = dsoftmax.data;

launch_params.deterministic = params.deterministic;
launch_params.dq_accum_split_stride =
!launch_params.deterministic ? 0 : dq_accum.stride(0);

auto launch = &run_mha_bwd;

auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
Expand All @@ -353,11 +377,11 @@ custom_call_flash_attention_backward(
int64_t counter_offset = launch_params.b * launch_params.h * 32;

bool is_dropout = (1.f - launch_params.p_dropout) > 0.0;
if (is_dropout) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
launch_params.philox_args = gen->philox_cuda_state(counter_offset);
}
// TODO(wenting.swt): According to the implementation in
// `flash_attn_varlen_func` of flash-attn v2.5.6, the forward generates
// `rng_state` which is passed as ctx to the backward. Hence, for simplifying
// the logic, the redundant branch where `rng_state` is None has been omitted.
launch_params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data);

launch(launch_params, gpu_stream, /*configure=*/false);

Expand All @@ -378,12 +402,65 @@ custom_call_flash_attention_backward(
return std::make_tuple(dq_res, dk_res, dv_res, dsoftmax);
}

template <typename T_IN, typename SOFT_MAX_TYPE, int M>
std::tuple<MemRefType<T_IN, M>, MemRefType<T_IN, M>, MemRefType<T_IN, M>,
MemRefType<SOFT_MAX_TYPE, M>>
custom_call_flash_attention_backward_noalibi(
ExecutionContext* ctx, void* stream_handle, MemRefType<T_IN, M> dout,
MemRefType<T_IN, M> q, MemRefType<T_IN, M> k, MemRefType<T_IN, M> v,
MemRefType<T_IN, M> out, MemRefType<SOFT_MAX_TYPE, M> softmax_lse,
MemRefType<int32_t, 1> seqlens_q, MemRefType<int32_t, 1> seqlens_k,
MemRefType<int64_t, 1> rng_state, void* customAttrs) {
return custom_call_flash_attention_backward_impl<T_IN, SOFT_MAX_TYPE, M>(
ctx, stream_handle, dout, q, k, v, out, softmax_lse, seqlens_q, seqlens_k,
rng_state, nullptr, customAttrs);
}

template <typename T_IN, typename SOFT_MAX_TYPE, int M>
std::tuple<MemRefType<T_IN, M>, MemRefType<T_IN, M>, MemRefType<T_IN, M>,
MemRefType<SOFT_MAX_TYPE, M>>
custom_call_flash_attention_backward_alibi_v1(
ExecutionContext* ctx, void* stream_handle, MemRefType<T_IN, M> dout,
MemRefType<T_IN, M> q, MemRefType<T_IN, M> k, MemRefType<T_IN, M> v,
MemRefType<T_IN, M> out, MemRefType<SOFT_MAX_TYPE, M> softmax_lse,
MemRefType<int32_t, 1> seqlens_q, MemRefType<int32_t, 1> seqlens_k,
MemRefType<int64_t, 1> rng_state, MemRefType<float, 1> alibi_slopes,
void* customAttrs) {
return custom_call_flash_attention_backward_impl<T_IN, SOFT_MAX_TYPE, M>(
ctx, stream_handle, dout, q, k, v, out, softmax_lse, seqlens_q, seqlens_k,
rng_state, alibi_slopes.data, customAttrs);
}

template <typename T_IN, typename SOFT_MAX_TYPE, int M>
std::tuple<MemRefType<T_IN, M>, MemRefType<T_IN, M>, MemRefType<T_IN, M>,
MemRefType<SOFT_MAX_TYPE, M>>
custom_call_flash_attention_backward_alibi_v2(
ExecutionContext* ctx, void* stream_handle, MemRefType<T_IN, M> dout,
MemRefType<T_IN, M> q, MemRefType<T_IN, M> k, MemRefType<T_IN, M> v,
MemRefType<T_IN, M> out, MemRefType<SOFT_MAX_TYPE, M> softmax_lse,
MemRefType<int32_t, 1> seqlens_q, MemRefType<int32_t, 1> seqlens_k,
MemRefType<int64_t, 1> rng_state, MemRefType<float, 2> alibi_slopes,
void* customAttrs) {
return custom_call_flash_attention_backward_impl<T_IN, SOFT_MAX_TYPE, M>(
ctx, stream_handle, dout, q, k, v, out, softmax_lse, seqlens_q, seqlens_k,
rng_state, alibi_slopes.data, customAttrs);
}

TAO_RAL_API(
"custom_call_flash_attention_backward", "gpu",
custom_call_flash_attention_backward_noalibi<Eigen::half, float, 3>);
TAO_RAL_API(
"custom_call_flash_attention_backward", "gpu",
custom_call_flash_attention_backward_alibi_v1<Eigen::half, float, 3>);
TAO_RAL_API(
"custom_call_flash_attention_backward", "gpu",
custom_call_flash_attention_backward_alibi_v2<Eigen::half, float, 3>);
TAO_RAL_API("custom_call_flash_attention_backward", "gpu",
custom_call_flash_attention_backward<float, float, 3>);
custom_call_flash_attention_backward_noalibi<bfloat16, float, 3>);
TAO_RAL_API("custom_call_flash_attention_backward", "gpu",
custom_call_flash_attention_backward<Eigen::half, float, 3>);
custom_call_flash_attention_backward_alibi_v1<bfloat16, float, 3>);
TAO_RAL_API("custom_call_flash_attention_backward", "gpu",
custom_call_flash_attention_backward<Eigen::bfloat16, float, 3>);
custom_call_flash_attention_backward_alibi_v2<bfloat16, float, 3>);

} // namespace ral
} // namespace tao
Loading

0 comments on commit 41baac7

Please sign in to comment.