diff --git a/xla/service/gpu/cudnn_fused_mha_rewriter.cc b/xla/service/gpu/cudnn_fused_mha_rewriter.cc index 56f4475b201c1..fd237b76067bb 100644 --- a/xla/service/gpu/cudnn_fused_mha_rewriter.cc +++ b/xla/service/gpu/cudnn_fused_mha_rewriter.cc @@ -1770,6 +1770,39 @@ absl::StatusOr FuseBwdMultiHeadedAttentionBlock( } return true; } + +Status RestoreFwdGraph(HloComputation* comp, HloInstruction* fwd_fmha_call, + HloInstruction* bmm2, HloInstruction* activation, + HloInstruction* original_bmm2_producer0, + HloInstruction* original_bmm2_producer1, + std::vector& original_activation_producers) { + // If backward pattern is not matched, we need to restore the + // original graph structure. + // Replacing new GTEs added by forward FMHA call with cloned old + // activations and bmm2. + HloInstruction* output_gte = fwd_fmha_call->users()[0]; + HloInstruction* activation_gte = fwd_fmha_call->users()[1]; + std::string suffix = "fmha_no_match_clone"; + HloInstruction* cloned_activation = + comp->AddInstruction(activation->CloneWithNewOperands( + activation->shape(), original_activation_producers, suffix)); + + // Since old activation is detached by forward FMHA rewrite, we need + // to use the newly cloned activation. + HloInstruction* lhs = activation == original_bmm2_producer0 + ? cloned_activation + : original_bmm2_producer1; + HloInstruction* rhs = activation == original_bmm2_producer0 + ? original_bmm2_producer1 + : cloned_activation; + HloInstruction* cloned_bmm2 = comp->AddInstruction( + bmm2->CloneWithNewOperands(bmm2->shape(), {lhs, rhs}, suffix)); + + TF_RETURN_IF_ERROR(comp->ReplaceInstruction(output_gte, cloned_bmm2)); + TF_RETURN_IF_ERROR( + comp->ReplaceInstruction(activation_gte, cloned_activation)); + return OkStatus(); +} } // namespace absl::StatusOr CudnnFusedMHARewriter::Run( @@ -1796,7 +1829,6 @@ absl::StatusOr CudnnFusedMHARewriter::Run( if (!matched_result.has_match) { continue; } - // flash attention TODO: dont fuse fwd if bwd is not supported // We check the validity of bmms here before canonicalization so we don't // modify the graph if mha fusion is not possible // Relax 512 constraint if it is flash attention @@ -1809,6 +1841,15 @@ absl::StatusOr CudnnFusedMHARewriter::Run( matched_result.matched_custom_call_name, debug_options)); if (!is_mha_module_supported) continue; + // flash attention require cuDNN 8.9.3 to run non-fused QKV + // once we have fused QKV support, we can relax this contraint + if (matched_result.is_flash_attention && + !IsComputeCapabilityAndCudnnSupported( + compute_capability_, cudnn_version_, stream_executor_, + stream_executor::dnn::VersionInfo(8, 9, 3))) { + VLOG(2) << "Require cuDNN 8.9.3 to run flash attention."; + continue; + } // If we have an activation with more than 1 users in non-training mode, // we cannot rewrite the graph. So skip processing the rest. HloInstruction* activation = @@ -1858,46 +1899,30 @@ absl::StatusOr CudnnFusedMHARewriter::Run( matched_result.is_flash_attention)); any_changed |= changed; if (matched_result.is_training) { - // if fwd uses mask input, then bwd needs cudnn 8.9.1 to take in a mask - // input if cudnn version < 8.9.1 we won't lower the bwd pass - if (matched_result.matched_mask != nullptr && - !IsComputeCapabilityAndCudnnSupported( - compute_capability_, cudnn_version_, stream_executor_, - stream_executor::dnn::VersionInfo(8, 9, 1))) { - continue; - } MatchBwdResult matched_bwd_result = MatchBwdMHAPatternsForCanonicalization( fwd_fmha_call, matched_result.matched_bmm_1, matched_result.matched_mask, v_transposed); if (!matched_bwd_result.has_match) { VLOG(2) << "Backward pattern not matching, skipping."; - // If backward pattern is not matched, we need to restore the - // original graph structure. - // Replacing new GTEs added by forward FMHA call with cloned old - // activations and bmm2. - HloInstruction* output_gte = fwd_fmha_call->users()[0]; - HloInstruction* activation_gte = fwd_fmha_call->users()[1]; - std::string suffix = "fmha_no_match_clone"; - HloInstruction* cloned_activation = - comp->AddInstruction(activation->CloneWithNewOperands( - activation->shape(), original_activation_producers, suffix)); - - // Since old activation is detached by forward FMHA rewrite, we need - // to use the newly cloned activation. - HloInstruction* lhs = activation == original_bmm2_producer0 - ? cloned_activation - : original_bmm2_producer1; - HloInstruction* rhs = activation == original_bmm2_producer0 - ? original_bmm2_producer1 - : cloned_activation; - HloInstruction* cloned_bmm2 = comp->AddInstruction( - matched_result.matched_bmm_2->CloneWithNewOperands( - matched_result.matched_bmm_2->shape(), {lhs, rhs}, suffix)); - - TF_RETURN_IF_ERROR(comp->ReplaceInstruction(output_gte, cloned_bmm2)); + // restore fwd graph if bwd pattern match failed + TF_RETURN_IF_ERROR( + RestoreFwdGraph(comp, fwd_fmha_call, matched_result.matched_bmm_2, + activation, original_bmm2_producer0, original_bmm2_producer1, + original_activation_producers)); + continue; + } + // if fwd uses mask input, then bwd needs cudnn 8.9.1 to take in a mask + // input if cudnn version < 8.9.1 we won't lower the bwd pass + if (matched_result.matched_mask != nullptr && + !IsComputeCapabilityAndCudnnSupported( + compute_capability_, cudnn_version_, stream_executor_, + stream_executor::dnn::VersionInfo(8, 9, 1))) { + // restore fwd graph if bwd pattern match failed TF_RETURN_IF_ERROR( - comp->ReplaceInstruction(activation_gte, cloned_activation)); + RestoreFwdGraph(comp, fwd_fmha_call, matched_result.matched_bmm_2, + activation, original_bmm2_producer0, original_bmm2_producer1, + original_activation_producers)); continue; } // check if dbias exist and the cudnn version is > 8.9.1. We @@ -1907,6 +1932,11 @@ absl::StatusOr CudnnFusedMHARewriter::Run( !IsComputeCapabilityAndCudnnSupported( compute_capability_, cudnn_version_, stream_executor_, stream_executor::dnn::VersionInfo(8, 9, 1))) { + // restore fwd graph if bwd pattern match failed + TF_RETURN_IF_ERROR( + RestoreFwdGraph(comp, fwd_fmha_call, matched_result.matched_bmm_2, + activation, original_bmm2_producer0, original_bmm2_producer1, + original_activation_producers)); continue; } // Canonicalize gemms