Skip to content

Commit

Permalink
add flash attention cuDNN version check && restore fwd graph is dbias…
Browse files Browse the repository at this point in the history
…/mask is not supported
  • Loading branch information
Cjkkkk committed Jan 12, 2024
1 parent d1d9f2c commit 84e9ff3
Showing 1 changed file with 64 additions and 34 deletions.
98 changes: 64 additions & 34 deletions xla/service/gpu/cudnn_fused_mha_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1770,6 +1770,39 @@ absl::StatusOr<bool> FuseBwdMultiHeadedAttentionBlock(
}
return true;
}

Status RestoreFwdGraph(HloComputation* comp, HloInstruction* fwd_fmha_call,
HloInstruction* bmm2, HloInstruction* activation,
HloInstruction* original_bmm2_producer0,
HloInstruction* original_bmm2_producer1,
std::vector<HloInstruction*>& original_activation_producers) {
// If backward pattern is not matched, we need to restore the
// original graph structure.
// Replacing new GTEs added by forward FMHA call with cloned old
// activations and bmm2.
HloInstruction* output_gte = fwd_fmha_call->users()[0];
HloInstruction* activation_gte = fwd_fmha_call->users()[1];
std::string suffix = "fmha_no_match_clone";
HloInstruction* cloned_activation =
comp->AddInstruction(activation->CloneWithNewOperands(
activation->shape(), original_activation_producers, suffix));

// Since old activation is detached by forward FMHA rewrite, we need
// to use the newly cloned activation.
HloInstruction* lhs = activation == original_bmm2_producer0
? cloned_activation
: original_bmm2_producer1;
HloInstruction* rhs = activation == original_bmm2_producer0
? original_bmm2_producer1
: cloned_activation;
HloInstruction* cloned_bmm2 = comp->AddInstruction(
bmm2->CloneWithNewOperands(bmm2->shape(), {lhs, rhs}, suffix));

TF_RETURN_IF_ERROR(comp->ReplaceInstruction(output_gte, cloned_bmm2));
TF_RETURN_IF_ERROR(
comp->ReplaceInstruction(activation_gte, cloned_activation));
return OkStatus();
}
} // namespace

absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
Expand All @@ -1796,7 +1829,6 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
if (!matched_result.has_match) {
continue;
}
// flash attention TODO: dont fuse fwd if bwd is not supported
// We check the validity of bmms here before canonicalization so we don't
// modify the graph if mha fusion is not possible
// Relax 512 constraint if it is flash attention
Expand All @@ -1809,6 +1841,15 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
matched_result.matched_custom_call_name, debug_options));

if (!is_mha_module_supported) continue;
// flash attention require cuDNN 8.9.3 to run non-fused QKV
// once we have fused QKV support, we can relax this contraint
if (matched_result.is_flash_attention &&
!IsComputeCapabilityAndCudnnSupported(
compute_capability_, cudnn_version_, stream_executor_,
stream_executor::dnn::VersionInfo(8, 9, 3))) {
VLOG(2) << "Require cuDNN 8.9.3 to run flash attention.";
continue;
}
// If we have an activation with more than 1 users in non-training mode,
// we cannot rewrite the graph. So skip processing the rest.
HloInstruction* activation =
Expand Down Expand Up @@ -1858,46 +1899,30 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
matched_result.is_flash_attention));
any_changed |= changed;
if (matched_result.is_training) {
// if fwd uses mask input, then bwd needs cudnn 8.9.1 to take in a mask
// input if cudnn version < 8.9.1 we won't lower the bwd pass
if (matched_result.matched_mask != nullptr &&
!IsComputeCapabilityAndCudnnSupported(
compute_capability_, cudnn_version_, stream_executor_,
stream_executor::dnn::VersionInfo(8, 9, 1))) {
continue;
}
MatchBwdResult matched_bwd_result =
MatchBwdMHAPatternsForCanonicalization(
fwd_fmha_call, matched_result.matched_bmm_1,
matched_result.matched_mask, v_transposed);
if (!matched_bwd_result.has_match) {
VLOG(2) << "Backward pattern not matching, skipping.";
// If backward pattern is not matched, we need to restore the
// original graph structure.
// Replacing new GTEs added by forward FMHA call with cloned old
// activations and bmm2.
HloInstruction* output_gte = fwd_fmha_call->users()[0];
HloInstruction* activation_gte = fwd_fmha_call->users()[1];
std::string suffix = "fmha_no_match_clone";
HloInstruction* cloned_activation =
comp->AddInstruction(activation->CloneWithNewOperands(
activation->shape(), original_activation_producers, suffix));

// Since old activation is detached by forward FMHA rewrite, we need
// to use the newly cloned activation.
HloInstruction* lhs = activation == original_bmm2_producer0
? cloned_activation
: original_bmm2_producer1;
HloInstruction* rhs = activation == original_bmm2_producer0
? original_bmm2_producer1
: cloned_activation;
HloInstruction* cloned_bmm2 = comp->AddInstruction(
matched_result.matched_bmm_2->CloneWithNewOperands(
matched_result.matched_bmm_2->shape(), {lhs, rhs}, suffix));

TF_RETURN_IF_ERROR(comp->ReplaceInstruction(output_gte, cloned_bmm2));
// restore fwd graph if bwd pattern match failed
TF_RETURN_IF_ERROR(
RestoreFwdGraph(comp, fwd_fmha_call, matched_result.matched_bmm_2,
activation, original_bmm2_producer0, original_bmm2_producer1,
original_activation_producers));
continue;
}
// if fwd uses mask input, then bwd needs cudnn 8.9.1 to take in a mask
// input if cudnn version < 8.9.1 we won't lower the bwd pass
if (matched_result.matched_mask != nullptr &&
!IsComputeCapabilityAndCudnnSupported(
compute_capability_, cudnn_version_, stream_executor_,
stream_executor::dnn::VersionInfo(8, 9, 1))) {
// restore fwd graph if bwd pattern match failed
TF_RETURN_IF_ERROR(
comp->ReplaceInstruction(activation_gte, cloned_activation));
RestoreFwdGraph(comp, fwd_fmha_call, matched_result.matched_bmm_2,
activation, original_bmm2_producer0, original_bmm2_producer1,
original_activation_producers));
continue;
}
// check if dbias exist and the cudnn version is > 8.9.1. We
Expand All @@ -1907,6 +1932,11 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
!IsComputeCapabilityAndCudnnSupported(
compute_capability_, cudnn_version_, stream_executor_,
stream_executor::dnn::VersionInfo(8, 9, 1))) {
// restore fwd graph if bwd pattern match failed
TF_RETURN_IF_ERROR(
RestoreFwdGraph(comp, fwd_fmha_call, matched_result.matched_bmm_2,
activation, original_bmm2_producer0, original_bmm2_producer1,
original_activation_producers));
continue;
}
// Canonicalize gemms
Expand Down

0 comments on commit 84e9ff3

Please sign in to comment.