Skip to content

Commit

Permalink
add dim128 support for block sparsewq
Browse files Browse the repository at this point in the history
  • Loading branch information
kuizhiqing committed May 31, 2023
1 parent dc55747 commit ff74bc0
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 1 deletion.
2 changes: 1 addition & 1 deletion csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;

static_assert(Cta_tile_dkv::M == 512 || Cta_tile_dkv::M == 256 || Cta_tile_dkv::M == 128);
static_assert(Cta_tile_dkv::N == 16 || Cta_tile_dkv::N == 32 || Cta_tile_dkv::N == 64);
static_assert(Cta_tile_dkv::N == 16 || Cta_tile_dkv::N == 32 || Cta_tile_dkv::N == 64 || Cta_tile_dkv::N == 128);
static_assert(Cta_tile_dkv::K == 16);

// The MMA tile for the 1st GEMM.
Expand Down
3 changes: 3 additions & 0 deletions csrc/flash_attn/src/fmha_block_dgrad_kernel_loop.sm80.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ void run_fmha_block_dgrad_sm80(const FMHA_dgrad_params &params, cudaStream_t str
} else if (params.d == 64) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>;
run_fmha_block_dgrad_sm80_loop_<Kernel_traits>(params, stream);
} else if (params.d == 128) {
using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 8, 0x08u, elem_type>;
run_fmha_block_dgrad_sm80_loop_<Kernel_traits>(params, stream);
}
}));
}
3 changes: 3 additions & 0 deletions csrc/flash_attn/src/fmha_block_fprop_kernel.sm80.cu
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ void run_fmha_block_sm80(Launch_params<FMHA_fprop_params> &launch_params,
} else if (launch_params.params.d == 64) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_block_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if (launch_params.params.d == 128) {
using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x08u, elem_type>;
run_fmha_block_sm80_loop_<Kernel_traits>(launch_params, configure);
}
}));
}

0 comments on commit ff74bc0

Please sign in to comment.