Skip to content

Commit

Permalink
fix bf16 bug in grad
Browse files Browse the repository at this point in the history
  • Loading branch information
kuizhiqing committed May 29, 2023
1 parent c51e944 commit dc55747
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 13 deletions.
3 changes: 2 additions & 1 deletion csrc/flash_attn/src/fmha/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,11 @@ struct alignas(static_cast<int>(Base_::ALIGNMENT)) Fragment : public Base_ {
}

// Multiply by another fragment.
template <typename elem_type>
inline __device__ void hmul(const Fragment &other) {
#pragma unroll
for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
this->reg(ii) = fmha::hmul2(this->reg(ii), other.reg(ii));
this->reg(ii) = fmha::hmul2<elem_type>(this->reg(ii), other.reg(ii));
}
}

Expand Down
36 changes: 27 additions & 9 deletions csrc/flash_attn/src/fmha/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,11 @@ static inline __device__ uint32_t hmin2(uint32_t a, uint32_t b) {

////////////////////////////////////////////////////////////////////////////////////////////////////

static inline __device__ uint32_t hmul2(const uint32_t a, const uint32_t b) {
template<typename T=__half >
static inline __device__ uint32_t hmul2(const uint32_t a, const uint32_t b);

template<>
inline __device__ uint32_t hmul2<__half>(const uint32_t a, const uint32_t b) {
// uint32_t c;
// asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
// return c;
Expand All @@ -281,6 +285,18 @@ static inline __device__ uint32_t hmul2(const uint32_t a, const uint32_t b) {
return reinterpret_cast<uint32_t(&)>(result);
}

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template<>
inline __device__ uint32_t hmul2<__nv_bfloat16>(const uint32_t a, const uint32_t b) {
// uint32_t c;
// asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
// return c;
__nv_bfloat162 result = __hmul2(reinterpret_cast<const __nv_bfloat162 (&)>(a),
reinterpret_cast<const __nv_bfloat162 (&)>(b));
return reinterpret_cast<uint32_t(&)>(result);
}
#endif

////////////////////////////////////////////////////////////////////////////////////////////////////

static inline __device__ uint2 hmul4(uint2 a, uint2 b) {
Expand All @@ -292,23 +308,25 @@ static inline __device__ uint2 hmul4(uint2 a, uint2 b) {

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename elem_type=__half >
static inline __device__ uint4 hmul8(uint4 a, uint4 b) {
uint4 c;
c.x = hmul2(a.x, b.x);
c.y = hmul2(a.y, b.y);
c.z = hmul2(a.z, b.z);
c.w = hmul2(a.w, b.w);
c.x = hmul2<elem_type>(a.x, b.x);
c.y = hmul2<elem_type>(a.y, b.y);
c.z = hmul2<elem_type>(a.z, b.z);
c.w = hmul2<elem_type>(a.w, b.w);
return c;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename elem_type=__half >
static inline __device__ uint4 hmul8(uint32_t a, uint4 b) {
uint4 c;
c.x = hmul2(a, b.x);
c.y = hmul2(a, b.y);
c.z = hmul2(a, b.z);
c.w = hmul2(a, b.w);
c.x = hmul2<elem_type>(a, b.x);
c.y = hmul2<elem_type>(a, b.y);
c.z = hmul2<elem_type>(a, b.z);
c.w = hmul2<elem_type>(a, b.w);
return c;
}

Expand Down
33 changes: 30 additions & 3 deletions csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
const int loop_step_idx) {

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
constexpr bool is_fp16_type = std::is_same<typename Kernel_traits::elem_type, __half>::value;
using elem_type = typename Kernel_traits::elem_type;
#else
constexpr bool is_fp16_type = std::is_same<typename Kernel_traits::elem_type, __half>::value;
assert(is_fp16_type);
using elem_type = __half;
#endif


// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 2nd batched GEMM.
Expand Down Expand Up @@ -262,7 +264,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
const uint32_t scale_dropout = params.scale_dropout;
#pragma unroll
for(int it=0; it < Gmem_tile_v::LDGS; it++){
gmem_v.fetch_[it] = fmha::hmul8(scale_dropout, gmem_v.fetch_[it]);
gmem_v.fetch_[it] = fmha::hmul8<elem_type>(scale_dropout, gmem_v.fetch_[it]);
}
}

Expand Down Expand Up @@ -485,10 +487,10 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {
frag_p[mi][ni].hmul(frag_dp[mi][ni]);
frag_p[mi][ni].template hmul<elem_type>(frag_dp[mi][ni]);
}
}
} else {
} else if (is_fp16_type) {
__half2 dp_sum_half[Mma_tile_p::MMAS_M * 2];
for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) {
dp_sum_half[mi] = __float2half2_rn(dp_sum[mi]);
Expand All @@ -511,6 +513,31 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
}
}
}
} else {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
__nv_bfloat162 dp_sum_half[Mma_tile_p::MMAS_M * 2];
for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) {
dp_sum_half[mi] = __float2bfloat162_rn(dp_sum[mi]);
}
const __nv_bfloat16 zero_h = __nv_bfloat16(0.f);
#pragma unroll
for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {
#pragma unroll
for (int ii = 0; ii < 4; ++ii) {
const __nv_bfloat162 p = frag_p[mi][ni].template elt_as<__nv_bfloat162>(ii);
const __nv_bfloat162 pdp = __hmul2(p, frag_dp[mi][ni].template elt_as<__nv_bfloat162>(ii));
// If this element is dropped, then frag_p stores -p instead of p.
// So pd holds -p * dp_sum in that case.
const __nv_bfloat162 pd = __hmul2(p, dp_sum_half[mi * 2 + (ii % 2)]);
const __nv_bfloat16 low = __low2bfloat16(p) >= zero_h ? __low2bfloat16(pdp) : __low2bfloat16(pd);
const __nv_bfloat16 high = __low2bfloat16(p) >= zero_h ? __low2bfloat16(pdp) : __low2bfloat16(pd);
frag_p[mi][ni].template elt_as<__nv_bfloat162>(ii) = __halves2bfloat162(low, high);
}
}
}
#endif
}

// Store dp to smem for transpose
Expand Down

0 comments on commit dc55747

Please sign in to comment.