Skip to content

Commit

Permalink
optimize skip block calculate in bwd
Browse files Browse the repository at this point in the history
  • Loading branch information
GuoxiaWang committed Aug 28, 2024
1 parent c9515ed commit 598f152
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 30 deletions.
88 changes: 62 additions & 26 deletions csrc/flash_attn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,

template <int THREADS_PER_ROW, typename Engine0, typename Layout0,
typename Engine1, typename Layout1, typename Engine2, typename Layout2>
inline __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engine0, Layout0> const &o,
__forceinline__ __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engine0, Layout0> const &o,
Tensor<Engine1, Layout1> &dP_sum, Tensor<Engine2, Layout2> &sdPsum,
const int gdP_col_stride, const float scale) {
static_assert(Layout0::rank == 3, "Only support 3D Tensor");
Expand Down Expand Up @@ -425,7 +425,7 @@ inline __device__ void convert_dKV(const Params &params) {
////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Is_attn_mask, bool Seq_parallel=false, typename Params>
inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {
__forceinline__ __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {

const bool Is_sparse_attn_mask = params.flashmask_downstart_ptr != nullptr;
int flashmask_startrow = 0;
Expand Down Expand Up @@ -488,9 +488,32 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
const bool flashmask_has_end = params.flashmask_downend_ptr != nullptr;
int flashmask_upendrow = params.seqlen_q;

#define SPARSE_MASKED_DOWN \
(((m_block * kBlockM) >= flashmask_downstartmax) && (!flashmask_has_end || (m_block + 1) * kBlockM < flashmask_downendmin))

#define SPARSE_MASKED_UP \
(!Is_causal && (m_block + 1) * kBlockM < flashmask_upendmin && (!flashmask_has_end || m_block * kBlockM >= flashmask_upstartmax))

#define SPARSE_MASKED \
(SPARSE_MASKED_DOWN || SPARSE_MASKED_UP)

const bool enable_mask_bypass = params.enable_mask_bypass;

if (Is_sparse_attn_mask && enable_mask_bypass) {
int flashmask_downstartmax = std::numeric_limits<int>::max();
int flashmask_downendmin = 0;
int flashmask_upendmin = 0;
int flashmask_upstartmax = std::numeric_limits<int>::max();

if(params.flashmask_downstart_nblockmax != nullptr)
flashmask_downstartmax = gSparseMaskDownMax[n_block];
if(params.flashmask_downend_nblockmin != nullptr)
flashmask_downendmin = gSparseMaskDownEndMin[n_block];
if(params.flashmask_upend_nblockmin != nullptr)
flashmask_upendmin = gSparseMaskUpMin[n_block];
if(params.flashmask_upstart_nblockmax != nullptr)
flashmask_upstartmax = gSparseMaskUpStartMax[n_block];

if (Is_sparse_attn_mask && enable_mask_bypass && !flashmask_has_end) {
m_block_max = min(m_block_max,
cute::ceil_div(gSparseMaskDownMax[n_block], kBlockM));
/*
Expand Down Expand Up @@ -744,7 +767,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in

int m_block = m_block_max - 1;
int m_block_min = !Is_causal ? 0 : (n_block * kBlockN) / kBlockM;
if(Is_sparse_attn_mask && enable_mask_bypass){
if(Is_sparse_attn_mask && enable_mask_bypass && !flashmask_has_end){
if (!Is_causal) {
m_block_min = max(m_block_min, gSparseMaskUpMin[n_block] / kBlockM);
}
Expand Down Expand Up @@ -922,8 +945,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// cute::copy(smem_tiled_copy_KV, tSsK(_, _, k), tSrK_copy_view(_, _, k));
// }
// if (cute::thread0()) { print(tSrK); }
flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp,
smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV);

if (!SPARSE_MASKED) {
flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp,
smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV);
}

// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
Expand Down Expand Up @@ -1005,7 +1031,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
}
// if (cute::thread(32, 0)) { print(scores); }
// Compute the exponential value.
flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
if (!SPARSE_MASKED) {
flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
}
if (Is_dropout) {
uint32_t warp_id = tidx / 32;
uint32_t block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS;
Expand Down Expand Up @@ -1048,21 +1076,23 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in

// if (cute::thread0()) { print(dP_sum); }

flash::gemm</*A_in_regs=*/false, /*B_in_regs=*/Kernel_traits::Is_V_in_regs>(
acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp,
smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV
);

// Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
Tensor dS = make_tensor(acc_dp.data(), scores.layout());
auto pointwise_mult = [](float p, float dp, float d) {
return p * (!Is_dropout || p >= 0 ? dp - d : d);
};
#pragma unroll
for (int mi = 0; mi < size<0>(dS); ++mi) {
if (!SPARSE_MASKED) {
flash::gemm</*A_in_regs=*/false, /*B_in_regs=*/Kernel_traits::Is_V_in_regs>(
acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp,
smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV
);

// Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
auto pointwise_mult = [](float p, float dp, float d) {
return p * (!Is_dropout || p >= 0 ? dp - d : d);
};
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) {
dS(mi, ni) = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));
for (int mi = 0; mi < size<0>(dS); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) {
dS(mi, ni) = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));
}
}
}
// if (cute::thread0()) { print(dS); }
Expand Down Expand Up @@ -1104,8 +1134,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// flash::gemm_A_in_regs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt);
// Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout());
// flash::gemm_A_in_regs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt);
flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv,
smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
if (!SPARSE_MASKED) {
flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv,
smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
}
// if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); }
// if (cute::thread0()) { print(acc_dv); }

Expand All @@ -1124,8 +1156,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
}
}

flash::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq,
smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt);
if (!SPARSE_MASKED) {
flash::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq,
smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt);
}
// if (cute::thread0()) { print(acc_dq); }

if (m_block > m_block_min) {
Expand Down Expand Up @@ -1163,8 +1197,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
}

flash::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv,
smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
if (!SPARSE_MASKED) {
flash::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv,
smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
}
// if (cute::thread0()) { print(acc_dk); }
if (Double_buffer) { // Double buffer for sQ
tdKsQt.data() = tdKsQt.data() + (m_block % 2 == 0 ? size(sQ) : -size(sQ));
Expand Down
4 changes: 0 additions & 4 deletions csrc/flash_attn/src/flash_bwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,6 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
const bool is_attn_mask = params.attn_mask_ptr != nullptr;
const bool is_deterministic = params.num_splits == 1;
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
if (params.flashmask_downend_ptr != nullptr) {
// bypass is not supported for flashmask_downend
params.enable_mask_bypass = false;
}
prepare_sparsemask<Kernel_traits>(params, stream);
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
Expand Down

0 comments on commit 598f152

Please sign in to comment.