diff --git a/csrc/flash_attn/src/fmha/gemm.h b/csrc/flash_attn/src/fmha/gemm.h index 2fff2b219..72d68b0ac 100644 --- a/csrc/flash_attn/src/fmha/gemm.h +++ b/csrc/flash_attn/src/fmha/gemm.h @@ -135,10 +135,11 @@ struct alignas(static_cast(Base_::ALIGNMENT)) Fragment : public Base_ { } // Multiply by another fragment. + template inline __device__ void hmul(const Fragment &other) { #pragma unroll for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) { - this->reg(ii) = fmha::hmul2(this->reg(ii), other.reg(ii)); + this->reg(ii) = fmha::hmul2(this->reg(ii), other.reg(ii)); } } diff --git a/csrc/flash_attn/src/fmha/utils.h b/csrc/flash_attn/src/fmha/utils.h index 110dda25f..0494e4c0b 100644 --- a/csrc/flash_attn/src/fmha/utils.h +++ b/csrc/flash_attn/src/fmha/utils.h @@ -272,7 +272,11 @@ static inline __device__ uint32_t hmin2(uint32_t a, uint32_t b) { //////////////////////////////////////////////////////////////////////////////////////////////////// -static inline __device__ uint32_t hmul2(const uint32_t a, const uint32_t b) { +template +static inline __device__ uint32_t hmul2(const uint32_t a, const uint32_t b); + +template<> +inline __device__ uint32_t hmul2<__half>(const uint32_t a, const uint32_t b) { // uint32_t c; // asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); // return c; @@ -281,6 +285,18 @@ static inline __device__ uint32_t hmul2(const uint32_t a, const uint32_t b) { return reinterpret_cast(result); } +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template<> +inline __device__ uint32_t hmul2<__nv_bfloat16>(const uint32_t a, const uint32_t b) { + // uint32_t c; + // asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + // return c; + __nv_bfloat162 result = __hmul2(reinterpret_cast(a), + reinterpret_cast(b)); + return reinterpret_cast(result); +} +#endif + //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint2 hmul4(uint2 a, uint2 b) { @@ -292,23 +308,25 @@ static inline __device__ uint2 hmul4(uint2 a, uint2 b) { //////////////////////////////////////////////////////////////////////////////////////////////////// +template static inline __device__ uint4 hmul8(uint4 a, uint4 b) { uint4 c; - c.x = hmul2(a.x, b.x); - c.y = hmul2(a.y, b.y); - c.z = hmul2(a.z, b.z); - c.w = hmul2(a.w, b.w); + c.x = hmul2(a.x, b.x); + c.y = hmul2(a.y, b.y); + c.z = hmul2(a.z, b.z); + c.w = hmul2(a.w, b.w); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// +template static inline __device__ uint4 hmul8(uint32_t a, uint4 b) { uint4 c; - c.x = hmul2(a, b.x); - c.y = hmul2(a, b.y); - c.z = hmul2(a, b.z); - c.w = hmul2(a, b.w); + c.x = hmul2(a, b.x); + c.y = hmul2(a, b.y); + c.z = hmul2(a, b.z); + c.w = hmul2(a, b.w); return c; } diff --git a/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h index 01282691c..9ad4569b1 100644 --- a/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h @@ -31,6 +31,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, const int loop_step_idx) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + constexpr bool is_fp16_type = std::is_same::value; using elem_type = typename Kernel_traits::elem_type; #else constexpr bool is_fp16_type = std::is_same::value; @@ -38,6 +39,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, using elem_type = __half; #endif + // The description of the CTA tile for the 1st batched GEMM. using Cta_tile_p = typename Kernel_traits::Cta_tile_p; // The description of the CTA tile for the 2nd batched GEMM. @@ -262,7 +264,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, const uint32_t scale_dropout = params.scale_dropout; #pragma unroll for(int it=0; it < Gmem_tile_v::LDGS; it++){ - gmem_v.fetch_[it] = fmha::hmul8(scale_dropout, gmem_v.fetch_[it]); + gmem_v.fetch_[it] = fmha::hmul8(scale_dropout, gmem_v.fetch_[it]); } } @@ -485,10 +487,10 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { #pragma unroll for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { - frag_p[mi][ni].hmul(frag_dp[mi][ni]); + frag_p[mi][ni].template hmul(frag_dp[mi][ni]); } } - } else { + } else if (is_fp16_type) { __half2 dp_sum_half[Mma_tile_p::MMAS_M * 2]; for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { dp_sum_half[mi] = __float2half2_rn(dp_sum[mi]); @@ -511,6 +513,31 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, } } } + } else { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + __nv_bfloat162 dp_sum_half[Mma_tile_p::MMAS_M * 2]; + for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { + dp_sum_half[mi] = __float2bfloat162_rn(dp_sum[mi]); + } + const __nv_bfloat16 zero_h = __nv_bfloat16(0.f); + #pragma unroll + for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { + #pragma unroll + for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { + #pragma unroll + for (int ii = 0; ii < 4; ++ii) { + const __nv_bfloat162 p = frag_p[mi][ni].template elt_as<__nv_bfloat162>(ii); + const __nv_bfloat162 pdp = __hmul2(p, frag_dp[mi][ni].template elt_as<__nv_bfloat162>(ii)); + // If this element is dropped, then frag_p stores -p instead of p. + // So pd holds -p * dp_sum in that case. + const __nv_bfloat162 pd = __hmul2(p, dp_sum_half[mi * 2 + (ii % 2)]); + const __nv_bfloat16 low = __low2bfloat16(p) >= zero_h ? __low2bfloat16(pdp) : __low2bfloat16(pd); + const __nv_bfloat16 high = __low2bfloat16(p) >= zero_h ? __low2bfloat16(pdp) : __low2bfloat16(pd); + frag_p[mi][ni].template elt_as<__nv_bfloat162>(ii) = __halves2bfloat162(low, high); + } + } + } +#endif } // Store dp to smem for transpose