From b4c7fbadb0785748d983deaed923955ead7a6d06 Mon Sep 17 00:00:00 2001 From: cjkkkk Date: Mon, 27 Nov 2023 14:18:41 -0800 Subject: [PATCH] add gpu backend to fmha e2e tests && address some format issues --- xla/service/gpu/cudnn_fused_mha_rewriter.cc | 19 ++++++++++--------- .../gpu/cudnn_fused_mha_transpose_fusion.cc | 4 ++-- xla/service/gpu/tests/BUILD | 8 +++++++- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/xla/service/gpu/cudnn_fused_mha_rewriter.cc b/xla/service/gpu/cudnn_fused_mha_rewriter.cc index 25423e6eeba4cf..38fa7af821759f 100644 --- a/xla/service/gpu/cudnn_fused_mha_rewriter.cc +++ b/xla/service/gpu/cudnn_fused_mha_rewriter.cc @@ -412,19 +412,20 @@ StatusOr IsSupportedBMM2(const HloInstruction* bmm_2, StatusOr IsFlashAttention(HloInstruction* bmm_1, bool is_causal_mask, absl::string_view custom_call_name) { + const DotDimensionNumbers& dnums = bmm_1->dot_dimension_numbers(); TF_ASSIGN_OR_RETURN( std::vector seq_q_dims, GetNonContractingDims( bmm_1->operand(0)->shape(), - bmm_1->dot_dimension_numbers().lhs_batch_dimensions(), - bmm_1->dot_dimension_numbers().lhs_contracting_dimensions())); + dnums.lhs_batch_dimensions(), + dnums.lhs_contracting_dimensions())); TF_ASSIGN_OR_RETURN( std::vector seq_k_dims, GetNonContractingDims( bmm_1->operand(1)->shape(), - bmm_1->dot_dimension_numbers().rhs_batch_dimensions(), - bmm_1->dot_dimension_numbers().rhs_contracting_dimensions())); + dnums.rhs_batch_dimensions(), + dnums.rhs_contracting_dimensions())); std::vector seq_q = GetDimensionVector(bmm_1->operand(0)->shape().dimensions(), seq_q_dims); @@ -434,7 +435,7 @@ StatusOr IsFlashAttention(HloInstruction* bmm_1, bool is_causal_mask, std::vector hidden_dim = GetDimensionVector( bmm_1->operand(0)->shape().dimensions(), - bmm_1->dot_dimension_numbers().lhs_contracting_dimensions()); + dnums.lhs_contracting_dimensions()); // for now, seq_q and seq_k should be equal for flash attention to work // flash attention only supports fixed topology so we check if custom call is // such topology by checking custom_call_name @@ -1350,12 +1351,12 @@ StatusOr FuseFwdMultiHeadedAttentionBlock( if (is_training) { activation_output = bmm_2->mutable_operand(0); // Sometimes activation output is bitcast, the actual activation is the - // second user of the producer of bmm_2's first operand. + // other user of the producer of bmm_2's first operand. if (activation_output->user_count() < 2 && activation_output->opcode() == HloOpcode::kBitcast) { HloInstruction* producer = activation_output->mutable_operand(0); TF_RET_CHECK(producer->user_count() == 2); - HloInstruction* bmm2_grad2_user = producer->UserId(activation_output) == 0 + HloInstruction* bmm2_grad2_user = producer->users()[0] == activation_output ? producer->users()[1] : producer->users()[0]; // might be (transpose) - bmm2_grad2 @@ -1759,9 +1760,9 @@ StatusOr CudnnFusedMHARewriter::Run( if (!is_mha_module_supported) continue; // We make sure no attention block is matched and replaced twice here - if (matched_bmm1.find(matched_result.matched_bmm_1) != matched_bmm1.end()) + if (!matched_bmm1.insert(matched_result.matched_bmm_1).second) { continue; - matched_bmm1.insert(matched_result.matched_bmm_1); + } // If we need to canonicalize the bmm, we will assign the newly // canonicalized bmm to bmm_2. if (matched_result.need_canonicalization) { diff --git a/xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc b/xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc index 513fef387a2481..5003b731df22d9 100644 --- a/xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc +++ b/xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc @@ -465,8 +465,8 @@ StatusOr FusePrologueTransposeWithcuDNNFMHA(HloComputation* comp) { if (!config.is_flash_attention()) { TF_ASSIGN_OR_RETURN(changed, FuseArgPrologueTransposeWithcuDNNFMHA( - fmha, 4, true /*is_lhs*/, - true /*should_contracting_be_fastest*/)); + fmha, 4, true /*is_lhs=*/, + true /*should_contracting_be_fastest=*/)); } if (changed && VLOG_IS_ON(2)) { diff --git a/xla/service/gpu/tests/BUILD b/xla/service/gpu/tests/BUILD index d437f1b8a45648..34a38ae6feba7d 100644 --- a/xla/service/gpu/tests/BUILD +++ b/xla/service/gpu/tests/BUILD @@ -886,10 +886,16 @@ xla_test( ], ) -xla_cc_test( +xla_test( name = "gpu_fused_mha_test", srcs = ["gpu_fused_mha_test.cc"], tags = tf_cuda_tests_tags(), + backend_tags = {"gpu": [ + "requires-gpu-sm80", + ]}, + backends = [ + "gpu", + ], deps = [ ":gpu_codegen_test", "//xla:array4d",