From 542fe9f04de423da5ae3600c08b2eafc1d511c12 Mon Sep 17 00:00:00 2001 From: cjkkkk Date: Wed, 20 Dec 2023 16:42:44 -0800 Subject: [PATCH] address some comments and fix wrong layout for softmax stat if O is not [batch, num_heads, seq, head] layout --- xla/service/gpu/cudnn_fused_mha_rewriter.cc | 4 ++-- .../gpu/cudnn_fused_mha_transpose_fusion.cc | 16 ++++++++-------- xla/stream_executor/cuda/cuda_dnn.cc | 16 ++++++++++++++-- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/xla/service/gpu/cudnn_fused_mha_rewriter.cc b/xla/service/gpu/cudnn_fused_mha_rewriter.cc index 3b94c5d4570d36..95a1c718a93f29 100644 --- a/xla/service/gpu/cudnn_fused_mha_rewriter.cc +++ b/xla/service/gpu/cudnn_fused_mha_rewriter.cc @@ -1552,7 +1552,6 @@ absl::StatusOr FuseBwdMultiHeadedAttentionBlock( HloInstruction* lhs_bmm2_grad_gemm1; HloInstruction* rhs_bmm2_grad_gemm2; HloInstruction* d_output_grad; - HloInstruction* fwd_act; DotDimensionNumbers orig_bmm1_grad1_config = bmm_1_grad_1->dot_dimension_numbers(); @@ -1587,6 +1586,7 @@ absl::StatusOr FuseBwdMultiHeadedAttentionBlock( // Forward activation // if it is not flash attention, fwd activation is the P tensor // else it is the softmax_stats + HloInstruction* fwd_act; if (fwd_config.is_flash_attention()) { auto fwd_act_index = 2; fwd_act = comp->AddInstruction(HloInstruction::CreateGetTupleElement( @@ -1835,7 +1835,7 @@ absl::StatusOr CudnnFusedMHARewriter::Run( matched_result.matched_custom_call_name, debug_options)); if (!is_mha_module_supported) continue; - // flash attention require cuDNN 8.9.3 to run non-fused QKV + // flash attention requires cuDNN 8.9.3 to run non-fused QKV // once we have fused QKV support, we can relax this contraint if (matched_result.is_flash_attention && !IsComputeCapabilityAndCudnnSupported( diff --git a/xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc b/xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc index 8729efded923e0..8a4ead418e2c12 100644 --- a/xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc +++ b/xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc @@ -102,8 +102,8 @@ absl::StatusOr FuseArgPrologueTransposeWithcuDNNFMHA( absl::Span checked_dims; std::vector checked_dims_vec; - // `should_contracting_be_fastest` means if contracting dim is the hidden - // dim. cuDNN requires hidden dim to be the fastest dim. fwd bmm1 and bwd + // `should_contracting_be_fastest` means if contracting dim is the head + // dim. cuDNN requires head dim to be the fastest dim. fwd bmm1 and bwd // bmm2grad1 should set this value to true. if (should_contracting_be_fastest) { checked_dims = is_lhs ? new_bmm_dot_dims.lhs_contracting_dimensions() @@ -133,10 +133,10 @@ absl::StatusOr FuseArgPrologueTransposeWithcuDNNFMHA( new_bmm_checked_dims[i] = std::distance(inverse_perm.begin(), itr); } // We want to make sure that making the argument to transpose, an input to - // fmha, doesn't break cuDNN constraint that the checked dimensions of - // corresponding operand of BMM has the fastest moving dimension. + // fmha, doesn't break cuDNN constraint that the head dim of + // corresponding operand of BMM is the fastest moving dimension. // One exception is the forward activation which doesn't have the constraint - // that the fastest dim has to be 64. + // since it does not have head dim. absl::Span minor_to_major_bmm = transpose_arg_operand->shape().layout().minor_to_major(); if ((minor_to_major_bmm[0] != new_bmm_checked_dims[0]) && @@ -516,7 +516,7 @@ use(FMHA_out_t) absl::StatusOr FuseEpilogueTransposeWithcuDNNFMHA(HloComputation* comp) { bool changed = false; - auto onlyOneGTEWithSpecIndex = [](const HloInstruction* instr, + auto only_one_gte_with_spec_index = [](const HloInstruction* instr, int64_t index) { int count = 0; for (auto user : instr->users()) { @@ -548,7 +548,7 @@ absl::StatusOr FuseEpilogueTransposeWithcuDNNFMHA(HloComputation* comp) { if (Match(instr, fwd_pattern)) { // check if only one gte with such index exist int64_t tuple_index = gte->tuple_index(); - if (!onlyOneGTEWithSpecIndex(fmha, tuple_index)) continue; + if (!only_one_gte_with_spec_index(fmha, tuple_index)) continue; std::vector inverse_perm = InversePermutation(transpose->dimensions()); @@ -600,7 +600,7 @@ absl::StatusOr FuseEpilogueTransposeWithcuDNNFMHA(HloComputation* comp) { } else if (Match(instr, bwd_pattern)) { // check if only one gte with such index exist int64_t operand_tuple_idx = gte->tuple_index(); - if (!onlyOneGTEWithSpecIndex(fmha, operand_tuple_idx)) continue; + if (!only_one_gte_with_spec_index(fmha, operand_tuple_idx)) continue; std::vector inverse_perm = InversePermutation(transpose->dimensions()); diff --git a/xla/stream_executor/cuda/cuda_dnn.cc b/xla/stream_executor/cuda/cuda_dnn.cc index 0c66995f8f8e49..e5c92794b9d126 100644 --- a/xla/stream_executor/cuda/cuda_dnn.cc +++ b/xla/stream_executor/cuda/cuda_dnn.cc @@ -6958,7 +6958,7 @@ GetCudnnFlashAttentionBackwardOperationGraph( d_output_descriptor.GetCudnnCompatibleDimensions(false); std::vector do_strides = d_output_descriptor.GetCudnnCompatibleStrides(false); - + VLOG(2) << "\n cuDNN compatible d_output_dims: " << absl::StrJoin(do_dims, ",") << "\n cuDNN compatible d_output_strides: " @@ -7144,9 +7144,21 @@ GetCudnnFlashAttentionBackwardOperationGraph( CreateCudnnTensor(p_dims, p_strides, CudnnfMHAUid::VIRTUAL_ID + 104, dnn::DataType::kFloat, 1, -1, /*is_virtual*/ true)); + + std::vector p_reduction_dims(p_dims.begin(), p_dims.end() - 1); + p_reduction_dims.push_back(1); + + // Divide every stride by the last dim value. + std::vector p_reduction_strides; + p_reduction_strides.reserve(p_strides.size()); + int64_t p_reduced_dim_len = p_dims.back(); + for (auto stride : p_strides) { + p_reduction_strides.push_back(stride / p_reduced_dim_len); + } + TF_ASSIGN_OR_RETURN( auto tensor_softmax_stats, - CreateCudnnTensor(do_reduction_dims, do_reduction_strides, + CreateCudnnTensor(p_reduction_dims, p_reduction_strides, CudnnfMHAUid::P_ID, dnn::DataType::kFloat, 1, -1)); TF_ASSIGN_OR_RETURN(auto sub_desc,