Skip to content

Commit

Permalink
Flashmask support upend (#46)
Browse files Browse the repository at this point in the history
* support upend

* fix bug
  • Loading branch information
kircle888 authored Aug 9, 2024
1 parent 806cdc4 commit 666d5c5
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 36 deletions.
27 changes: 22 additions & 5 deletions csrc/flash_attn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
__shared__ int32_t sparse_mask_smem_[Kernel_traits::kBlockN];
__shared__ int32_t sparse_mask_smem_up[Kernel_traits::kBlockN];
__shared__ int32_t sparse_mask_smem_downend[Kernel_traits::kBlockN];
__shared__ int32_t sparse_mask_smem_upstart[Kernel_traits::kBlockN];
extern __shared__ char smem_[];

// The thread index.
Expand Down Expand Up @@ -480,7 +481,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
const int *gSparseMaskDownEndMin =
reinterpret_cast<int32_t *>(params.flashmask_downend_nblockmin) +
row_offset_sparsemask_nblock;

const int* gSparseMaskUpStartMax = reinterpret_cast<int32_t*>(params.flashmask_upstart_nblockmax) + row_offset_sparsemask_nblock;
const int* gSparseMaskUpStartMin = reinterpret_cast<int32_t*>(params.flashmask_upstart_nblockmin) + row_offset_sparsemask_nblock;

int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM);
const bool flashmask_has_end = params.flashmask_downend_ptr != nullptr;
int flashmask_upendrow = params.seqlen_q;
Expand Down Expand Up @@ -564,6 +567,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
Shape<Int<kBlockN>>{});
Tensor gSparseMaskDownEnd = make_tensor(make_gmem_ptr(reinterpret_cast<int32_t *>(params.flashmask_downend_ptr) + row_offset_sparse_mask),
Shape<Int<kBlockN>>{});
Tensor gSparseMaskUpStart = make_tensor(make_gmem_ptr(reinterpret_cast<int32_t *>(params.flashmask_upstart_ptr) + row_offset_sparse_mask),
Shape<Int<kBlockN>>{});

Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutQdO{});
Expand All @@ -590,6 +595,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
Tensor sSparseMask = make_tensor(make_smem_ptr(reinterpret_cast<int32_t *>(sparse_mask_smem_)), Shape<Int<kBlockN>>{});
Tensor sSparseMaskUp = make_tensor(make_smem_ptr(reinterpret_cast<int32_t *>(sparse_mask_smem_up)), Shape<Int<kBlockN>>{});
Tensor sSparseMaskDownEnd = make_tensor(make_smem_ptr(reinterpret_cast<int32_t *>(sparse_mask_smem_downend)), Shape<Int<kBlockN>>{});
Tensor sSparseMaskUpStart = make_tensor(make_smem_ptr(reinterpret_cast<int32_t *>(sparse_mask_smem_upstart)), Shape<Int<kBlockN>>{});
Tensor sdPsum = make_tensor(make_smem_ptr(reinterpret_cast<float2 *>((sP.data() + cute::max(size(sP), size(sdQ))).get())),
Shape<Int<Kernel_traits::kSmemdPsumCount / 2>>{});

Expand Down Expand Up @@ -881,8 +887,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
if (Is_sparse_attn_mask) {
if (tidx < kBlockN) {
sSparseMask(tidx) = gSparseMask(tidx);
if(!Is_causal)
if(!Is_causal){
sSparseMaskUp(tidx) = gSparseMaskUp(tidx);
if(flashmask_has_end){
sSparseMaskUpStart(tidx) = gSparseMaskUpStart(tidx);
}
}
if(flashmask_has_end)
sSparseMaskDownEnd(tidx) = gSparseMaskDownEnd(tidx);
}
Expand Down Expand Up @@ -935,9 +945,16 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
if (!Is_causal) {
if (Is_sparse_attn_mask &&
((m_block + 1) * kBlockM >= flashmask_startrow || m_block * kBlockM < flashmask_upendrow)){
flash::apply_sparse_mask(scores, sSparseMask, sSparseMaskUp, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, binfo.actual_seqlen_k,
m_block * kBlockM + get<0>(taccScS_row(0)),
AtomLayoutMS * 16, n_block * kBlockN);
if(flashmask_has_end){
flash::apply_sparse_mask_withend(scores, sSparseMask, sSparseMaskDownEnd, sSparseMaskUp, sSparseMaskUpStart, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, binfo.actual_seqlen_k,
m_block * kBlockM + get<0>(taccScS_row(0)),
AtomLayoutMS * 16, n_block * kBlockN);
}
else {
flash::apply_sparse_mask(scores, sSparseMask, sSparseMaskUp, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, binfo.actual_seqlen_k,
m_block * kBlockM + get<0>(taccScS_row(0)),
AtomLayoutMS * 16, n_block * kBlockN);
}
} else if (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k) {
flash::apply_mask(scores, binfo.actual_seqlen_k,
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16);
Expand Down
119 changes: 88 additions & 31 deletions csrc/flash_attn/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
__shared__ int32_t sparse_mask_smem_[Kernel_traits::kBlockN];
__shared__ int32_t sparse_mask_smem_up[Kernel_traits::kBlockN];
__shared__ int32_t sparse_mask_smem_downend[Kernel_traits::kBlockN];
__shared__ int32_t sparse_mask_smem_upstart[Kernel_traits::kBlockN];
extern __shared__ char smem_[];

// The thread index.
Expand Down Expand Up @@ -207,13 +208,17 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor gSparseMaskUp = make_tensor(make_gmem_ptr(reinterpret_cast<int32_t *>(params.flashmask_upend_ptr) + row_offset_sparse_mask),
Shape<Int<kBlockN>>{});
Tensor gSparseMaskDownEnd = make_tensor(make_gmem_ptr(reinterpret_cast<int32_t *>(params.flashmask_downend_ptr) + row_offset_sparse_mask),
Shape<Int<kBlockN>>{});
Shape<Int<kBlockN>>{});
Tensor gSparseMaskUpStart = make_tensor(make_gmem_ptr(reinterpret_cast<int32_t *>(params.flashmask_upstart_ptr) + row_offset_sparse_mask),
Shape<Int<kBlockN>>{});
const int* gSparseMaskDownMax = reinterpret_cast<int32_t*>(params.flashmask_downstart_nblockmax) + row_offset_sparsemask_nblock;
const int* gSparseMaskDownMin = reinterpret_cast<int32_t*>(params.flashmask_downstart_nblockmin) + row_offset_sparsemask_nblock;
const int* gSparseMaskUpMax = reinterpret_cast<int32_t*>(params.flashmask_upend_nblockmax) + row_offset_sparsemask_nblock;
const int* gSparseMaskUpMin = reinterpret_cast<int32_t*>(params.flashmask_upend_nblockmin) + row_offset_sparsemask_nblock;
const int* gSparseMaskDownEndMax = reinterpret_cast<int32_t*>(params.flashmask_downend_nblockmax) + row_offset_sparsemask_nblock;
const int* gSparseMaskDownEndMin = reinterpret_cast<int32_t*>(params.flashmask_downend_nblockmin) + row_offset_sparsemask_nblock;
const int* gSparseMaskUpStartMax = reinterpret_cast<int32_t*>(params.flashmask_upstart_nblockmax) + row_offset_sparsemask_nblock;
const int* gSparseMaskUpStartMin = reinterpret_cast<int32_t*>(params.flashmask_upstart_nblockmin) + row_offset_sparsemask_nblock;
const bool enable_mask_bypass = params.enable_mask_bypass;
const bool flashmask_has_end = params.flashmask_downend_ptr != nullptr;

Expand All @@ -228,6 +233,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor sSparseMask = make_tensor(make_smem_ptr(reinterpret_cast<int32_t *>(sparse_mask_smem_)), Shape<Int<kBlockN>>{});
Tensor sSparseMaskUp = make_tensor(make_smem_ptr(reinterpret_cast<int32_t *>(sparse_mask_smem_up)), Shape<Int<kBlockN>>{});
Tensor sSparseMaskDownEnd = make_tensor(make_smem_ptr(reinterpret_cast<int32_t *>(sparse_mask_smem_downend)), Shape<Int<kBlockN>>{});
Tensor sSparseMaskUpStart = make_tensor(make_smem_ptr(reinterpret_cast<int32_t *>(sparse_mask_smem_upstart)), Shape<Int<kBlockN>>{});

typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
Expand Down Expand Up @@ -380,7 +386,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
#define SPARSE_MASKED_DOWN(N_BLOCK) \
(((m_block * kBlockM) >= gSparseMaskDownMax[(N_BLOCK)]) && (!flashmask_has_end || (m_block + 1) * kBlockM < gSparseMaskDownEndMin[(N_BLOCK)]))
#define SPARSE_MASKED_UP(N_BLOCK) \
(!Is_causal && (m_block + 1) * kBlockM < gSparseMaskUpMin[(N_BLOCK)])
(!Is_causal && (m_block + 1) * kBlockM < gSparseMaskUpMin[(N_BLOCK)] && (!flashmask_has_end || m_block * kBlockM >= gSparseMaskUpStartMax[(N_BLOCK)]))
#define SPARSE_MASKED(N_BLOCK) \
(SPARSE_MASKED_DOWN(N_BLOCK) || SPARSE_MASKED_UP(N_BLOCK))
constexpr int n_masking_steps = Is_causal ? cute::ceil_div(kBlockM, kBlockN) : 1;
Expand Down Expand Up @@ -431,18 +437,38 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
if (tidx < kBlockN) {
sSparseMask(tidx) = gSparseMask(tidx);
sSparseMaskUp(tidx) = gSparseMaskUp(tidx);
if(flashmask_has_end){
sSparseMaskUpStart(tidx) = gSparseMaskUpStart(tidx);
sSparseMaskDownEnd(tidx) = gSparseMaskDownEnd(tidx);
}
}
__syncthreads();
flash::apply_sparse_mask(
scores,
sSparseMask,
sSparseMaskUp,
n_block * kBlockN,
binfo.actual_seqlen_k,
// m_block * kBlockM + get<0>(idx_row(0)),
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
kNWarps * 16,
n_block * kBlockN);
if(flashmask_has_end){
flash::apply_sparse_mask_withend(
scores,
sSparseMask,
sSparseMaskDownEnd,
sSparseMaskUp,
sSparseMaskUpStart,
n_block * kBlockN,
binfo.actual_seqlen_k,
// m_block * kBlockM + get<0>(idx_row(0)),
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
kNWarps * 16,
n_block * kBlockN);
}
else{
flash::apply_sparse_mask(
scores,
sSparseMask,
sSparseMaskUp,
n_block * kBlockN,
binfo.actual_seqlen_k,
// m_block * kBlockM + get<0>(idx_row(0)),
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
kNWarps * 16,
n_block * kBlockN);
}
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
} else if (!Is_even_N) {
Expand Down Expand Up @@ -511,8 +537,12 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
if(flashmask_has_end){
gSparseMaskDownEnd.data() = gSparseMaskDownEnd.data() + (-kBlockN);
}
if (!Is_causal)
if (!Is_causal){
gSparseMaskUp.data() = gSparseMaskUp.data() + (-kBlockN);
if(flashmask_has_end){
gSparseMaskUpStart.data() = gSparseMaskUpStart.data() + (-kBlockN);
}
}
}

flash::cp_async_wait<0>();
Expand All @@ -539,7 +569,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// We must check inf if use sparse_attn_mask
masking_step == 0
? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/true>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
: softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/true>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); }
: softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/true>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
}
else{
masking_step == 0
? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_attn_mask>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
Expand Down Expand Up @@ -588,8 +619,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
gSparseMask.data() = gSparseMask.data() + (-kBlockN);
gSparseMaskDownEnd.data() = gSparseMaskDownEnd.data() + (-kBlockN);
if (!Is_causal)
if (!Is_causal){
gSparseMaskUp.data() = gSparseMaskUp.data() + (-kBlockN);
gSparseMaskUpStart.data() = gSparseMaskUpStart.data() + (-kBlockN);
}
if (Return_softmax)
tPgP.data() = tPgP.data() + (-kBlockN);
continue;
Expand Down Expand Up @@ -641,21 +674,41 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
(!enable_mask_bypass ||
(m_block + 1) * kBlockM >= gSparseMaskDownMin[n_block] ||
m_block * kBlockM < gSparseMaskUpMax[n_block])) {
if (tidx < kBlockN) {
sSparseMask(tidx) = gSparseMask(tidx);
sSparseMaskUp(tidx) = gSparseMaskUp(tidx);
}
__syncthreads();
flash::apply_sparse_mask(
scores,
sSparseMask,
sSparseMaskUp,
n_block * kBlockN,
binfo.actual_seqlen_k,
// m_block * kBlockM + get<0>(idx_row(0)),
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
kNWarps * 16,
n_block * kBlockN);
if (tidx < kBlockN) {
sSparseMask(tidx) = gSparseMask(tidx);
sSparseMaskUp(tidx) = gSparseMaskUp(tidx);
if(flashmask_has_end){
sSparseMaskUpStart(tidx) = gSparseMaskUpStart(tidx);
sSparseMaskDownEnd(tidx) = gSparseMaskDownEnd(tidx);
}
}
__syncthreads();
if(flashmask_has_end){
flash::apply_sparse_mask_withend(
scores,
sSparseMask,
sSparseMaskDownEnd,
sSparseMaskUp,
sSparseMaskUpStart,
n_block * kBlockN,
binfo.actual_seqlen_k,
// m_block * kBlockM + get<0>(idx_row(0)),
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
kNWarps * 16,
n_block * kBlockN);
}
else{
flash::apply_sparse_mask(
scores,
sSparseMask,
sSparseMaskUp,
n_block * kBlockN,
binfo.actual_seqlen_k,
// m_block * kBlockM + get<0>(idx_row(0)),
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
kNWarps * 16,
n_block * kBlockN);
}
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
} else if (!Is_attn_mask && Is_causal && Is_sparse_attn_mask &&
Expand Down Expand Up @@ -698,8 +751,12 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
if(flashmask_has_end){
gSparseMaskDownEnd.data() = gSparseMaskDownEnd.data() + (-kBlockN);
}
if (!Is_causal)
if (!Is_causal){
gSparseMaskUp.data() = gSparseMaskUp.data() + (-kBlockN);
if(flashmask_has_end){
gSparseMaskUpStart.data() = gSparseMaskUpStart.data() + (-kBlockN);
}
}
}

if(Is_sparse_attn_mask){
Expand Down
40 changes: 40 additions & 0 deletions csrc/flash_attn/src/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,46 @@ inline __device__ void apply_sparse_mask(Tensor<Engine, Layout> &tensor, Tensor<
}
}

template <typename Engine, typename Layout, typename Engine1, typename Layout1>
inline __device__ void apply_sparse_mask_withend(Tensor<Engine, Layout> &tensor, Tensor<Engine1, Layout1> &flashmask_downstart, Tensor<Engine1, Layout1> &flashmask_downend, Tensor<Engine1, Layout1> &flashmask_upend, Tensor<Engine1, Layout1> &flashmask_upstart, const uint32_t col_idx_offset_,
const uint32_t max_seqlen_k, const uint32_t row_idx_offset_,
const uint32_t warp_row_stride, const uint32_t mask_col_idx_offset) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const uint32_t lane_id = threadIdx.x % 32;
// const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4;
const uint32_t row_idx_offset = row_idx_offset_;
const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const uint32_t col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const uint32_t col_idx = col_idx_base + j;
const uint32_t downstart = flashmask_downstart(col_idx - mask_col_idx_offset);
const uint32_t downend = flashmask_downend(col_idx - mask_col_idx_offset);
const uint32_t upstart = flashmask_upstart(col_idx - mask_col_idx_offset);
const uint32_t upend = flashmask_upend(col_idx - mask_col_idx_offset);

#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const uint32_t row_idx = row_idx_base + i * 8;
if (col_idx >= max_seqlen_k) {
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
} else if (row_idx >= downstart && row_idx < downend) {
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
} else if (row_idx >= upstart && row_idx < upend) {
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
}
}
}
}
}
}

// TODO(umiswing): support cu_attn_mask
// This kernel should work after dealing with input cu_seq indicating mask position.
template <typename Engine, typename Layout, typename T>
Expand Down

0 comments on commit 666d5c5

Please sign in to comment.