diff --git a/xla/service/gpu/cudnn_fused_mha_rewriter.cc b/xla/service/gpu/cudnn_fused_mha_rewriter.cc index 099e6776a370d3..373feb8094042b 100644 --- a/xla/service/gpu/cudnn_fused_mha_rewriter.cc +++ b/xla/service/gpu/cudnn_fused_mha_rewriter.cc @@ -63,10 +63,18 @@ struct MatchFwdResult { HloInstruction* matched_mask = nullptr; HloInstruction* matched_scale = nullptr; HloInstruction* matched_softmax_input = nullptr; + HloInstruction* matched_reduce_sum = nullptr; double matched_dropout_rate = 0.0; bool need_canonicalization = false; bool is_training = false; + // We use this to keep track of whether the bias or the mask that is being + // applied to the bmm1 is a causal mask, cuDNN can generate causal mask inside + // the attention kernel to save I/O. + bool is_causal_mask = false; + // We use this to keep track of whether the attention block should be lowered + // to flash attention or regular fused attention in cuDNN. + bool is_flash_attention = false; bool has_match = false; std::string matched_custom_call_name; }; @@ -402,6 +410,93 @@ StatusOr IsSupportedBMM2(const HloInstruction* bmm_2, return true; } +StatusOr IsFlashAttention(HloInstruction* bmm_1, bool is_causal_mask, + absl::string_view custom_call_name) { + TF_ASSIGN_OR_RETURN( + std::vector seq_q_dims, + GetNonContractingDims( + bmm_1->operand(0)->shape(), + bmm_1->dot_dimension_numbers().lhs_batch_dimensions(), + bmm_1->dot_dimension_numbers().lhs_contracting_dimensions())); + + TF_ASSIGN_OR_RETURN( + std::vector seq_k_dims, + GetNonContractingDims( + bmm_1->operand(1)->shape(), + bmm_1->dot_dimension_numbers().rhs_batch_dimensions(), + bmm_1->dot_dimension_numbers().rhs_contracting_dimensions())); + + std::vector seq_q = + GetDimensionVector(bmm_1->operand(0)->shape().dimensions(), seq_q_dims); + + std::vector seq_k = + GetDimensionVector(bmm_1->operand(1)->shape().dimensions(), seq_k_dims); + + std::vector hidden_dim = GetDimensionVector( + bmm_1->operand(0)->shape().dimensions(), + bmm_1->dot_dimension_numbers().lhs_contracting_dimensions()); + // for now, seq_q and seq_k should be equal for flash attention to work + // flash attention only supports fixed topology so we check if custom call is + // such topology by checking custom_call_name + TF_RET_CHECK(seq_q.size() == 1); + TF_RET_CHECK(seq_k.size() == 1); + TF_RET_CHECK(hidden_dim.size() == 1); + auto is_fixed_topology = + (custom_call_name == kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget || + custom_call_name == kCudnnfMHAScaleBiasSoftmaxCallTarget); + + auto is_seqlen_supported = + seq_q[0] == seq_k[0] && seq_q[0] > 512 && seq_q[0] % 64 == 0; + auto is_hidden_dim_supported = hidden_dim[0] == 64 || hidden_dim[0] == 128; + return is_seqlen_supported && is_hidden_dim_supported && is_fixed_topology; +} + +bool IsCausalMaskPattern(HloInstruction* mask) { + auto causal_mask = + m::Select(m::Compare(m::Iota(), m::Iota()), m::Broadcast(m::Constant()), + m::Broadcast(m::Constant())); + auto causal_mask_pattern_fwd_remat = + m::Broadcast(OptionalBitcast(causal_mask)); + auto causal_mask_pattern_bwd = m::Broadcast(m::Convert(OptionalBitcast( + m::Minimum(m::Op(), m::Broadcast(OptionalBitcast(causal_mask)))))); + HloInstruction* param = nullptr; + HloInstruction* gte = nullptr; + auto causal_mask_pattern_fwd = m::Broadcast( + OptionalBitcast(m::GetTupleElement(>e, m::Parameter(¶m)))); + auto causal_mask_pattern = m::AnyOf( + causal_mask_pattern_fwd_remat, causal_mask_pattern_fwd, + causal_mask_pattern_bwd); + if (Match(mask, causal_mask_pattern)) { + if (param != nullptr) { + // need to track to outside of the while loop body to find the real mask. + auto mask_index = gte->tuple_index(); + auto comp = param->parent(); + auto mod = comp->parent(); + auto name = comp->name(); + auto entry_computation = mod->entry_computation(); + bool is_causal_mask = true; + for (HloInstruction* instr : + entry_computation->MakeInstructionPostOrder()) { + if (instr->opcode() == HloOpcode::kWhile && + instr->while_body()->name() == name && + instr->operand(0)->opcode() == HloOpcode::kTuple) { + auto actual_mask = + instr->mutable_operand(0)->mutable_operand(mask_index); + auto causal_mask_pattern_fwd = + OptionalBitcast(m::Convert(m::MinimumAnyOrder( + m::Op(), + OptionalBitcast(m::MinimumAnyOrder( + m::Op(), m::Broadcast(OptionalBitcast(causal_mask))))))); + is_causal_mask &= Match(actual_mask, causal_mask_pattern_fwd); + } + } + return is_causal_mask; + } + return true; + } + return false; +} + MatchFwdResult MatchDefaultFwdBmmBmm(MatchFwdResult previous_result, int64_t bmm2_operand_position, HloInstruction* instr) { @@ -472,6 +567,14 @@ MatchFwdResult MatchSoftmaxDropoutBmm(MatchFwdResult previous_result, m::Broadcast(m::Constant(&dropout).WithPredicate(IsScalar)), m::Op()))))))))); + // Form3 -> softmax - mul(dropout) - mul(scale) - BMM2 + auto dropout_softmax_pattern_form_3 = m::MultiplyAnyOrder( + m::MultiplyAnyOrder( + OptionalConvert(GetUnfusedReduceMaxSumSoftmaxPattern( + &softmax_input, &softmax_reduce_sum, &softmax_reduce_sum_bcast)), + m::Op()), + m::Broadcast(m::Constant(&dropout).WithPredicate(IsScalar))); + // Try matching BMM1 - (Scale) - (Bias) - (Mask) - Softmax - (Dropout) - // BMM2 Dropout with non-zero drop rate has select(divide(softmax_output, // broadcast(1-dropout_rate))) @@ -485,7 +588,8 @@ MatchFwdResult MatchSoftmaxDropoutBmm(MatchFwdResult previous_result, &softmax_input, &softmax_reduce_sum, &softmax_reduce_sum_bcast))), dropout_softmax_pattern_form_1, - dropout_softmax_pattern_form_2)); + dropout_softmax_pattern_form_2, + dropout_softmax_pattern_form_3)); if (!Match(instr, softmax_dropout_bmm2_pattern) || !IsSupportedPrimitiveType(bmm_2)) { @@ -502,6 +606,7 @@ MatchFwdResult MatchSoftmaxDropoutBmm(MatchFwdResult previous_result, match_result.matched_dropout_rate = GetDropoutRateFromHlo(dropout); } match_result.matched_softmax_input = softmax_input; + match_result.matched_reduce_sum = softmax_reduce_sum; match_result.has_match = true; return match_result; } @@ -539,6 +644,7 @@ MatchFwdResult MatchBmm1UnfusedBiasSoftmaxBmm2(MatchFwdResult previous_result, match_result.matched_custom_call_name = has_dropout ? kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget : kCudnnfMHAScaleBiasSoftmaxCallTarget; + match_result.is_causal_mask |= IsCausalMaskPattern(bias); match_result.has_match = true; } else { match_result.has_match = false; @@ -561,18 +667,19 @@ MatchFwdResult MatchBmm1ScaleBiasMaskSoftmaxDropoutBmm2( OptionalConvert( m::Op(&bmm_1).WithPredicate(IsBatchedMatmul).WithOneUse()), m::Broadcast(m::Constant(&scale).WithPredicate(IsScalar)))); - if (Match( softmax_input, OptionalConvert(m::Select( m::Op(&mask).WithPredicate([](const HloInstruction* instr) { return instr->shape().element_type() == PRED; }), - // Match bmm1-scale-bias-mask + // Match bmm1-(scale)-(bias)-mask m::AnyOf( // Scale and bias might or might not be fused // with gemm - m::Op(&bmm_1).WithPredicate(IsBatchedMatmul).WithOneUse(), + OptionalConvert(m::Op(&bmm_1) + .WithPredicate(IsBatchedMatmul) + .WithOneUse()), OptionalConvert(m::AnyOf( // Try to match unfused bias m::AddAnyOrder(m::Op(&bias), @@ -588,9 +695,9 @@ MatchFwdResult MatchBmm1ScaleBiasMaskSoftmaxDropoutBmm2( matched_result.has_match = false; return matched_result; } - + matched_result.is_causal_mask |= IsCausalMaskPattern(mask); if (has_dropout) { - // Found BMM1 - Scale - (bias) - Mask - Softmax - dropout - BMM2 + // Found BMM1 - (Scale) - (bias) - Mask - Softmax - dropout - BMM2 matched_result.matched_custom_call_name = bias == nullptr ? kCudnnfMHAScaleMaskSoftmaxDropoutCallTarget : kCudnnfMHAScaleBiasMaskSoftmaxDropoutCallTarget; @@ -818,8 +925,16 @@ MatchBwdResult MatchBwdBmmSoftmaxDropoutBmm(MatchBwdResult previous_result, m::Broadcast(OptionalConvert( m::Constant().WithPredicate(IsScalar))), m::Op())))))))); + auto bwd_dropout_pattern_form_3 = OptionalConvert(m::MultiplyAnyOrder( + m::MultiplyAnyOrder( + m::Op().WithPredicate([&](const HloInstruction* instr) { + return instr == match_result.matched_bmm_2_grad_2; + }), + m::Broadcast(m::Constant().WithPredicate(IsScalar))), + m::Op())); auto bwd_dropout_pattern = m::AnyOf( - bwd_dropout_pattern_form_1, bwd_dropout_pattern_form_2); + bwd_dropout_pattern_form_1, bwd_dropout_pattern_form_2, + bwd_dropout_pattern_form_3); // Backward softmax pattern HloInstruction* bwd_softmax_input = nullptr; HloInstruction* exp_1; @@ -1011,6 +1126,8 @@ MatchBwdResult MatchBwdMHAPatternsForCanonicalization( StatusOr IsMHABlockSupported(HloInstruction* bmm_1, HloInstruction* bmm_2, bool need_canonicalization, bool is_training, + bool is_causal_mask, + bool& is_flash_attention, std::string& custom_call_name, const DebugOptions& debug_options) { if (MHACallHasDropout(custom_call_name) && @@ -1039,6 +1156,20 @@ StatusOr IsMHABlockSupported(HloInstruction* bmm_1, HloInstruction* bmm_2, return false; } + // check if matched attention block is supported by cuDNN flash attention. + TF_ASSIGN_OR_RETURN( + is_flash_attention, + IsFlashAttention(bmm_1, is_causal_mask, custom_call_name)); + if (is_flash_attention) { + if (is_causal_mask) { + // if bias is causal mask, needs to remove bias from name + custom_call_name = MHACallHasDropout(custom_call_name) + ? kCudnnfMHASoftmaxDropoutCallTarget + : kCudnnfMHASoftmaxCallTarget; + } + return true; + } + // otherwise check if it is supported by regular attention TF_ASSIGN_OR_RETURN(bool is_bmm1_supported, IsSupportedBMM1(bmm_1)); if (!is_bmm1_supported) return false; TF_ASSIGN_OR_RETURN(bool is_bmm2_supported, @@ -1106,28 +1237,27 @@ StatusOr ChangeCheckedDimToFastest( is_lhs ? lhs_minor_to_major_bmm : rhs_minor_to_major_bmm; CHECK_EQ(contracting_dims.size(), 1); - TF_ASSIGN_OR_RETURN(std::vector non_contracting_dim_nums_bmm, + TF_ASSIGN_OR_RETURN(std::vector non_contracting_dims, GetNonContractingDims(bmm->operand(bmm_operand)->shape(), batch_dims, contracting_dims)); - CHECK_EQ(non_contracting_dim_nums_bmm.size(), 1); + CHECK_EQ(non_contracting_dims.size(), 1); HloInstruction* operand_bmm = bmm->mutable_operand(bmm_operand); - std::vector contracting_dims_to_check{contracting_dims[0]}; - std::vector dims_to_set = should_contracting_be_fastest - ? contracting_dims_to_check - : non_contracting_dim_nums_bmm; - // If the dimension being checked(contracting or non-contracting) of the - // target operand is not the fastest moving dimension, make it so. - if (minor_to_major_to_check[0] != dims_to_set[0]) { + int64_t hidden_dim = should_contracting_be_fastest ? contracting_dims[0] + : non_contracting_dims[0]; + int64_t minor_dim = minor_to_major_to_check[0]; + // If the hidden dim of the target operand is not the fastest moving + // dimension, make it so. + if (minor_dim != hidden_dim) { std::vector perm(bmm->shape().dimensions_size()); std::iota(perm.begin(), perm.end(), 0); - std::swap(perm[dims_to_set[0]], perm[minor_to_major_to_check[0]]); + std::swap(perm[hidden_dim], perm[minor_dim]); if (is_lhs) { - new_dot_dims_bmm.set_lhs_contracting_dimensions( - 0, non_contracting_dim_nums_bmm[0]); + new_dot_dims_bmm.set_lhs_contracting_dimensions(0, + non_contracting_dims[0]); } else { - new_dot_dims_bmm.set_rhs_contracting_dimensions( - 0, non_contracting_dim_nums_bmm[0]); + new_dot_dims_bmm.set_rhs_contracting_dimensions(0, + non_contracting_dims[0]); } operand_bmm = comp->AddInstruction( @@ -1135,7 +1265,7 @@ StatusOr ChangeCheckedDimToFastest( ShapeUtil::MakeShapeWithDenseLayout( bmm->shape().element_type(), Permute(operand_bmm->shape().dimensions(), perm), - rhs_minor_to_major_bmm), + minor_to_major_to_check), operand_bmm, perm), &operand_bmm->metadata()); *((DynCast(bmm))->mutable_dot_dimension_numbers()) = @@ -1147,13 +1277,15 @@ StatusOr ChangeCheckedDimToFastest( StatusOr FuseFwdMultiHeadedAttentionBlock( HloComputation* comp, HloInstruction* bmm_1, HloInstruction* bmm_2, HloInstruction* bias, HloInstruction* mask, HloInstruction* scale, + HloInstruction* reduce_sum, HloInstruction* softmax_input, double dropout_rate, std::string& custom_call_name, stream_executor::CudaComputeCapability cc, bool is_training, bool& changed, - bool& v_transposed) { + bool& v_transposed, bool is_causal_mask, bool is_flash_attention) { double scale_value = 1.0; HloInstruction* lhs_bmm1; HloInstruction* rhs_bmm1; HloInstruction* rhs_bmm2; + DotDimensionNumbers bmm1dot = bmm_1->dot_dimension_numbers(); TF_ASSIGN_OR_RETURN(rhs_bmm1, ChangeCheckedDimToFastest( comp, bmm_1, false /*is_lhs*/, true /*should_contracting_be_fastest*/)); @@ -1174,7 +1306,8 @@ StatusOr FuseFwdMultiHeadedAttentionBlock( bmm_1->dot_dimension_numbers(); *fmha_config.mutable_bmm2_dot_dimension_numbers() = bmm_2->dot_dimension_numbers(); - + *((DynCast(bmm_1))->mutable_dot_dimension_numbers()) = + bmm1dot; TF_RET_CHECK((dropout_rate >= 0.0 && dropout_rate <= 1.0)); // If scale node is assigned, extract value from it. @@ -1205,16 +1338,28 @@ StatusOr FuseFwdMultiHeadedAttentionBlock( algorithm->set_is_cudnn_frontend(true); algorithm->mutable_workspace_size()->set_value(0); } + + // set is flash attention here + // choose to use flash attention or non-fa attention based on this flag. + fmha_config.set_is_flash_attention(is_flash_attention); + // set is_causal_mask here + // choose to generate causal mask inside cuDNN attention or not + fmha_config.set_is_causal_mask(is_causal_mask); + + // Output Order: {O, scratch, Fwd act*} const Shape& output_shape = bmm_2->shape(); Shape call_shape; // Activation output is used by backward gemm. HloInstruction* activation_output = nullptr; - std::vector output_shapes = {output_shape, - ShapeUtil::MakeShape(U8, {0})}; + std::vector output_shapes = { + output_shape, + ShapeUtil::MakeShape( + U8, {is_flash_attention + ? 16 + : 0})}; // reserved 2 int64 for dropout seed and offset if (is_training) { - // TODO Flush attention will have a different shape in training. activation_output = bmm_2->mutable_operand(0); // Sometimes activation output is bitcast, the actual activation is the // second user of the producer of bmm_2's first operand. @@ -1222,21 +1367,37 @@ StatusOr FuseFwdMultiHeadedAttentionBlock( activation_output->opcode() == HloOpcode::kBitcast) { HloInstruction* producer = activation_output->mutable_operand(0); TF_RET_CHECK(producer->user_count() == 2); - activation_output = producer->UserId(activation_output) == 0 - ? producer->users()[1] - : producer->users()[0]; + HloInstruction* bmm2_grad2_user = producer->UserId(activation_output) == 0 + ? producer->users()[1] + : producer->users()[0]; + // might be (transpose) - bmm2_grad2 + if (IsBatchedMatmul(bmm2_grad2_user)) { + activation_output = producer; + } else if (bmm2_grad2_user->opcode() == HloOpcode::kTranspose) { + activation_output = bmm2_grad2_user; + } else { + return InternalError("Unexpected activation patterns"); + } + } + // if it is flash attention, should output softmax stats to the bwd + if (is_flash_attention) { + TF_RET_CHECK(reduce_sum != nullptr); + output_shapes.push_back( + ShapeUtil::MakeShape(F32, reduce_sum->shape().dimensions())); + } else { + output_shapes.push_back(activation_output->shape()); } - output_shapes.push_back(activation_output->shape()); } call_shape = ShapeUtil::MakeTupleShape(output_shapes); + // Input Order: {Q, K, V, mask*, bias*} std::vector operands = {lhs_bmm1, rhs_bmm1, rhs_bmm2}; if (mask != nullptr) { HloInstruction* converted_mask = comp->AddInstruction( HloInstruction::CreateConvert(bmm_1->shape(), mask)); operands.push_back(converted_mask); } - if (bias != nullptr) { + if ((!is_flash_attention || !is_causal_mask) && bias != nullptr) { HloInstruction* original_bias; HloInstruction* original_broadcast; // There will be cases where the bias is up-casted to wider float type, @@ -1300,6 +1461,39 @@ StatusOr FuseFwdMultiHeadedAttentionBlock( return fmha_call; } +Status RematSoftmaxOutput(HloComputation* comp, HloInstruction* fwd_fmha_call, + HloInstruction* softmax_input) { + // if only flash fwd is matched and bwd is not matched, then we need to remat + // the real softmax output because flash fwd only output softmax stat tensor + // following computation recovers the softmax output + // s = sub(softmax_input, broadcast(softmax_stat)) + // r = exp(s) + // find the softmax stat tensor + HloInstruction* softmax_stat; + for (auto user : fwd_fmha_call->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement && + user->tuple_index() == 2) { + softmax_stat = user; + } + } + // should be able to find the softmax stat + TF_RET_CHECK(softmax_stat != nullptr); + auto broadcast = comp->AddInstruction(HloInstruction::CreateBroadcast( + softmax_input->shape(), softmax_stat, {0, 1, 2})); + auto sub = comp->AddInstruction(HloInstruction::CreateBinary( + softmax_input->shape(), HloOpcode::kSubtract, softmax_input, broadcast)); + auto exp = comp->AddInstruction(HloInstruction::CreateUnary( + softmax_input->shape(), HloOpcode::kExp, sub)); + // convert fp32 to bf16/fp16 + // we use datatype of Q tensor here + auto new_shape = ShapeUtil::ChangeElementType( + softmax_input->shape(), + fwd_fmha_call->operand(0)->shape().element_type()); + TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction( + softmax_stat, HloInstruction::CreateConvert(new_shape, exp))); + return OkStatus(); +} + bool IsDbiasOnlyUserBesidesGradGemm(HloInstruction* d_intermediate, HloInstruction* bmm_1_grad_1, HloInstruction* bmm_1_grad_2, @@ -1329,15 +1523,20 @@ StatusOr FuseBwdMultiHeadedAttentionBlock( HloComputation* comp, HloInstruction* bmm_1_grad_1, HloInstruction* bmm_1_grad_2, HloInstruction* bmm_2_grad_1, HloInstruction* bmm_2_grad_2, HloInstruction* fwd_fmha_call, - HloInstruction* d_intermediate, HloInstruction* mask, - std::string& bwd_custom_call_name, bool fwd_bmm_2_canonicalized, - bool is_bmm2_grad1_canonicalized) { + HloInstruction* d_intermediate, HloInstruction* mask, HloInstruction* bias, + std::string& bwd_custom_call_name) { HloInstruction* rhs_bmm1_grad_gemm1; HloInstruction* lhs_bmm1_grad_gemm2; HloInstruction* lhs_bmm2_grad_gemm1; HloInstruction* rhs_bmm2_grad_gemm2; HloInstruction* d_output_grad; + HloInstruction* fwd_act; + TF_ASSIGN_OR_RETURN(CudnnfMHABackendConfig fwd_config, + fwd_fmha_call->backend_config()); + bool is_flash_attention = fwd_config.is_flash_attention(); + bool is_causal_mask = fwd_config.is_causal_mask(); + CudnnfMHABackendConfig bwd_fmha_config; // Q tensor TF_ASSIGN_OR_RETURN( rhs_bmm1_grad_gemm1, @@ -1348,67 +1547,72 @@ StatusOr FuseBwdMultiHeadedAttentionBlock( lhs_bmm1_grad_gemm2, ChangeCheckedDimToFastest(comp, bmm_1_grad_2, false /*is_lhs*/, false /*should_contracting_be_fastest*/)); - // Forward activation + // P tensor TF_ASSIGN_OR_RETURN( lhs_bmm2_grad_gemm1, ChangeCheckedDimToFastest(comp, bmm_2_grad_1, true /*is_lhs*/, false /*should_contracting_be_fastest*/)); + + // Forward activation + // if it is not flash attention, fwd activation is the P tensor + // else it is the softmax_stats + if (fwd_config.is_flash_attention()) { + auto fwd_act_index = 2; + fwd_act = comp->AddInstruction(HloInstruction::CreateGetTupleElement( + fwd_fmha_call->shape().tuple_shapes(fwd_act_index), fwd_fmha_call, + fwd_act_index)); + } else { + fwd_act = lhs_bmm2_grad_gemm1; + } + // V tensor TF_ASSIGN_OR_RETURN( rhs_bmm2_grad_gemm2, ChangeCheckedDimToFastest(comp, bmm_2_grad_2, false /*is_lhs*/, true /*should_contracting_be_fastest*/)); - // d output + // d output to bmm2_grad2 // Since d_o is the input of 2 bmms, we set the dim number using the // constraint // -> the contracting dimension of the lhs of bmm_2_grad_2 needs to be the // fastest moving dimension. - TF_ASSIGN_OR_RETURN(d_output_grad, ChangeCheckedDimToFastest( - comp, bmm_2_grad_2, true /*is_lhs*/, - true /*check_contracting_dim*/)); - // Operand order {Q, K, V, Fwd act, d_o, mask*} + TF_ASSIGN_OR_RETURN( + d_output_grad, + ChangeCheckedDimToFastest(comp, bmm_2_grad_2, true /*is_lhs*/, + true /*should_contracting_be_fastest*/)); + // d output to bmm2_grad1 + // we don't use this value but we call this to make sure dot number is being + // set correctly + TF_ASSIGN_OR_RETURN( + HloInstruction * bmm_2_grad_1_rhs, + ChangeCheckedDimToFastest(comp, bmm_2_grad_1, false /*is_lhs*/, + false /*should_contracting_be_fastest*/)); + // Operand order: {Q, K, V, Fwd act, d_o, mask*, bias*, O*} std::vector operands = { - rhs_bmm1_grad_gemm1, lhs_bmm1_grad_gemm2, rhs_bmm2_grad_gemm2, - lhs_bmm2_grad_gemm1, d_output_grad}; + rhs_bmm1_grad_gemm1, lhs_bmm1_grad_gemm2, rhs_bmm2_grad_gemm2, fwd_act, + d_output_grad}; if (mask) { HloInstruction* converted_mask = comp->AddInstruction( HloInstruction::CreateConvert(bmm_2_grad_2->shape(), mask)); operands.push_back(converted_mask); } - TF_ASSIGN_OR_RETURN(CudnnfMHABackendConfig fwd_config, - fwd_fmha_call->backend_config()); - CudnnfMHABackendConfig bwd_fmha_config; - // If forward bmm_2 is canonicalized, the contracting dimension of lhs - // of bmm_2_grad_1 needs to be changed to the non-contracting dimension. - - if (fwd_bmm_2_canonicalized) { - TF_ASSIGN_OR_RETURN( - std::vector bmm_2_grad_1_lhs_non_contracting_dims, - GetNonContractingDims( - bmm_2_grad_1->shape(), - bmm_2_grad_1->dot_dimension_numbers().lhs_batch_dimensions(), - bmm_2_grad_1->dot_dimension_numbers() - .lhs_contracting_dimensions())); - CHECK_EQ(bmm_2_grad_1_lhs_non_contracting_dims.size(), 1); - (DynCast(bmm_2_grad_1)) - ->mutable_dot_dimension_numbers() - ->set_lhs_contracting_dimensions( - 0, bmm_2_grad_1_lhs_non_contracting_dims[0]); - } - - TF_ASSIGN_OR_RETURN( - std::vector bmm_2_grad_1_new_contracting_dims, - GetNonContractingDims( - bmm_2_grad_1->shape(), - bmm_2_grad_1->dot_dimension_numbers().rhs_batch_dimensions(), - bmm_2_grad_1->dot_dimension_numbers().rhs_contracting_dimensions())); - - if (is_bmm2_grad1_canonicalized) { - (DynCast(bmm_2_grad_1)) - ->mutable_dot_dimension_numbers() - ->set_rhs_contracting_dimensions(0, - bmm_2_grad_1_new_contracting_dims[0]); + // if is flash attention, add fwd output to input list + if (is_flash_attention) { + if (!is_causal_mask && bias) { + operands.push_back(bias); + } + HloInstruction* fwd_output; + for (auto user : fwd_fmha_call->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement && + user->tuple_index() == 0) { + fwd_output = user; + } + } + // should be able to find the instruction + TF_RET_CHECK(fwd_output != nullptr); + // check dO and O have the same layout as it is required by cuDNN + TF_RET_CHECK(fwd_output->shape() == d_output_grad->shape()); + operands.push_back(fwd_output); } *bwd_fmha_config.mutable_bmm1_grad_gemm1_dot_dimension_numbers() = @@ -1427,6 +1631,10 @@ StatusOr FuseBwdMultiHeadedAttentionBlock( // TODO Find a way to compute original seed from dropout keys. bwd_fmha_config.set_seed(fwd_config.seed()); + // Set is flash attention + bwd_fmha_config.set_is_flash_attention(is_flash_attention); + bwd_fmha_config.set_is_causal_mask(is_causal_mask); + *bwd_fmha_config.mutable_intermediate_tensor_shape() = fwd_config.intermediate_tensor_shape(); { @@ -1443,15 +1651,26 @@ StatusOr FuseBwdMultiHeadedAttentionBlock( } // Output order: - // dQ(bmm_1_grad_2), dK(bmm_1_grad_1), dV(bmm_2_grad_1), - // d_intermediate_tensor, d_bias_tensor + // {dQ(bmm_1_grad_2), dK(bmm_1_grad_1), dV(bmm_2_grad_1), + // d_intermediate_tensor*, softmax_sum*, d_Q_accum*, scratch, dbias*} std::vector output_shapes = { bmm_1_grad_2->shape(), bmm_1_grad_1->shape(), bmm_2_grad_1->shape()}; - // d_intermediate is required to be output - output_shapes.push_back(lhs_bmm2_grad_gemm1->shape()); - + if (!fwd_config.is_flash_attention()) { + output_shapes.push_back(lhs_bmm2_grad_gemm1->shape()); + } else { + // softmax_sum, d_Q_accum + // add softmax sum here and change the data type + // softmax sum and d_Q_accum should both be fp32 datatype + output_shapes.push_back( + ShapeUtil::MakeShape(F32, fwd_act->shape().dimensions())); + output_shapes.push_back( + ShapeUtil::MakeShape(F32, bmm_1_grad_2->shape().dimensions())); + } // Reserved placeholder for workspace - output_shapes.push_back(ShapeUtil::MakeShape(U8, {0})); + output_shapes.push_back(ShapeUtil::MakeShape( + U8, {is_flash_attention + ? 16 + : 0})); // reserved 2 int64 for dropout seed and offset HloInstruction* dbias = nullptr; if (d_intermediate) { @@ -1518,6 +1737,8 @@ StatusOr CudnnFusedMHARewriter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool any_changed = false; + // we use this set to keep track of all already matched attention block + absl::flat_hash_set matched_bmm1; for (HloComputation* comp : module->MakeNonfusionComputations(execution_threads)) { const DebugOptions& debug_options = @@ -1530,20 +1751,29 @@ StatusOr CudnnFusedMHARewriter::Run( } for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { bool v_transposed = false; + bool changed = false; MatchFwdResult matched_result = MatchFwdMHAPatternsForCanonicalization(instr); if (!matched_result.has_match) { continue; } + // flash attention TODO: dont fuse fwd if bwd is not supported // We check the validity of bmms here before canonicalization so we don't // modify the graph if mha fusion is not possible + // Relax 512 constraint if it is flash attention TF_ASSIGN_OR_RETURN( bool is_mha_module_supported, IsMHABlockSupported( matched_result.matched_bmm_1, matched_result.matched_bmm_2, matched_result.need_canonicalization, matched_result.is_training, + matched_result.is_causal_mask, matched_result.is_flash_attention, matched_result.matched_custom_call_name, debug_options)); + if (!is_mha_module_supported) continue; + // We make sure no attention block is matched and replaced twice here + if (matched_bmm1.find(matched_result.matched_bmm_1) != matched_bmm1.end()) + continue; + matched_bmm1.insert(matched_result.matched_bmm_1); // If we need to canonicalize the bmm, we will assign the newly // canonicalized bmm to bmm_2. if (matched_result.need_canonicalization) { @@ -1551,7 +1781,7 @@ StatusOr CudnnFusedMHARewriter::Run( CanonicalizeBatchedGemmForcuDNNFMHA( matched_result.matched_bmm_2, comp)); } - bool changed = false; + // Fuse the bmms and intermediate nodes into fMHA call, the fused call // will replace bmm_2. TF_ASSIGN_OR_RETURN( @@ -1559,11 +1789,14 @@ StatusOr CudnnFusedMHARewriter::Run( FuseFwdMultiHeadedAttentionBlock( comp, matched_result.matched_bmm_1, matched_result.matched_bmm_2, matched_result.matched_bias, matched_result.matched_mask, - matched_result.matched_scale, matched_result.matched_dropout_rate, + matched_result.matched_scale, matched_result.matched_reduce_sum, + matched_result.matched_softmax_input, + matched_result.matched_dropout_rate, matched_result.matched_custom_call_name, compute_capability_, - matched_result.is_training, changed, v_transposed)); + matched_result.is_training, changed, v_transposed, + matched_result.is_causal_mask, + matched_result.is_flash_attention)); any_changed |= changed; - if (matched_result.is_training) { // if fwd uses mask input, then bwd needs cudnn 8.9.1 to take in a mask // input if cudnn version < 8.9.1 we won't lower the bwd pass @@ -1578,6 +1811,16 @@ StatusOr CudnnFusedMHARewriter::Run( fwd_fmha_call, matched_result.matched_bmm_1, matched_result.matched_mask, v_transposed); if (!matched_bwd_result.has_match) { + if (matched_result.is_flash_attention) { + // if only flash fwd is matched but bwd is not matched, we need to + // remat the softmax output from softmax stat for bwd to user, + // otherwise bwd will fail. if both flash fwd and bwd is matched, + // don't do this because flash bwd will remat itself. + TF_RETURN_IF_ERROR(RematSoftmaxOutput( + comp, fwd_fmha_call, matched_result.matched_softmax_input)); + VLOG(2) << "Only flash attention fwd is matched, rematerialize " + "softmax output for bwd."; + } continue; } // check if dbias is the only user of d_intermediate besides @@ -1630,10 +1873,8 @@ StatusOr CudnnFusedMHARewriter::Run( matched_bwd_result.matched_bmm_2_grad_1, matched_bwd_result.matched_bmm_2_grad_2, fwd_fmha_call, matched_bwd_result.matched_d_intermediate, - matched_result.matched_mask, - matched_bwd_result.matched_custom_call_name, - matched_result.need_canonicalization, - matched_bwd_result.bmm_2_grad_1_need_canonicalization)); + matched_result.matched_mask, matched_result.matched_bias, + matched_bwd_result.matched_custom_call_name)); any_changed |= changed; } } diff --git a/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc b/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc index 1545b4a0e39e3f..37962863c34504 100644 --- a/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc +++ b/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc @@ -2893,6 +2893,693 @@ ENTRY main.146 { EXPECT_NEAR(config.dropout_rate(), 0.1, 1e-2); } + +// flash attention +TEST_F(CudnnFusedMhaRewriterTestHloTest, + BF16TrainingBmm1CausalMaskSoftmaxBmm2Pattern) { + const char* module_str = R"( +HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,128,2048]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0})->(bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} + +region_0.32 { + Arg_0.33 = bf16[] parameter(0) + Arg_1.34 = bf16[] parameter(1) + ROOT maximum = bf16[] maximum(Arg_0.33, Arg_1.34) +} + +region_1.44 { + Arg_0.45 = f32[] parameter(0) + Arg_1.46 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0.45, Arg_1.46) +} + +region_2.66 { + Arg_0.67 = bf16[] parameter(0) + Arg_1.68 = bf16[] parameter(1) + ROOT add.1 = bf16[] add(Arg_0.67, Arg_1.68) +} + +ENTRY main.92 { + Arg_0.1 = bf16[2,6,2048,128]{3,2,1,0} parameter(0), sharding={replicated} + Arg_1.2 = bf16[2,6,128,2048]{3,2,1,0} parameter(1), sharding={replicated} + dot.14 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + constant.17 = bf16[] constant(2) + broadcast.29 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(constant.17), dimensions={} + multiply.2 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.14, broadcast.29) + iota.2 = s32[2048,2048]{1,0} iota(), iota_dimension=0 + iota.5 = s32[2048,2048]{1,0} iota(), iota_dimension=1 + compare.1 = pred[2048,2048]{1,0} compare(iota.2, iota.5), direction=LT + constant.6 = bf16[] constant(-2.366e+38) + broadcast.16 = bf16[2048,2048]{1,0} broadcast(constant.6), dimensions={} + constant.16 = bf16[] constant(0) + broadcast.17 = bf16[2048,2048]{1,0} broadcast(constant.16), dimensions={} + select.2 = bf16[2048,2048]{1,0} select(compare.1, broadcast.16, broadcast.17) + broadcast.19 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(select.2), dimensions={2,3} + add.3 = bf16[2,6,2048,2048]{3,2,1,0} add(multiply.2, broadcast.19) + constant.10 = bf16[] constant(-inf) + reduce.36 = bf16[2,6,2048]{2,1,0} reduce(add.3, constant.10), dimensions={3}, to_apply=region_0.32 + broadcast.21 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reduce.36), dimensions={0,1,2} + subtract.1 = bf16[2,6,2048,2048]{3,2,1,0} subtract(add.3, broadcast.21) + exponential.1 = bf16[2,6,2048,2048]{3,2,1,0} exponential(subtract.1) + convert.5 = f32[2,6,2048,2048]{3,2,1,0} convert(exponential.1) + constant.14 = f32[] constant(0) + reduce.48 = f32[2,6,2048]{2,1,0} reduce(convert.5, constant.14), dimensions={3}, to_apply=region_1.44 + convert.9 = bf16[2,6,2048]{2,1,0} convert(reduce.48) + broadcast.32 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2} + divide.5 = bf16[2,6,2048,2048]{3,2,1,0} divide(exponential.1, broadcast.32) + Arg_2.3 = bf16[2,6,2048,128]{3,2,1,0} parameter(2), sharding={replicated} + dot.57 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + Arg_3.4 = bf16[2,6,2048,128]{3,2,1,0} parameter(3), sharding={replicated} + dot.60 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + divide.4 = bf16[2,6,2048,2048]{3,2,1,0} divide(dot.60, broadcast.32) + constant.15 = bf16[] constant(1) + broadcast.25 = bf16[2,6,2048]{2,1,0} broadcast(constant.15), dimensions={} + multiply.3 = bf16[2,6,2048]{2,1,0} multiply(convert.9, convert.9) + divide.3 = bf16[2,6,2048]{2,1,0} divide(broadcast.25, multiply.3) + broadcast.26 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2} + multiply.4 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.60, broadcast.26) + multiply.5 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.4, exponential.1) + reduce.70 = bf16[2,6,2048]{2,1,0} reduce(multiply.5, constant.16), dimensions={3}, to_apply=region_2.66 + negate.2 = bf16[2,6,2048]{2,1,0} negate(reduce.70) + broadcast.31 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2} + add.5 = bf16[2,6,2048,2048]{3,2,1,0} add(divide.4, broadcast.31) + multiply.8 = bf16[2,6,2048,2048]{3,2,1,0} multiply(add.5, exponential.1) + multiply.9 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.8, broadcast.29) + dot.90 = bf16[2,6,2048,128]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + dot = bf16[2,6,128,2048]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + dot.1 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + ROOT tuple.91 = (bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}) tuple(dot.57, dot.90, dot, dot.1) +} + +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + CudnnFusedMHARewriter fusedMhaRewriter{ + GetCudaComputeCapability(), + GetCudnnVersionWithDbiasAndMaskBwdInputSupport()}; + TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); + HloDCE dce; + TF_ASSERT_OK(RunHloPass(&dce, m.get()).status()); + + ComputationLayout computation_layout( + m->entry_computation()->ComputeProgramShape()); + + const HloInstruction* fmha; + SCOPED_TRACE(m->ToString()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + m::GetTupleElement( + m::CustomCall(&fmha, {kCudnnfMHASoftmaxCallTarget}), 0) + .WithShape(BF16, {2, 6, 2048, 128}), + m::GetTupleElement( + m::CustomCall(&fmha, {kCudnnfMHASoftmaxBackwardCallTarget}), 0) + .WithShape(BF16, {2, 6, 2048, 128}), + m::Transpose( + m::GetTupleElement( + m::CustomCall({kCudnnfMHASoftmaxBackwardCallTarget}), 1)) + .WithShape(BF16, {2, 6, 128, 2048}), + m::GetTupleElement( + m::CustomCall({kCudnnfMHASoftmaxBackwardCallTarget}), 2) + .WithShape(BF16, {2, 6, 2048, 128})))); + TF_ASSERT_OK_AND_ASSIGN(auto config, + fmha->backend_config()); + EXPECT_EQ(fmha->operands().size(), 6); + EXPECT_NEAR(config.dropout_rate(), 0, 1e-2); + EXPECT_EQ(config.is_flash_attention(), true); + EXPECT_EQ(config.is_causal_mask(), true); +} + +TEST_F(CudnnFusedMhaRewriterTestHloTest, + BF16TrainingBmm1BiasSoftmaxBmm2Pattern) { + const char* module_str = R"( +HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,128,2048]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,2048,2048]{3,2,1,0})->(bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} + +region_0.32 { + Arg_0.33 = bf16[] parameter(0) + Arg_1.34 = bf16[] parameter(1) + ROOT maximum = bf16[] maximum(Arg_0.33, Arg_1.34) +} + +region_1.44 { + Arg_0.45 = f32[] parameter(0) + Arg_1.46 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0.45, Arg_1.46) +} + +region_2.66 { + Arg_0.67 = bf16[] parameter(0) + Arg_1.68 = bf16[] parameter(1) + ROOT add.1 = bf16[] add(Arg_0.67, Arg_1.68) +} + +ENTRY main.92 { + Arg_0.1 = bf16[2,6,2048,128]{3,2,1,0} parameter(0), sharding={replicated} + Arg_1.2 = bf16[2,6,128,2048]{3,2,1,0} parameter(1), sharding={replicated} + dot.14 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + constant.17 = bf16[] constant(2) + broadcast.29 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(constant.17), dimensions={} + multiply.2 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.14, broadcast.29) + // bias + Arg_4.5 = bf16[2,6,2048,2048]{3,2,1,0} parameter(4), sharding={replicated} + add.3 = bf16[2,6,2048,2048]{3,2,1,0} add(multiply.2, Arg_4.5) + constant.10 = bf16[] constant(-inf) + constant.16 = bf16[] constant(0) + reduce.36 = bf16[2,6,2048]{2,1,0} reduce(add.3, constant.10), dimensions={3}, to_apply=region_0.32 + broadcast.21 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reduce.36), dimensions={0,1,2} + subtract.1 = bf16[2,6,2048,2048]{3,2,1,0} subtract(add.3, broadcast.21) + exponential.1 = bf16[2,6,2048,2048]{3,2,1,0} exponential(subtract.1) + convert.5 = f32[2,6,2048,2048]{3,2,1,0} convert(exponential.1) + constant.14 = f32[] constant(0) + reduce.48 = f32[2,6,2048]{2,1,0} reduce(convert.5, constant.14), dimensions={3}, to_apply=region_1.44 + convert.9 = bf16[2,6,2048]{2,1,0} convert(reduce.48) + broadcast.32 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2} + divide.5 = bf16[2,6,2048,2048]{3,2,1,0} divide(exponential.1, broadcast.32) + Arg_2.3 = bf16[2,6,2048,128]{3,2,1,0} parameter(2), sharding={replicated} + dot.57 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + Arg_3.4 = bf16[2,6,2048,128]{3,2,1,0} parameter(3), sharding={replicated} + dot.60 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + divide.4 = bf16[2,6,2048,2048]{3,2,1,0} divide(dot.60, broadcast.32) + constant.15 = bf16[] constant(1) + broadcast.25 = bf16[2,6,2048]{2,1,0} broadcast(constant.15), dimensions={} + multiply.3 = bf16[2,6,2048]{2,1,0} multiply(convert.9, convert.9) + divide.3 = bf16[2,6,2048]{2,1,0} divide(broadcast.25, multiply.3) + broadcast.26 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2} + multiply.4 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.60, broadcast.26) + multiply.5 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.4, exponential.1) + reduce.70 = bf16[2,6,2048]{2,1,0} reduce(multiply.5, constant.16), dimensions={3}, to_apply=region_2.66 + negate.2 = bf16[2,6,2048]{2,1,0} negate(reduce.70) + broadcast.31 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2} + add.5 = bf16[2,6,2048,2048]{3,2,1,0} add(divide.4, broadcast.31) + multiply.8 = bf16[2,6,2048,2048]{3,2,1,0} multiply(add.5, exponential.1) + multiply.9 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.8, broadcast.29) + dot.90 = bf16[2,6,2048,128]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + dot = bf16[2,6,128,2048]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + dot.1 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + ROOT tuple.91 = (bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}) tuple(dot.57, dot.90, dot, dot.1) +} + +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + CudnnFusedMHARewriter fusedMhaRewriter{ + GetCudaComputeCapability(), + GetCudnnVersionWithDbiasAndMaskBwdInputSupport()}; + TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); + HloDCE dce; + TF_ASSERT_OK(RunHloPass(&dce, m.get()).status()); + + ComputationLayout computation_layout( + m->entry_computation()->ComputeProgramShape()); + + const HloInstruction* fmha; + + SCOPED_TRACE(m->ToString()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + m::GetTupleElement( + m::CustomCall(&fmha, {kCudnnfMHAScaleBiasSoftmaxCallTarget}), 0) + .WithShape(BF16, {2, 6, 2048, 128}), + m::GetTupleElement( + m::CustomCall(&fmha, + {kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}), + 0) + .WithShape(BF16, {2, 6, 2048, 128}), + m::Transpose( + m::GetTupleElement( + m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}), + 1)) + .WithShape(BF16, {2, 6, 128, 2048}), + m::GetTupleElement( + m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}), 2) + .WithShape(BF16, {2, 6, 2048, 128})))); + TF_ASSERT_OK_AND_ASSIGN(auto config, + fmha->backend_config()); + EXPECT_EQ(fmha->operands().size(), 7); + EXPECT_NEAR(config.dropout_rate(), 0, 1e-2); + EXPECT_EQ(config.is_flash_attention(), true); + EXPECT_EQ(config.is_causal_mask(), false); +} + +// GPT3 pattern +TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16TrainingGPT3) { + const char* module_str = R"( +HloModule jit__unnamed_wrapped_function_, entry_computation_layout={((s32[], bf16[4,2048,768]{2,1,0}, bf16[12,3072]{1,0}, bf16[12,768,3072]{2,1,0}, bf16[12,768]{1,0}, bf16[12,3072,768]{2,1,0}, bf16[12,768]{1,0}, bf16[12,768]{1,0}, bf16[12,768]{1,0}, bf16[12,768]{1,0}, bf16[12,3,12,64]{3,2,1,0}, bf16[12,3,768,12,64]{4,3,2,1,0}, bf16[12,768]{1,0}, bf16[12,768,12,64]{3,2,1,0}, bf16[12,3072]{1,0}, bf16[12,768,3072]{2,1,0}, bf16[12,3072,768]{2,1,0}, bf16[12,768]{1,0}, bf16[12,768]{1,0}, bf16[12,768]{1,0}, bf16[12,768]{1,0}, bf16[12,3,12,64]{3,2,1,0}, bf16[12,3,768,12,64]{4,3,2,1,0}, bf16[12,768]{1,0}, bf16[12,768,12,64]{3,2,1,0}, bf16[12,4,2048,768]{3,2,1,0}, bf16[4,1,2048,2048]{3,2,0,1}, bf16[4,2048]{1,0}))->(s32[], bf16[4,2048,768]{2,1,0}, bf16[12,3072]{1,0}, bf16[12,768,3072]{2,1,0}, bf16[12,768]{1,0}, bf16[12,3072,768]{2,1,0}, bf16[12,768]{1,0}, bf16[12,768]{1,0}, bf16[12,768]{1,0}, bf16[12,768]{1,0}, bf16[12,3,12,64]{3,2,1,0}, bf16[12,3,768,12,64]{4,3,2,1,0}, bf16[12,768]{1,0}, bf16[12,768,12,64]{3,2,1,0}, bf16[12,3072]{1,0}, bf16[12,768,3072]{2,1,0}, bf16[12,3072,768]{2,1,0}, bf16[12,768]{1,0}, bf16[12,768]{1,0}, bf16[12,768]{1,0}, bf16[12,768]{1,0}, bf16[12,3,12,64]{3,2,1,0}, bf16[12,3,768,12,64]{4,3,2,1,0}, bf16[12,768]{1,0}, bf16[12,768,12,64]{3,2,1,0}, bf16[12,4,2048,768]{3,2,1,0}, bf16[4,1,2048,2048]{3,2,0,1}, bf16[4,2048]{1,0})} + +region_8.643 { + Arg_0.644 = f32[] parameter(0) + Arg_1.645 = f32[] parameter(1) + ROOT add.646 = f32[] add(Arg_0.644, Arg_1.645) +} + +region_23.860 { + Arg_0.861 = bf16[] parameter(0) + Arg_1.862 = bf16[] parameter(1) + ROOT add.863 = bf16[] add(Arg_0.861, Arg_1.862) +} + +region_33.931 { + Arg_0.932 = f32[] parameter(0) + Arg_1.933 = f32[] parameter(1) + ROOT maximum.934 = f32[] maximum(Arg_0.932, Arg_1.933) +} + +ENTRY main.92 { + arg_tuple.1060 = (s32[], bf16[4,2048,768]{2,1,0}, bf16[12,3072]{1,0}, bf16[12,768,3072]{2,1,0}, bf16[12,768]{1,0}, /*index=5*/bf16[12,3072,768]{2,1,0}, bf16[12,768]{1,0}, bf16[12,768]{1,0}, bf16[12,768]{1,0}, bf16[12,768]{1,0}, /*index=10*/bf16[12,3,12,64]{3,2,1,0}, bf16[12,3,768,12,64]{4,3,2,1,0}, bf16[12,768]{1,0}, bf16[12,768,12,64]{3,2,1,0}, bf16[12,3072]{1,0}, /*index=15*/bf16[12,768,3072]{2,1,0}, bf16[12,3072,768]{2,1,0}, bf16[12,768]{1,0}, bf16[12,768]{1,0}, bf16[12,768]{1,0}, /*index=20*/bf16[12,768]{1,0}, bf16[12,3,12,64]{3,2,1,0}, bf16[12,3,768,12,64]{4,3,2,1,0}, bf16[12,768]{1,0}, bf16[12,768,12,64]{3,2,1,0}, /*index=25*/bf16[12,4,2048,768]{3,2,1,0}, bf16[4,1,2048,2048]{3,2,0,1}, bf16[4,2048]{1,0}) parameter(0) + get-tuple-element.1061 = s32[] get-tuple-element(arg_tuple.1060), index=0 + constant.1121 = s32[] constant(1) + add.1568 = s32[] add(get-tuple-element.1061, constant.1121) + get-tuple-element.1062 = bf16[4,2048,768]{2,1,0} get-tuple-element(arg_tuple.1060), index=1 + get-tuple-element.1083 = bf16[12,3,768,12,64]{4,3,2,1,0} get-tuple-element(arg_tuple.1060), index=22 + constant.178 = s32[] constant(11) + subtract.6 = s32[] subtract(constant.178, get-tuple-element.1061) + constant.1120 = s32[] constant(0) + compare.1161 = pred[] compare(subtract.6, constant.1120), direction=LT + constant.205 = s32[] constant(23) + subtract.10 = s32[] subtract(constant.205, get-tuple-element.1061) + select.1163 = s32[] select(compare.1161, subtract.10, subtract.6) + dynamic-slice.1164 = bf16[1,3,768,12,64]{4,3,2,1,0} dynamic-slice(get-tuple-element.1083, select.1163, constant.1120, constant.1120, constant.1120, /*index=5*/constant.1120), dynamic_slice_sizes={1,3,768,12,64} + reshape.1165 = bf16[3,768,12,64]{3,2,1,0} reshape(dynamic-slice.1164) + transpose.12 = bf16[3,12,64,768]{2,1,3,0} transpose(reshape.1165), dimensions={0,2,3,1} + reshape.35 = bf16[2304,768]{1,0} reshape(transpose.12) + get-tuple-element.1086 = bf16[12,4,2048,768]{3,2,1,0} get-tuple-element(arg_tuple.1060), index=25 + dynamic-slice.1178 = bf16[1,4,2048,768]{3,2,1,0} dynamic-slice(get-tuple-element.1086, select.1163, constant.1120, constant.1120, constant.1120), dynamic_slice_sizes={1,4,2048,768} + reshape.1179 = bf16[4,2048,768]{2,1,0} reshape(dynamic-slice.1178) + convert.1196 = f32[4,2048,768]{2,1,0} convert(reshape.1179) + constant.1117 = f32[] constant(0) + reduce.1197 = f32[4,2048]{1,0} reduce(convert.1196, constant.1117), dimensions={2}, to_apply=region_8.643 + constant.41 = f32[] constant(0.00130208337) + broadcast.367 = f32[4,2048]{1,0} broadcast(constant.41), dimensions={} + multiply.44 = f32[4,2048]{1,0} multiply(reduce.1197, broadcast.367) + broadcast.1211 = f32[4,2048,768]{2,1,0} broadcast(multiply.44), dimensions={0,1} + subtract.1212 = f32[4,2048,768]{2,1,0} subtract(convert.1196, broadcast.1211) + multiply.1204 = f32[4,2048,768]{2,1,0} multiply(subtract.1212, subtract.1212) + reduce.1206 = f32[4,2048]{1,0} reduce(multiply.1204, constant.1117), dimensions={2}, to_apply=region_8.643 + multiply.45 = f32[4,2048]{1,0} multiply(reduce.1206, broadcast.367) + constant.1111 = f32[] constant(1e-05) + broadcast.469 = f32[4,2048]{1,0} broadcast(constant.1111), dimensions={} + add.92 = f32[4,2048]{1,0} add(multiply.45, broadcast.469) + reshape.779 = f32[4,2048,1]{1,0,2} reshape(add.92) + rsqrt.1214 = f32[4,2048,1]{1,0,2} rsqrt(reshape.779) + reshape.1218 = f32[4,2048]{1,0} reshape(rsqrt.1214) + broadcast.1219 = f32[4,2048,768]{2,1,0} broadcast(reshape.1218), dimensions={0,1} + multiply.1220 = f32[4,2048,768]{2,1,0} multiply(subtract.1212, broadcast.1219) + convert.1221 = bf16[4,2048,768]{2,1,0} convert(multiply.1220) + get-tuple-element.1081 = bf16[12,768]{1,0} get-tuple-element(arg_tuple.1060), index=20 + dynamic-slice.1155 = bf16[1,768]{1,0} dynamic-slice(get-tuple-element.1081, select.1163, constant.1120), dynamic_slice_sizes={1,768} + constant.1090 = bf16[] constant(1) + broadcast.370 = bf16[1,768]{1,0} broadcast(constant.1090), dimensions={} + add.44 = bf16[1,768]{1,0} add(dynamic-slice.1155, broadcast.370) + reshape.1225 = bf16[768]{0} reshape(add.44) + broadcast.1226 = bf16[4,2048,768]{2,1,0} broadcast(reshape.1225), dimensions={2} + multiply.1227 = bf16[4,2048,768]{2,1,0} multiply(convert.1221, broadcast.1226) + get-tuple-element.1080 = bf16[12,768]{1,0} get-tuple-element(arg_tuple.1060), index=19 + dynamic-slice.1151 = bf16[1,768]{1,0} dynamic-slice(get-tuple-element.1080, select.1163, constant.1120), dynamic_slice_sizes={1,768} + reshape.1230 = bf16[768]{0} reshape(dynamic-slice.1151) + broadcast.1231 = bf16[4,2048,768]{2,1,0} broadcast(reshape.1230), dimensions={2} + add.1232 = bf16[4,2048,768]{2,1,0} add(multiply.1227, broadcast.1231) + transpose.13 = bf16[768,4,2048]{0,2,1} transpose(add.1232), dimensions={2,0,1} + reshape.36 = bf16[768,8192]{0,1} reshape(transpose.13) + dot.6 = bf16[2304,8192]{1,0} dot(reshape.35, reshape.36), lhs_contracting_dims={1}, rhs_contracting_dims={0} + reshape.37 = bf16[3,12,64,4,2048]{4,2,1,3,0} reshape(dot.6) + get-tuple-element.1082 = bf16[12,3,12,64]{3,2,1,0} get-tuple-element(arg_tuple.1060), index=21 + dynamic-slice.1160 = bf16[1,3,12,64]{3,2,1,0} dynamic-slice(get-tuple-element.1082, select.1163, constant.1120, constant.1120, constant.1120), dynamic_slice_sizes={1,3,12,64} + reshape.1237 = bf16[3,12,64]{2,1,0} reshape(dynamic-slice.1160) + broadcast.372 = bf16[3,12,64,4,2048]{4,2,1,3,0} broadcast(reshape.1237), dimensions={0,1,2} + add.45 = bf16[3,12,64,4,2048]{4,2,1,3,0} add(reshape.37, broadcast.372) + transpose.67 = bf16[3,4,2048,12,64]{2,4,3,1,0} transpose(add.45), dimensions={0,3,4,1,2} + // V + slice.1244 = bf16[1,4,2048,12,64]{2,4,3,1,0} slice(transpose.67), slice={[2:3], [0:4], [0:2048], [0:12], [0:64]} + reshape.1245 = bf16[4,2048,12,64]{1,3,2,0} reshape(slice.1244) + transpose.16 = bf16[4,12,64,2048]{3,2,1,0} transpose(reshape.1245), dimensions={0,2,3,1} + // Q + slice.1240 = bf16[1,4,2048,12,64]{2,4,3,1,0} slice(transpose.67), slice={[0:1], [0:4], [0:2048], [0:12], [0:64]} + constant.1105 = bf16[] constant(0.125) + broadcast.374 = bf16[1,4,2048,12,64]{2,4,3,1,0} broadcast(constant.1105), dimensions={} + multiply.42 = bf16[1,4,2048,12,64]{2,4,3,1,0} multiply(slice.1240, broadcast.374) + reshape.458 = bf16[4,2048,12,64]{1,3,2,0} reshape(multiply.42) + transpose.14 = bf16[4,12,2048,64]{2,3,1,0} transpose(reshape.458), dimensions={0,2,1,3} + copy = bf16[4,12,2048,64]{3,2,1,0} copy(transpose.14) + // K + slice.1242 = bf16[1,4,2048,12,64]{2,4,3,1,0} slice(transpose.67), slice={[1:2], [0:4], [0:2048], [0:12], [0:64]} + reshape.1243 = bf16[4,2048,12,64]{1,3,2,0} reshape(slice.1242) + transpose.15 = bf16[4,12,64,2048]{3,2,1,0} transpose(reshape.1243), dimensions={0,2,3,1} + // Q K -> S + dot.7 = bf16[4,12,2048,2048]{3,2,1,0} dot(copy, transpose.15), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + convert.1251 = f32[4,12,2048,2048]{3,2,1,0} convert(dot.7) + get-tuple-element.1087 = bf16[4,1,2048,2048]{3,2,0,1} get-tuple-element(arg_tuple.1060), index=26 + // causal mask + iota.6 = s32[2048,2048]{1,0} iota(), iota_dimension=0 + iota.9 = s32[2048,2048]{1,0} iota(), iota_dimension=1 + compare.1188 = pred[2048,2048]{1,0} compare(iota.6, iota.9), direction=LT + constant.1118 = bf16[] constant(-2.366e+38) + broadcast.1119 = bf16[2048,2048]{1,0} broadcast(constant.1118), dimensions={} + constant.1089 = bf16[] constant(0) + broadcast.284 = bf16[2048,2048]{1,0} broadcast(constant.1089), dimensions={} + select.168 = bf16[2048,2048]{1,0} select(compare.1188, broadcast.1119, broadcast.284) + broadcast.1194 = bf16[4,1,2048,2048]{3,2,0,1} broadcast(select.168), dimensions={2,3} + minimum.1195 = bf16[4,1,2048,2048]{3,2,0,1} minimum(get-tuple-element.1087, broadcast.1194) + reshape.1247 = bf16[4,2048,2048]{2,1,0} reshape(minimum.1195) + convert.96 = f32[4,2048,2048]{2,1,0} convert(reshape.1247) + broadcast.1255 = f32[4,12,2048,2048]{3,2,1,0} broadcast(convert.96), dimensions={0,2,3} + add.1256 = f32[4,12,2048,2048]{3,2,1,0} add(convert.1251, broadcast.1255) + // softmax + constant.1104 = f32[] constant(-inf) + reduce.1257 = f32[4,12,2048]{2,1,0} reduce(add.1256, constant.1104), dimensions={3}, to_apply=region_33.931 + broadcast.1261 = f32[4,12,2048,2048]{3,2,1,0} broadcast(reduce.1257), dimensions={0,1,2} + subtract.1262 = f32[4,12,2048,2048]{3,2,1,0} subtract(add.1256, broadcast.1261) + exponential.1263 = f32[4,12,2048,2048]{3,2,1,0} exponential(subtract.1262) + reduce.1264 = f32[4,12,2048]{2,1,0} reduce(exponential.1263, constant.1117), dimensions={3}, to_apply=region_8.643 + broadcast.1268 = f32[4,12,2048,2048]{3,2,1,0} broadcast(reduce.1264), dimensions={0,1,2} + divide.1269 = f32[4,12,2048,2048]{3,2,1,0} divide(exponential.1263, broadcast.1268) + convert.1272 = bf16[4,12,2048,2048]{3,2,1,0} convert(divide.1269) + // V P -> O + dot.8 = bf16[4,12,64,2048]{3,2,1,0} dot(transpose.16, convert.1272), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + transpose.18 = bf16[4,2048,64,12]{1,2,3,0} transpose(dot.8), dimensions={0,3,2,1} + reshape.44 = bf16[8192,768]{1,0} reshape(transpose.18) + get-tuple-element.1085 = bf16[12,768,12,64]{3,2,1,0} get-tuple-element(arg_tuple.1060), index=24 + dynamic-slice.1173 = bf16[1,768,12,64]{3,2,1,0} dynamic-slice(get-tuple-element.1085, select.1163, constant.1120, constant.1120, constant.1120), dynamic_slice_sizes={1,768,12,64} + reshape.1174 = bf16[768,12,64]{2,1,0} reshape(dynamic-slice.1173) + transpose.19 = bf16[64,12,768]{0,1,2} transpose(reshape.1174), dimensions={2,1,0} + reshape.45 = bf16[768,768]{1,0} reshape(transpose.19) + dot.9 = bf16[8192,768]{1,0} dot(reshape.44, reshape.45), lhs_contracting_dims={1}, rhs_contracting_dims={0} + get-tuple-element.1084 = bf16[12,768]{1,0} get-tuple-element(arg_tuple.1060), index=23 + dynamic-slice.1169 = bf16[1,768]{1,0} dynamic-slice(get-tuple-element.1084, select.1163, constant.1120), dynamic_slice_sizes={1,768} + reshape.1279 = bf16[768]{0} reshape(dynamic-slice.1169) + broadcast.383 = bf16[8192,768]{1,0} broadcast(reshape.1279), dimensions={1} + add.46 = bf16[8192,768]{1,0} add(dot.9, broadcast.383) + reshape.462 = bf16[4,2048,768]{2,1,0} reshape(add.46) + add.1283 = bf16[4,2048,768]{2,1,0} add(reshape.462, reshape.1179) + convert.1285 = f32[4,2048,768]{2,1,0} convert(add.1283) + reduce.1286 = f32[4,2048]{1,0} reduce(convert.1285, constant.1117), dimensions={2}, to_apply=region_8.643 + multiply.46 = f32[4,2048]{1,0} multiply(reduce.1286, broadcast.367) + broadcast.1291 = f32[4,2048,768]{2,1,0} broadcast(multiply.46), dimensions={0,1} + subtract.1292 = f32[4,2048,768]{2,1,0} subtract(convert.1285, broadcast.1291) + multiply.1293 = f32[4,2048,768]{2,1,0} multiply(subtract.1292, subtract.1292) + reduce.1295 = f32[4,2048]{1,0} reduce(multiply.1293, constant.1117), dimensions={2}, to_apply=region_8.643 + multiply.47 = f32[4,2048]{1,0} multiply(reduce.1295, broadcast.367) + add.93 = f32[4,2048]{1,0} add(multiply.47, broadcast.469) + reshape.785 = f32[4,2048,1]{1,0,2} reshape(add.93) + rsqrt.1303 = f32[4,2048,1]{1,0,2} rsqrt(reshape.785) + reshape.1307 = f32[4,2048]{1,0} reshape(rsqrt.1303) + broadcast.1308 = f32[4,2048,768]{2,1,0} broadcast(reshape.1307), dimensions={0,1} + multiply.1309 = f32[4,2048,768]{2,1,0} multiply(subtract.1292, broadcast.1308) + convert.1310 = bf16[4,2048,768]{2,1,0} convert(multiply.1309) + get-tuple-element.1079 = bf16[12,768]{1,0} get-tuple-element(arg_tuple.1060), index=18 + dynamic-slice.1146 = bf16[1,768]{1,0} dynamic-slice(get-tuple-element.1079, select.1163, constant.1120), dynamic_slice_sizes={1,768} + add.47 = bf16[1,768]{1,0} add(dynamic-slice.1146, broadcast.370) + reshape.1314 = bf16[768]{0} reshape(add.47) + broadcast.1315 = bf16[4,2048,768]{2,1,0} broadcast(reshape.1314), dimensions={2} + multiply.1316 = bf16[4,2048,768]{2,1,0} multiply(convert.1310, broadcast.1315) + get-tuple-element.1078 = bf16[12,768]{1,0} get-tuple-element(arg_tuple.1060), index=17 + dynamic-slice.1142 = bf16[1,768]{1,0} dynamic-slice(get-tuple-element.1078, select.1163, constant.1120), dynamic_slice_sizes={1,768} + reshape.1319 = bf16[768]{0} reshape(dynamic-slice.1142) + broadcast.1320 = bf16[4,2048,768]{2,1,0} broadcast(reshape.1319), dimensions={2} + add.1321 = bf16[4,2048,768]{2,1,0} add(multiply.1316, broadcast.1320) + reshape.47 = bf16[8192,768]{1,0} reshape(add.1321) + get-tuple-element.1076 = bf16[12,768,3072]{2,1,0} get-tuple-element(arg_tuple.1060), index=15 + dynamic-slice.1132 = bf16[1,768,3072]{2,1,0} dynamic-slice(get-tuple-element.1076, select.1163, constant.1120, constant.1120), dynamic_slice_sizes={1,768,3072} + reshape.1133 = bf16[768,3072]{1,0} reshape(dynamic-slice.1132) + dot.10 = bf16[8192,3072]{1,0} dot(reshape.47, reshape.1133), lhs_contracting_dims={1}, rhs_contracting_dims={0} + get-tuple-element.1075 = bf16[12,3072]{1,0} get-tuple-element(arg_tuple.1060), index=14 + dynamic-slice.1128 = bf16[1,3072]{1,0} dynamic-slice(get-tuple-element.1075, select.1163, constant.1120), dynamic_slice_sizes={1,3072} + reshape.1326 = bf16[3072]{0} reshape(dynamic-slice.1128) + broadcast.387 = bf16[8192,3072]{1,0} broadcast(reshape.1326), dimensions={1} + add.48 = bf16[8192,3072]{1,0} add(dot.10, broadcast.387) + reshape.469 = bf16[4,2048,3072]{2,1,0} reshape(add.48) + broadcast.389 = bf16[4,2048]{1,0} broadcast(constant.1090), dimensions={} + get-tuple-element.1088 = bf16[4,2048]{1,0} get-tuple-element(arg_tuple.1060), index=27 + subtract.3 = bf16[4,2048]{1,0} subtract(broadcast.389, get-tuple-element.1088) + broadcast.1349 = bf16[4,2048,768]{2,1,0} broadcast(subtract.3), dimensions={0,1} + multiply.1350 = bf16[4,2048,768]{2,1,0} multiply(get-tuple-element.1062, broadcast.1349) + reshape.53 = bf16[8192,768]{1,0} reshape(multiply.1350) + get-tuple-element.1077 = bf16[12,3072,768]{2,1,0} get-tuple-element(arg_tuple.1060), index=16 + dynamic-slice.1137 = bf16[1,3072,768]{2,1,0} dynamic-slice(get-tuple-element.1077, select.1163, constant.1120, constant.1120), dynamic_slice_sizes={1,3072,768} + reshape.1138 = bf16[3072,768]{1,0} reshape(dynamic-slice.1137) + dot.12 = bf16[8192,3072]{1,0} dot(reshape.53, reshape.1138), lhs_contracting_dims={1}, rhs_contracting_dims={1} + reshape.55 = bf16[4,2048,3072]{2,1,0} reshape(dot.12) + broadcast.1360 = bf16[4,2048,3072]{2,1,0} broadcast(subtract.3), dimensions={0,1} + multiply.1361 = bf16[4,2048,3072]{2,1,0} multiply(reshape.55, broadcast.1360) + multiply.1362 = bf16[4,2048,3072]{2,1,0} multiply(reshape.469, multiply.1361) + constant.1092 = bf16[] constant(0.5) + broadcast.1093 = bf16[4,2048,3072]{2,1,0} broadcast(constant.1092), dimensions={} + multiply.1363 = bf16[4,2048,3072]{2,1,0} multiply(multiply.1362, broadcast.1093) + broadcast.1095 = bf16[4,2048,3072]{2,1,0} broadcast(constant.1090), dimensions={} + multiply.1329 = bf16[4,2048,3072]{2,1,0} multiply(reshape.469, reshape.469) + multiply.1330 = bf16[4,2048,3072]{2,1,0} multiply(reshape.469, multiply.1329) + constant.1098 = bf16[] constant(0.04468) + broadcast.1099 = bf16[4,2048,3072]{2,1,0} broadcast(constant.1098), dimensions={} + multiply.1333 = bf16[4,2048,3072]{2,1,0} multiply(multiply.1330, broadcast.1099) + add.1334 = bf16[4,2048,3072]{2,1,0} add(reshape.469, multiply.1333) + constant.1096 = bf16[] constant(0.7969) + broadcast.1097 = bf16[4,2048,3072]{2,1,0} broadcast(constant.1096), dimensions={} + multiply.1335 = bf16[4,2048,3072]{2,1,0} multiply(add.1334, broadcast.1097) + tanh.1336 = bf16[4,2048,3072]{2,1,0} tanh(multiply.1335) + subtract.1337 = bf16[4,2048,3072]{2,1,0} subtract(broadcast.1095, tanh.1336) + multiply.1364 = bf16[4,2048,3072]{2,1,0} multiply(multiply.1363, subtract.1337) + multiply.1365 = bf16[4,2048,3072]{2,1,0} multiply(multiply.1364, tanh.1336) + add.1366 = bf16[4,2048,3072]{2,1,0} add(multiply.1364, multiply.1365) + multiply.1367 = bf16[4,2048,3072]{2,1,0} multiply(add.1366, broadcast.1097) + constant.66 = bf16[] constant(0.03564) + broadcast.289 = bf16[4,2048,3072]{2,1,0} broadcast(constant.66), dimensions={} + multiply.1368 = bf16[4,2048,3072]{2,1,0} multiply(add.1366, broadcast.289) + constant.1100 = bf16[] constant(3) + broadcast.1101 = bf16[4,2048,3072]{2,1,0} broadcast(constant.1100), dimensions={} + multiply.1332 = bf16[4,2048,3072]{2,1,0} multiply(multiply.1329, broadcast.1101) + multiply.1369 = bf16[4,2048,3072]{2,1,0} multiply(multiply.1368, multiply.1332) + add.1370 = bf16[4,2048,3072]{2,1,0} add(multiply.1367, multiply.1369) + add.1338 = bf16[4,2048,3072]{2,1,0} add(tanh.1336, broadcast.1095) + multiply.1339 = bf16[4,2048,3072]{2,1,0} multiply(add.1338, broadcast.1093) + multiply.1371 = bf16[4,2048,3072]{2,1,0} multiply(multiply.1361, multiply.1339) + add.1372 = bf16[4,2048,3072]{2,1,0} add(add.1370, multiply.1371) + reshape.59 = bf16[8192,3072]{1,0} reshape(add.1372) + dot.14 = bf16[8192,768]{1,0} dot(reshape.59, reshape.1133), lhs_contracting_dims={1}, rhs_contracting_dims={1} + reshape.61 = bf16[4,2048,768]{2,1,0} reshape(dot.14) + multiply.1390 = bf16[4,2048,768]{2,1,0} multiply(reshape.61, broadcast.1315) + convert.1391 = f32[4,2048,768]{2,1,0} convert(multiply.1390) + multiply.1392 = f32[4,2048,768]{2,1,0} multiply(subtract.1292, convert.1391) + reduce.1393 = f32[4,2048]{1,0} reduce(multiply.1392, constant.1117), dimensions={2}, to_apply=region_8.643 + reshape.1394 = f32[4,2048,1]{1,0,2} reshape(reduce.1393) + divide.1304 = f32[4,2048,1]{1,0,2} divide(rsqrt.1303, reshape.785) + constant.1109 = f32[] constant(-0.5) + broadcast.1110 = f32[4,2048,1]{1,0,2} broadcast(constant.1109), dimensions={} + multiply.1305 = f32[4,2048,1]{1,0,2} multiply(divide.1304, broadcast.1110) + multiply.1395 = f32[4,2048,1]{1,0,2} multiply(reshape.1394, multiply.1305) + constant.206 = f32[] constant(0.00260416674) + broadcast.474 = f32[4,2048,1]{1,0,2} broadcast(constant.206), dimensions={} + multiply.48 = f32[4,2048,1]{1,0,2} multiply(multiply.1395, broadcast.474) + reshape.510 = f32[4,2048]{1,0} reshape(multiply.48) + broadcast.293 = f32[4,2048,768]{2,1,0} broadcast(reshape.510), dimensions={0,1} + multiply.1399 = f32[4,2048,768]{2,1,0} multiply(subtract.1292, broadcast.293) + multiply.1406 = f32[4,2048,768]{2,1,0} multiply(convert.1391, broadcast.1308) + add.1410 = f32[4,2048,768]{2,1,0} add(multiply.1399, multiply.1406) + negate.1400 = f32[4,2048,768]{2,1,0} negate(multiply.1399) + reduce.1401 = f32[4,2048]{1,0} reduce(negate.1400, constant.1117), dimensions={2}, to_apply=region_8.643 + negate.1407 = f32[4,2048,768]{2,1,0} negate(multiply.1406) + reduce.1408 = f32[4,2048]{1,0} reduce(negate.1407, constant.1117), dimensions={2}, to_apply=region_8.643 + add.49 = f32[4,2048]{1,0} add(reduce.1401, reduce.1408) + multiply.74 = f32[4,2048]{1,0} multiply(add.49, broadcast.367) + broadcast.1414 = f32[4,2048,768]{2,1,0} broadcast(multiply.74), dimensions={0,1} + add.1415 = f32[4,2048,768]{2,1,0} add(add.1410, broadcast.1414) + convert.1416 = bf16[4,2048,768]{2,1,0} convert(add.1415) + add.1417 = bf16[4,2048,768]{2,1,0} add(get-tuple-element.1062, convert.1416) + reshape.65 = bf16[8192,768]{1,0} reshape(add.1417) + reshape.66 = bf16[768,768]{1,0} reshape(dynamic-slice.1173) + dot.16 = bf16[8192,768]{1,0} dot(reshape.65, reshape.66), lhs_contracting_dims={1}, rhs_contracting_dims={0} + reshape.67 = bf16[4,2048,12,64]{3,1,2,0} reshape(dot.16) + transpose.34 = bf16[4,12,2048,64]{3,2,1,0} transpose(reshape.67), dimensions={0,2,1,3} + // dO V -> dP + dot.17 = bf16[4,12,2048,2048]{3,2,1,0} dot(transpose.34, transpose.16), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + convert.1427 = f32[4,12,2048,2048]{3,2,1,0} convert(dot.17) + constant.1102 = f32[] constant(1) + broadcast.495 = f32[4,12,2048]{2,1,0} broadcast(constant.1102), dimensions={} + multiply.73 = f32[4,12,2048]{2,1,0} multiply(reduce.1264, reduce.1264) + divide.3 = f32[4,12,2048]{2,1,0} divide(broadcast.495, multiply.73) + broadcast.1430 = f32[4,12,2048,2048]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2} + multiply.1431 = f32[4,12,2048,2048]{3,2,1,0} multiply(convert.1427, broadcast.1430) + multiply.1432 = f32[4,12,2048,2048]{3,2,1,0} multiply(multiply.1431, exponential.1263) + reduce.1433 = f32[4,12,2048]{2,1,0} reduce(multiply.1432, constant.1117), dimensions={3}, to_apply=region_8.643 + negate.38 = f32[4,12,2048]{2,1,0} negate(reduce.1433) + broadcast.1437 = f32[4,12,2048,2048]{3,2,1,0} broadcast(negate.38), dimensions={0,1,2} + divide.1441 = f32[4,12,2048,2048]{3,2,1,0} divide(convert.1427, broadcast.1268) + add.1442 = f32[4,12,2048,2048]{3,2,1,0} add(broadcast.1437, divide.1441) + multiply.1443 = f32[4,12,2048,2048]{3,2,1,0} multiply(add.1442, exponential.1263) + convert.1444 = bf16[4,12,2048,2048]{3,2,1,0} convert(multiply.1443) + copy.1 = bf16[4,12,2048,64]{3,2,1,0} copy(transpose.14) + // dS Q -> dK + dot.18 = bf16[4,12,2048,64]{3,2,1,0} dot(convert.1444, copy.1), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + transpose.1446 = bf16[4,2048,12,64]{3,1,2,0} transpose(dot.18), dimensions={0,2,1,3} + reshape.1448 = bf16[1,4,2048,12,64]{4,2,3,1,0} reshape(transpose.1446) + pad.1449 = bf16[3,4,2048,12,64]{4,2,3,1,0} pad(reshape.1448, constant.1089), padding=1_1x0_0x0_0x0_0x0_0 + transpose.39 = bf16[4,12,2048,64]{2,3,1,0} transpose(reshape.1243), dimensions={0,2,1,3} + copy.2 = bf16[4,12,2048,64]{3,2,1,0} copy(transpose.39) + // dS K -> dQ + dot.19 = bf16[4,12,2048,64]{3,2,1,0} dot(convert.1444, copy.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + broadcast.395 = bf16[4,12,2048,64]{3,2,1,0} broadcast(constant.1105), dimensions={} + multiply.43 = bf16[4,12,2048,64]{3,2,1,0} multiply(dot.19, broadcast.395) + transpose.70 = bf16[4,2048,12,64]{3,1,2,0} transpose(multiply.43), dimensions={0,2,1,3} + reshape.1454 = bf16[1,4,2048,12,64]{4,2,3,1,0} reshape(transpose.70) + pad.1455 = bf16[3,4,2048,12,64]{4,2,3,1,0} pad(reshape.1454, constant.1089), padding=0_2x0_0x0_0x0_0x0_0 + add.1456 = bf16[3,4,2048,12,64]{4,2,3,1,0} add(pad.1449, pad.1455) + transpose.1425 = bf16[4,12,64,2048]{2,3,1,0} transpose(reshape.67), dimensions={0,2,3,1} + copy.3 = bf16[4,12,64,2048]{3,2,1,0} copy(transpose.1425) + // dO P -> dV + dot.1457 = bf16[4,12,64,2048]{3,2,1,0} dot(copy.3, convert.1272), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + copy.4 = bf16[4,12,64,2048]{2,3,1,0} copy(dot.1457) + transpose.1458 = bf16[4,2048,12,64]{3,1,2,0} transpose(copy.4), dimensions={0,3,1,2} + reshape.1460 = bf16[1,4,2048,12,64]{4,2,3,1,0} reshape(transpose.1458) + pad.1461 = bf16[3,4,2048,12,64]{4,2,3,1,0} pad(reshape.1460, constant.1089), padding=2_0x0_0x0_0x0_0x0_0 + add.1462 = bf16[3,4,2048,12,64]{4,2,3,1,0} add(add.1456, pad.1461) + transpose.40 = bf16[4,2048,3,12,64]{4,1,3,0,2} transpose(add.1462), dimensions={1,2,0,3,4} + reshape.77 = bf16[8192,2304]{1,0} reshape(transpose.40) + dot.20 = bf16[8192,768]{1,0} dot(reshape.77, reshape.35), lhs_contracting_dims={1}, rhs_contracting_dims={0} + reshape.79 = bf16[4,2048,768]{2,1,0} reshape(dot.20) + multiply.1478 = bf16[4,2048,768]{2,1,0} multiply(reshape.79, broadcast.1226) + convert.1479 = f32[4,2048,768]{2,1,0} convert(multiply.1478) + multiply.1480 = f32[4,2048,768]{2,1,0} multiply(subtract.1212, convert.1479) + reduce.1481 = f32[4,2048]{1,0} reduce(multiply.1480, constant.1117), dimensions={2}, to_apply=region_8.643 + reshape.1482 = f32[4,2048,1]{1,0,2} reshape(reduce.1481) + divide.1215 = f32[4,2048,1]{1,0,2} divide(rsqrt.1214, reshape.779) + multiply.1216 = f32[4,2048,1]{1,0,2} multiply(divide.1215, broadcast.1110) + multiply.1483 = f32[4,2048,1]{1,0,2} multiply(reshape.1482, multiply.1216) + multiply.49 = f32[4,2048,1]{1,0,2} multiply(multiply.1483, broadcast.474) + reshape.515 = f32[4,2048]{1,0} reshape(multiply.49) + broadcast.298 = f32[4,2048,768]{2,1,0} broadcast(reshape.515), dimensions={0,1} + multiply.1487 = f32[4,2048,768]{2,1,0} multiply(subtract.1212, broadcast.298) + multiply.1494 = f32[4,2048,768]{2,1,0} multiply(convert.1479, broadcast.1219) + add.1498 = f32[4,2048,768]{2,1,0} add(multiply.1487, multiply.1494) + negate.1488 = f32[4,2048,768]{2,1,0} negate(multiply.1487) + reduce.1489 = f32[4,2048]{1,0} reduce(negate.1488, constant.1117), dimensions={2}, to_apply=region_8.643 + negate.1495 = f32[4,2048,768]{2,1,0} negate(multiply.1494) + reduce.1496 = f32[4,2048]{1,0} reduce(negate.1495, constant.1117), dimensions={2}, to_apply=region_8.643 + add.50 = f32[4,2048]{1,0} add(reduce.1489, reduce.1496) + multiply.75 = f32[4,2048]{1,0} multiply(add.50, broadcast.367) + broadcast.1502 = f32[4,2048,768]{2,1,0} broadcast(multiply.75), dimensions={0,1} + add.1503 = f32[4,2048,768]{2,1,0} add(add.1498, broadcast.1502) + convert.1504 = bf16[4,2048,768]{2,1,0} convert(add.1503) + add.1505 = bf16[4,2048,768]{2,1,0} add(add.1417, convert.1504) + get-tuple-element.1063 = bf16[12,3072]{1,0} get-tuple-element(arg_tuple.1060), index=2 + reduce.1373 = bf16[3072]{0} reduce(add.1372, constant.1089), dimensions={0,1}, to_apply=region_23.860 + reshape.1508 = bf16[1,3072]{1,0} reshape(reduce.1373) + dynamic-update-slice.1512 = bf16[12,3072]{1,0} dynamic-update-slice(get-tuple-element.1063, reshape.1508, select.1163, constant.1120) + get-tuple-element.1064 = bf16[12,768,3072]{2,1,0} get-tuple-element(arg_tuple.1060), index=3 + transpose.26 = bf16[3072,4,2048]{0,2,1} transpose(add.1372), dimensions={2,0,1} + reshape.56 = bf16[3072,8192]{0,1} reshape(transpose.26) + dot.29 = bf16[768,3072]{1,0} dot(reshape.47, reshape.56), lhs_contracting_dims={0}, rhs_contracting_dims={1} + reshape.1513 = bf16[1,768,3072]{2,1,0} reshape(dot.29) + dynamic-update-slice.1517 = bf16[12,768,3072]{2,1,0} dynamic-update-slice(get-tuple-element.1064, reshape.1513, select.1163, constant.1120, constant.1120) + get-tuple-element.1065 = bf16[12,768]{1,0} get-tuple-element(arg_tuple.1060), index=4 + reduce.1351 = bf16[768]{0} reduce(multiply.1350, constant.1089), dimensions={0,1}, to_apply=region_23.860 + reshape.1518 = bf16[1,768]{1,0} reshape(reduce.1351) + dynamic-update-slice.1522 = bf16[12,768]{1,0} dynamic-update-slice(get-tuple-element.1065, reshape.1518, select.1163, constant.1120) + get-tuple-element.1066 = bf16[12,3072,768]{2,1,0} get-tuple-element(arg_tuple.1060), index=5 + multiply.1340 = bf16[4,2048,3072]{2,1,0} multiply(reshape.469, multiply.1339) + multiply.1345 = bf16[4,2048,3072]{2,1,0} multiply(multiply.1340, broadcast.1360) + reshape.51 = bf16[8192,3072]{1,0} reshape(multiply.1345) + transpose.22 = bf16[768,4,2048]{0,2,1} transpose(multiply.1350), dimensions={2,0,1} + reshape.50 = bf16[768,8192]{0,1} reshape(transpose.22) + dot.30 = bf16[3072,768]{1,0} dot(reshape.51, reshape.50), lhs_contracting_dims={0}, rhs_contracting_dims={1} + reshape.1523 = bf16[1,3072,768]{2,1,0} reshape(dot.30) + dynamic-update-slice.1527 = bf16[12,3072,768]{2,1,0} dynamic-update-slice(get-tuple-element.1066, reshape.1523, select.1163, constant.1120, constant.1120) + get-tuple-element.1067 = bf16[12,768]{1,0} get-tuple-element(arg_tuple.1060), index=6 + reduce.1380 = bf16[768]{0} reduce(dot.14, constant.1089), dimensions={0}, to_apply=region_23.860 + reshape.1528 = bf16[1,768]{1,0} reshape(reduce.1380) + dynamic-update-slice.1532 = bf16[12,768]{1,0} dynamic-update-slice(get-tuple-element.1067, reshape.1528, select.1163, constant.1120) + get-tuple-element.1068 = bf16[12,768]{1,0} get-tuple-element(arg_tuple.1060), index=7 + multiply.1383 = bf16[4,2048,768]{2,1,0} multiply(convert.1310, reshape.61) + reduce.1384 = bf16[768]{0} reduce(multiply.1383, constant.1089), dimensions={0,1}, to_apply=region_23.860 + reshape.1533 = bf16[1,768]{1,0} reshape(reduce.1384) + dynamic-update-slice.1537 = bf16[12,768]{1,0} dynamic-update-slice(get-tuple-element.1068, reshape.1533, select.1163, constant.1120) + get-tuple-element.1069 = bf16[12,768]{1,0} get-tuple-element(arg_tuple.1060), index=8 + reduce.1468 = bf16[768]{0} reduce(dot.20, constant.1089), dimensions={0}, to_apply=region_23.860 + reshape.1538 = bf16[1,768]{1,0} reshape(reduce.1468) + dynamic-update-slice.1542 = bf16[12,768]{1,0} dynamic-update-slice(get-tuple-element.1069, reshape.1538, select.1163, constant.1120) + get-tuple-element.1070 = bf16[12,768]{1,0} get-tuple-element(arg_tuple.1060), index=9 + multiply.1471 = bf16[4,2048,768]{2,1,0} multiply(convert.1221, reshape.79) + reduce.1472 = bf16[768]{0} reduce(multiply.1471, constant.1089), dimensions={0,1}, to_apply=region_23.860 + reshape.1543 = bf16[1,768]{1,0} reshape(reduce.1472) + dynamic-update-slice.1547 = bf16[12,768]{1,0} dynamic-update-slice(get-tuple-element.1070, reshape.1543, select.1163, constant.1120) + get-tuple-element.1071 = bf16[12,3,12,64]{3,2,1,0} get-tuple-element(arg_tuple.1060), index=10 + reduce.1463 = bf16[3,12,64]{2,1,0} reduce(add.1462, constant.1089), dimensions={1,2}, to_apply=region_23.860 + reshape.1548 = bf16[1,3,12,64]{3,2,1,0} reshape(reduce.1463) + dynamic-update-slice.1552 = bf16[12,3,12,64]{3,2,1,0} dynamic-update-slice(get-tuple-element.1071, reshape.1548, select.1163, constant.1120, constant.1120, /*index=5*/constant.1120) + get-tuple-element.1072 = bf16[12,3,768,12,64]{4,3,2,1,0} get-tuple-element(arg_tuple.1060), index=11 + transpose.42 = bf16[3,12,64,4,2048]{2,4,1,3,0} transpose(add.1462), dimensions={0,3,4,1,2} + reshape.80 = bf16[2304,8192]{1,0} reshape(transpose.42) + reshape.81 = bf16[8192,768]{1,0} reshape(add.1232) + dot.21 = bf16[2304,768]{1,0} dot(reshape.80, reshape.81), lhs_contracting_dims={1}, rhs_contracting_dims={0} + reshape.82 = bf16[3,12,64,768]{2,1,3,0} reshape(dot.21) + transpose.1507 = bf16[3,768,12,64]{3,2,1,0} transpose(reshape.82), dimensions={0,3,1,2} + reshape.1553 = bf16[1,3,768,12,64]{4,3,2,1,0} reshape(transpose.1507) + dynamic-update-slice.1557 = bf16[12,3,768,12,64]{4,3,2,1,0} dynamic-update-slice(get-tuple-element.1072, reshape.1553, select.1163, constant.1120, constant.1120, /*index=5*/constant.1120, constant.1120) + get-tuple-element.1073 = bf16[12,768]{1,0} get-tuple-element(arg_tuple.1060), index=12 + reduce.1419 = bf16[768]{0} reduce(add.1417, constant.1089), dimensions={0,1}, to_apply=region_23.860 + reshape.1558 = bf16[1,768]{1,0} reshape(reduce.1419) + dynamic-update-slice.1562 = bf16[12,768]{1,0} dynamic-update-slice(get-tuple-element.1073, reshape.1558, select.1163, constant.1120) + get-tuple-element.1074 = bf16[12,768,12,64]{3,2,1,0} get-tuple-element(arg_tuple.1060), index=13 + transpose.30 = bf16[768,4,2048]{0,2,1} transpose(add.1417), dimensions={2,0,1} + reshape.62 = bf16[768,8192]{0,1} reshape(transpose.30) + transpose.31 = bf16[4,2048,12,64]{1,3,2,0} transpose(dot.8), dimensions={0,3,1,2} + reshape.63 = bf16[8192,768]{1,0} reshape(transpose.31) + dot.15 = bf16[768,768]{1,0} dot(reshape.62, reshape.63), lhs_contracting_dims={1}, rhs_contracting_dims={0} + reshape.1563 = bf16[1,768,12,64]{3,2,1,0} reshape(dot.15) + dynamic-update-slice.1567 = bf16[12,768,12,64]{3,2,1,0} dynamic-update-slice(get-tuple-element.1074, reshape.1563, select.1163, constant.1120, constant.1120, /*index=5*/constant.1120) + ROOT tuple.1569 = (s32[], bf16[4,2048,768]{2,1,0}, bf16[12,3072]{1,0}, bf16[12,768,3072]{2,1,0}, bf16[12,768]{1,0}, /*index=5*/bf16[12,3072,768]{2,1,0}, bf16[12,768]{1,0}, bf16[12,768]{1,0}, bf16[12,768]{1,0}, bf16[12,768]{1,0}, /*index=10*/bf16[12,3,12,64]{3,2,1,0}, bf16[12,3,768,12,64]{4,3,2,1,0}, bf16[12,768]{1,0}, bf16[12,768,12,64]{3,2,1,0}, bf16[12,3072]{1,0}, /*index=15*/bf16[12,768,3072]{2,1,0}, bf16[12,3072,768]{2,1,0}, bf16[12,768]{1,0}, bf16[12,768]{1,0}, bf16[12,768]{1,0}, /*index=20*/bf16[12,768]{1,0}, bf16[12,3,12,64]{3,2,1,0}, bf16[12,3,768,12,64]{4,3,2,1,0}, bf16[12,768]{1,0}, bf16[12,768,12,64]{3,2,1,0}, /*index=25*/bf16[12,4,2048,768]{3,2,1,0}, bf16[4,1,2048,2048]{3,2,0,1}, bf16[4,2048]{1,0}) tuple(add.1568, add.1505, dynamic-update-slice.1512, dynamic-update-slice.1517, dynamic-update-slice.1522, /*index=5*/dynamic-update-slice.1527, dynamic-update-slice.1532, dynamic-update-slice.1537, dynamic-update-slice.1542, dynamic-update-slice.1547, /*index=10*/dynamic-update-slice.1552, dynamic-update-slice.1557, dynamic-update-slice.1562, dynamic-update-slice.1567, get-tuple-element.1075, /*index=15*/get-tuple-element.1076, get-tuple-element.1077, get-tuple-element.1078, get-tuple-element.1079, get-tuple-element.1080, /*index=20*/get-tuple-element.1081, get-tuple-element.1082, get-tuple-element.1083, get-tuple-element.1084, get-tuple-element.1085, /*index=25*/get-tuple-element.1086, get-tuple-element.1087, get-tuple-element.1088) +} + +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + AlgebraicSimplifierOptions alg_sim_options; + alg_sim_options.set_supports_non_canonical_dots(false); + alg_sim_options.set_is_layout_sensitive(true); + alg_sim_options.set_enable_conv_operand_swap(false); + AlgebraicSimplifier alge_simp{alg_sim_options}; + ReshapeDecomposer reshape_decomposer; + LayoutNormalization layout_normalizer; + HloCSE cse{/*is_layout_sensitive=*/true}; + TF_ASSERT_OK(RunHloPass(&reshape_decomposer, m.get()).status()); + TF_ASSERT_OK(RunHloPass(&layout_normalizer, m.get()).status()); + // TF_ASSERT_OK(RunHloPass(&cse, m.get()).status()); + TF_ASSERT_OK(RunHloPass(&alge_simp, m.get()).status()); + TF_ASSERT_OK(RunHloPass(&cse, m.get()).status()); + CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(), + GetCudnnVersion()}; + TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); + + CudnnFusedMHATransposeFusion fmha_transpose_fusion; + + HloDCE dce; + TF_ASSERT_OK(RunHloPass(&alge_simp, m.get()).status()); + TF_ASSERT_OK(RunHloPass(&fmha_transpose_fusion, m.get()).status()); + + TF_ASSERT_OK(RunHloPass(&dce, m.get()).status()); + + ComputationLayout computation_layout( + m->entry_computation()->ComputeProgramShape()); + + HloInstruction* fwd_instruction = nullptr; + HloInstruction* bwd_instruction = nullptr; + SCOPED_TRACE(m->ToString()); + for (HloInstruction* instr : + m->entry_computation()->MakeInstructionPostOrder()) { + if (instr->opcode() == HloOpcode::kCustomCall && + instr->custom_call_target() == kCudnnfMHASoftmaxCallTarget) { + fwd_instruction = instr; + } + if (instr->opcode() == HloOpcode::kCustomCall && + instr->custom_call_target() == kCudnnfMHASoftmaxBackwardCallTarget) { + bwd_instruction = instr; + } + } + EXPECT_NE(fwd_instruction, nullptr); + EXPECT_NE(bwd_instruction, nullptr); + TF_ASSERT_OK_AND_ASSIGN( + auto config, fwd_instruction->backend_config()); + EXPECT_EQ(config.is_flash_attention(), true); + EXPECT_EQ(config.is_causal_mask(), true); +} + } // anonymous namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc b/xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc index 6d30eac7bf3fdd..513fef387a2481 100644 --- a/xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc +++ b/xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc @@ -102,6 +102,9 @@ StatusOr FuseArgPrologueTransposeWithcuDNNFMHA( absl::Span checked_dims; std::vector checked_dims_vec; + // `should_contracting_be_fastest` means if contracting dim is the hidden + // dim. cuDNN requires hidden dim to be the fastest dim. fwd bmm1 and bwd + // bmm2grad1 should set this value to true. if (should_contracting_be_fastest) { checked_dims = is_lhs ? new_bmm_dot_dims.lhs_contracting_dimensions() : new_bmm_dot_dims.rhs_contracting_dimensions(); @@ -137,9 +140,7 @@ StatusOr FuseArgPrologueTransposeWithcuDNNFMHA( absl::Span minor_to_major_bmm = transpose_arg_operand->shape().layout().minor_to_major(); if ((minor_to_major_bmm[0] != new_bmm_checked_dims[0]) && - ((transpose_arg_operand->shape().dimensions().at( - new_bmm_checked_dims[0]) == 64) || - (IsBwdCustomCallTofMHA(*fmha) && operand_index == 3))) { + !(IsBwdCustomCallTofMHA(*fmha) && operand_index == 3)) { return false; } if (should_contracting_be_fastest) { @@ -198,8 +199,10 @@ StatusOr FuseArgPrologueTransposeWithcuDNNFMHA( } if (IsFwdCustomCallTofMHA(*fmha)) { if (operand_index == 0 || operand_index == 1) { + // Q or K *new_fmha_config.mutable_bmm1_dot_dimension_numbers() = new_bmm_dot_dims; } else { + // V *new_fmha_config.mutable_bmm2_dot_dimension_numbers() = new_bmm_dot_dims; } } else { @@ -455,9 +458,16 @@ StatusOr FusePrologueTransposeWithcuDNNFMHA(HloComputation* comp) { } // D_output tensor in backward graph is lhs with constraint on // contracting dim. - TF_ASSIGN_OR_RETURN(changed, FuseArgPrologueTransposeWithcuDNNFMHA( - fmha, 4, true /*is_lhs*/, - true /*should_contracting_be_fastest*/)); + // make sure we dont change layout of dO in flash attention case as dO + // should have the same layout of O + TF_ASSIGN_OR_RETURN(CudnnfMHABackendConfig config, + fmha->backend_config()); + if (!config.is_flash_attention()) { + TF_ASSIGN_OR_RETURN(changed, + FuseArgPrologueTransposeWithcuDNNFMHA( + fmha, 4, true /*is_lhs*/, + true /*should_contracting_be_fastest*/)); + } if (changed && VLOG_IS_ON(2)) { VLOG(2) << "After CudnnFusedMHATransposeFusion Arg 4: \n" @@ -489,16 +499,42 @@ Calling this function with 'result' shape as the input shape and the inverse perm as the permutation will generate an output shape whose dimensions match 'FMHA_out' dimensions but the physical layout is equivalent to 'result'. This is exactly what we want. + +FMHA output should have exactly one gte instruction for a tuple index +so we can safely fuse the transpose following that gte to FMHA + +FMHA_out = gte(FMHA, index=0) +FMHA_out_t = transpose(FMHA_out) +use(FMHA_out_t) + +after fusion: + +FMHA_out_t = gte(FMHA, index=0) +use(FMHA_out_t) */ StatusOr FuseEpilogueTransposeWithcuDNNFMHA(HloComputation* comp) { bool changed = false; + + auto onlyOneGTEWithSpecIndex = [](const HloInstruction* instr, + int64_t index) { + int count = 0; + for (auto user : instr->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement && + user->tuple_index() == index) { + count += 1; + } + } + return count == 1; + }; + for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { HloInstruction* fmha; HloInstruction* transpose; HloInstruction* gte; auto fwd_tuple_elem = - m::GetTupleElement(m::Op(&fmha).WithPredicate(IsFwdFMHACustomCall), 0) + m::GetTupleElement(>e, + m::Op(&fmha).WithPredicate(IsFwdFMHACustomCall), 0) .WithOneUser(); // Note that we don't match any specific tuple index in matcher for // backward. @@ -510,6 +546,10 @@ StatusOr FuseEpilogueTransposeWithcuDNNFMHA(HloComputation* comp) { auto bwd_pattern = m::Transpose(&transpose, bwd_tuple_elem); if (Match(instr, fwd_pattern)) { + // check if only one gte with such index exist + int64_t tuple_index = gte->tuple_index(); + if (!onlyOneGTEWithSpecIndex(fmha, tuple_index)) continue; + std::vector inverse_perm = InversePermutation(transpose->dimensions()); @@ -558,9 +598,12 @@ StatusOr FuseEpilogueTransposeWithcuDNNFMHA(HloComputation* comp) { } changed |= true; } else if (Match(instr, bwd_pattern)) { + // check if only one gte with such index exist + int64_t operand_tuple_idx = gte->tuple_index(); + if (!onlyOneGTEWithSpecIndex(fmha, operand_tuple_idx)) continue; + std::vector inverse_perm = InversePermutation(transpose->dimensions()); - int64_t operand_tuple_idx = gte->tuple_index(); auto expected_fmha_shape = ShapeUtil::PermuteDimensions(inverse_perm, transpose->shape()); diff --git a/xla/service/gpu/tests/gpu_fused_mha_test.cc b/xla/service/gpu/tests/gpu_fused_mha_test.cc index 85ec435927e803..e6b78587e047d5 100644 --- a/xla/service/gpu/tests/gpu_fused_mha_test.cc +++ b/xla/service/gpu/tests/gpu_fused_mha_test.cc @@ -2215,6 +2215,442 @@ class MultiHeadedAttentionBMMScaleBiasSoftmaxBMM } }; +class FlashAttentionBMMScaleCausalMaskSoftmaxBMM + : public MultiHeadedAttentionTest { + protected: + const std::string // NOLINT + GetModuleFlash_Attention_BMM1_CausalMask_Softmax_BMM2_HloString_BF16() { // NOLINT + const std::string hlo_text = R"( + HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,128,2048]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0})->bf16[2,6,2048,128]{3,2,1,0}}, allow_spmd_sharding_propagation_to_output={true} + + region_0.28 { + Arg_0.29 = bf16[] parameter(0) + Arg_1.30 = bf16[] parameter(1) + ROOT maximum.31 = bf16[] maximum(Arg_0.29, Arg_1.30) + } + + region_1.40 { + Arg_0.41 = f32[] parameter(0) + Arg_1.42 = f32[] parameter(1) + ROOT add.43 = f32[] add(Arg_0.41, Arg_1.42) + } + + ENTRY main.52 { + Arg_0.1 = bf16[2,6,2048,128]{3,2,1,0} parameter(0), sharding={replicated} + Arg_1.2 = bf16[2,6,128,2048]{3,2,1,0} parameter(1), sharding={replicated} + dot.10 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + constant.6 = bf16[] constant(2) + broadcast.7 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(constant.6), dimensions={} + multiply.11 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.10, broadcast.7) + iota.16 = s32[2048]{0} iota(), iota_dimension=0 + reshape.17 = s32[1,2048,1]{2,1,0} reshape(iota.16) + broadcast.18 = s32[1,2048,2048,1]{3,2,1,0} broadcast(reshape.17), dimensions={0,1,3} + reshape.19 = s32[2048,2048]{1,0} reshape(broadcast.18) + iota.12 = s32[2048]{0} iota(), iota_dimension=0 + reshape.13 = s32[1,1,2048]{2,1,0} reshape(iota.12) + broadcast.14 = s32[2048,1,1,2048]{3,2,1,0} broadcast(reshape.13), dimensions={1,2,3} + reshape.15 = s32[2048,2048]{1,0} reshape(broadcast.14) + compare.20 = pred[2048,2048]{1,0} compare(reshape.19, reshape.15), direction=LT + convert.21 = bf16[2048,2048]{1,0} convert(compare.20) + constant.4 = bf16[] constant(-2.366e+38) + broadcast.5 = bf16[2048,2048]{1,0} broadcast(constant.4), dimensions={} + multiply.22 = bf16[2048,2048]{1,0} multiply(convert.21, broadcast.5) + reshape.23 = bf16[1,1,2048,2048]{3,2,1,0} reshape(multiply.22) + broadcast.24 = bf16[1,1,2048,2048]{3,2,1,0} broadcast(reshape.23), dimensions={0,1,2,3} + reshape.25 = bf16[2048,2048]{1,0} reshape(broadcast.24) + broadcast.26 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reshape.25), dimensions={2,3} + add.27 = bf16[2,6,2048,2048]{3,2,1,0} add(multiply.11, broadcast.26) + constant.9 = bf16[] constant(-inf) + reduce.32 = bf16[2,6,2048]{2,1,0} reduce(add.27, constant.9), dimensions={3}, to_apply=region_0.28 + reshape.33 = bf16[2,6,2048,1]{3,2,1,0} reshape(reduce.32) + broadcast.34 = bf16[2,6,2048,1]{3,2,1,0} broadcast(reshape.33), dimensions={0,1,2,3} + reshape.35 = bf16[2,6,2048]{2,1,0} reshape(broadcast.34) + broadcast.36 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reshape.35), dimensions={0,1,2} + subtract.37 = bf16[2,6,2048,2048]{3,2,1,0} subtract(add.27, broadcast.36) + exponential.38 = bf16[2,6,2048,2048]{3,2,1,0} exponential(subtract.37) + convert.39 = f32[2,6,2048,2048]{3,2,1,0} convert(exponential.38) + constant.8 = f32[] constant(0) + reduce.44 = f32[2,6,2048]{2,1,0} reduce(convert.39, constant.8), dimensions={3}, to_apply=region_1.40 + reshape.45 = f32[2,6,2048,1]{3,2,1,0} reshape(reduce.44) + convert.46 = bf16[2,6,2048,1]{3,2,1,0} convert(reshape.45) + broadcast.47 = bf16[2,6,2048,1]{3,2,1,0} broadcast(convert.46), dimensions={0,1,2,3} + reshape.48 = bf16[2,6,2048]{2,1,0} reshape(broadcast.47) + broadcast.49 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reshape.48), dimensions={0,1,2} + divide.50 = bf16[2,6,2048,2048]{3,2,1,0} divide(exponential.38, broadcast.49) + Arg_2.3 = bf16[2,6,2048,128]{3,2,1,0} parameter(2), sharding={replicated} + ROOT dot.51 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.50, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + } + )"; + return hlo_text; + } + + const std::string // NOLINT + GetModuleFlash_Attention_Training_BMM1_CausalMask_Softmax_BMM2_HloString_BF16() { // NOLINT + const std::string hlo_text = R"( + HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,64,1024]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0})->(bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,64,1024]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} + + region_0.29 { + Arg_0.30 = bf16[] parameter(0) + Arg_1.31 = bf16[] parameter(1) + ROOT maximum.32 = bf16[] maximum(Arg_0.30, Arg_1.31) + } + + region_1.41 { + Arg_0.42 = f32[] parameter(0) + Arg_1.43 = f32[] parameter(1) + ROOT add.44 = f32[] add(Arg_0.42, Arg_1.43) + } + + region_2.63 { + Arg_0.64 = bf16[] parameter(0) + Arg_1.65 = bf16[] parameter(1) + ROOT add.66 = bf16[] add(Arg_0.64, Arg_1.65) + } + + region_3.75 { + Arg_0.76 = f32[] parameter(0) + Arg_1.77 = f32[] parameter(1) + ROOT add.78 = f32[] add(Arg_0.76, Arg_1.77) + } + + ENTRY main.88 { + Arg_0.1 = bf16[2,6,1024,64]{3,2,1,0} parameter(0), sharding={replicated} + Arg_1.2 = bf16[2,6,64,1024]{3,2,1,0} parameter(1), sharding={replicated} + dot.12 = bf16[2,6,1024,1024]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + iota.17 = s32[1024]{0} iota(), iota_dimension=0 + reshape.18 = s32[1,1024,1]{2,1,0} reshape(iota.17) + broadcast.19 = s32[1,1024,1024,1]{3,2,1,0} broadcast(reshape.18), dimensions={0,1,3} + reshape.20 = s32[1024,1024]{1,0} reshape(broadcast.19) + iota.13 = s32[1024]{0} iota(), iota_dimension=0 + reshape.14 = s32[1,1,1024]{2,1,0} reshape(iota.13) + broadcast.15 = s32[1024,1,1,1024]{3,2,1,0} broadcast(reshape.14), dimensions={1,2,3} + reshape.16 = s32[1024,1024]{1,0} reshape(broadcast.15) + compare.21 = pred[1024,1024]{1,0} compare(reshape.20, reshape.16), direction=LT + convert.22 = bf16[1024,1024]{1,0} convert(compare.21) + constant.7 = bf16[] constant(-2.366e+38) + broadcast.8 = bf16[1024,1024]{1,0} broadcast(constant.7), dimensions={} + multiply.23 = bf16[1024,1024]{1,0} multiply(convert.22, broadcast.8) + reshape.24 = bf16[1,1,1024,1024]{3,2,1,0} reshape(multiply.23) + broadcast.25 = bf16[1,1,1024,1024]{3,2,1,0} broadcast(reshape.24), dimensions={0,1,2,3} + reshape.26 = bf16[1024,1024]{1,0} reshape(broadcast.25) + broadcast.27 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.26), dimensions={2,3} + add.28 = bf16[2,6,1024,1024]{3,2,1,0} add(dot.12, broadcast.27) + constant.11 = bf16[] constant(-inf) + reduce.33 = bf16[2,6,1024]{2,1,0} reduce(add.28, constant.11), dimensions={3}, to_apply=region_0.29 + reshape.34 = bf16[2,6,1024,1]{3,2,1,0} reshape(reduce.33) + broadcast.35 = bf16[2,6,1024,1]{3,2,1,0} broadcast(reshape.34), dimensions={0,1,2,3} + reshape.36 = bf16[2,6,1024]{2,1,0} reshape(broadcast.35) + broadcast.37 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.36), dimensions={0,1,2} + subtract.38 = bf16[2,6,1024,1024]{3,2,1,0} subtract(add.28, broadcast.37) + exponential.39 = bf16[2,6,1024,1024]{3,2,1,0} exponential(subtract.38) + convert.40 = f32[2,6,1024,1024]{3,2,1,0} convert(exponential.39) + constant.10 = f32[] constant(0) + reduce.45 = f32[2,6,1024]{2,1,0} reduce(convert.40, constant.10), dimensions={3}, to_apply=region_1.41 + reshape.46 = f32[2,6,1024,1]{3,2,1,0} reshape(reduce.45) + convert.47 = bf16[2,6,1024,1]{3,2,1,0} convert(reshape.46) + broadcast.48 = bf16[2,6,1024,1]{3,2,1,0} broadcast(convert.47), dimensions={0,1,2,3} + reshape.49 = bf16[2,6,1024]{2,1,0} reshape(broadcast.48) + broadcast.50 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.49), dimensions={0,1,2} + divide.51 = bf16[2,6,1024,1024]{3,2,1,0} divide(exponential.39, broadcast.50) + Arg_2.3 = bf16[2,6,1024,64]{3,2,1,0} parameter(2), sharding={replicated} + dot.54 = bf16[2,6,1024,64]{3,2,1,0} dot(divide.51, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + Arg_3.4 = bf16[2,6,1024,64]{3,2,1,0} parameter(3), sharding={replicated} + dot.57 = bf16[2,6,1024,1024]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + broadcast.70 = bf16[2,6,1024,1]{3,2,1,0} broadcast(convert.47), dimensions={0,1,2,3} + reshape.71 = bf16[2,6,1024]{2,1,0} reshape(broadcast.70) + broadcast.72 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.71), dimensions={0,1,2} + divide.73 = bf16[2,6,1024,1024]{3,2,1,0} divide(dot.57, broadcast.72) + constant.5 = bf16[] constant(1) + broadcast.6 = bf16[2,6,1024,1]{3,2,1,0} broadcast(constant.5), dimensions={} + multiply.52 = bf16[2,6,1024,1]{3,2,1,0} multiply(convert.47, convert.47) + divide.53 = bf16[2,6,1024,1]{3,2,1,0} divide(broadcast.6, multiply.52) + broadcast.58 = bf16[2,6,1024,1]{3,2,1,0} broadcast(divide.53), dimensions={0,1,2,3} + reshape.59 = bf16[2,6,1024]{2,1,0} reshape(broadcast.58) + broadcast.60 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.59), dimensions={0,1,2} + multiply.61 = bf16[2,6,1024,1024]{3,2,1,0} multiply(dot.57, broadcast.60) + multiply.62 = bf16[2,6,1024,1024]{3,2,1,0} multiply(multiply.61, exponential.39) + constant.9 = bf16[] constant(0) + reduce.67 = bf16[2,6,1024]{2,1,0} reduce(multiply.62, constant.9), dimensions={3}, to_apply=region_2.63 + reshape.68 = bf16[2,6,1024,1]{3,2,1,0} reshape(reduce.67) + negate.69 = bf16[2,6,1024,1]{3,2,1,0} negate(reshape.68) + convert.74 = f32[2,6,1024,1]{3,2,1,0} convert(negate.69) + reduce.79 = f32[2,6,1024]{2,1,0} reduce(convert.74, constant.10), dimensions={3}, to_apply=region_3.75 + broadcast.80 = f32[2,6,1024,1024]{3,2,1,0} broadcast(reduce.79), dimensions={0,1,2} + convert.81 = bf16[2,6,1024,1024]{3,2,1,0} convert(broadcast.80) + add.82 = bf16[2,6,1024,1024]{3,2,1,0} add(divide.73, convert.81) + multiply.83 = bf16[2,6,1024,1024]{3,2,1,0} multiply(add.82, exponential.39) + dot.86 = bf16[2,6,1024,64]{3,2,1,0} dot(multiply.83, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + dot.84 = bf16[2,6,1024,64]{3,2,1,0} dot(multiply.83, Arg_0.1), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + transpose.85 = bf16[2,6,64,1024]{2,3,1,0} transpose(dot.84), dimensions={0,1,3,2} + dot.55 = bf16[2,6,64,1024]{3,2,1,0} dot(Arg_3.4, divide.51), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + transpose.56 = bf16[2,6,1024,64]{2,3,1,0} transpose(dot.55), dimensions={0,1,3,2} + ROOT tuple.87 = (bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,64,1024]{2,3,1,0}, bf16[2,6,1024,64]{2,3,1,0}) tuple(dot.54, dot.86, transpose.85, transpose.56) + } + )"; + return hlo_text; + } + + template + void TestImpl_Flash_Attention_BMM1_CausalMask_Softmax_BMM2() { + stream_executor::CudaComputeCapability cc = GetCudaComputeCapability(); + se::dnn::VersionInfo real_cudnn_version = GetCudnnVersion(); + if (!(cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0 && + real_cudnn_version >= se::dnn::VersionInfo(8, 9, 3))) { + GTEST_SKIP() << "Flash Attention is supported with the Nvidia AMPERE+ " + "GPUs and cuDNN >= 8.9.3."; + } + XlaBuilder builder(TestName()); + auto lhs_bmm1_literal = + GetInput4DLiteral({2, 6, 2048, 128}, {3, 2, 1, 0}); + auto rhs_bmm1_literal = + GetInput4DLiteral({2, 6, 128, 2048}, {3, 2, 1, 0}); + auto rhs_bmm2_literal = + GetInput4DLiteral({2, 6, 2048, 128}, {3, 2, 1, 0}); + std::string hlo_string = ""; + if (std::is_same::value) { + // + } else if (std::is_same::value) { + hlo_string = + GetModuleFlash_Attention_BMM1_CausalMask_Softmax_BMM2_HloString_BF16(); + } + + ExecuteAndCompare( + hlo_string, {&lhs_bmm1_literal, &rhs_bmm1_literal, &rhs_bmm2_literal}); + } + + + template + void TestImpl_Flash_Attention_Training_BMM1_CausalMask_Softmax_BMM2() { + stream_executor::CudaComputeCapability cc = GetCudaComputeCapability(); + se::dnn::VersionInfo real_cudnn_version = GetCudnnVersion(); + if (!(cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0 && + real_cudnn_version >= se::dnn::VersionInfo(8, 9, 3))) { + GTEST_SKIP() << "Flash Attention is supported with the Nvidia AMPERE+ " + "GPUs and cuDNN >= 8.9.3."; + } + XlaBuilder builder(TestName()); + auto lhs_bmm1_literal = + GetInput4DLiteral({2, 6, 1024, 64}, {3, 2, 1, 0}); + auto rhs_bmm1_literal = + GetInput4DLiteral({2, 6, 64, 1024}, {3, 2, 1, 0}); + auto rhs_bmm2_literal = + GetInput4DLiteral({2, 6, 1024, 64}, {3, 2, 1, 0}); + auto do_literal = + GetInput4DLiteral({2, 6, 1024, 64}, {3, 2, 1, 0}); + std::string hlo_string = ""; + if (std::is_same::value) { + // + } else if (std::is_same::value) { + hlo_string = + GetModuleFlash_Attention_Training_BMM1_CausalMask_Softmax_BMM2_HloString_BF16(); + } + + ExecuteAndCompare( + hlo_string, {&lhs_bmm1_literal, &rhs_bmm1_literal, &rhs_bmm2_literal, &do_literal}, true); + } +}; + +class FlashAttentionBMMScaleBiasSoftmaxBMM : public MultiHeadedAttentionTest { + protected: + const std::string // NOLINT + GetModuleFlash_Attention_BMM1_Bias_Softmax_BMM2_HloString_BF16() { // NOLINT + const std::string hlo_text = R"( + HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,128,2048]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,2048,2048]{3,2,1,0})->bf16[2,6,2048,128]{3,2,1,0}}, allow_spmd_sharding_propagation_to_output={true} + + region_0.28 { + Arg_0.29 = bf16[] parameter(0) + Arg_1.30 = bf16[] parameter(1) + ROOT maximum.31 = bf16[] maximum(Arg_0.29, Arg_1.30) + } + + region_1.40 { + Arg_0.41 = f32[] parameter(0) + Arg_1.42 = f32[] parameter(1) + ROOT add.43 = f32[] add(Arg_0.41, Arg_1.42) + } + + ENTRY main.52 { + Arg_0.1 = bf16[2,6,2048,128]{3,2,1,0} parameter(0), sharding={replicated} + Arg_1.2 = bf16[2,6,128,2048]{3,2,1,0} parameter(1), sharding={replicated} + dot.10 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + constant.6 = bf16[] constant(2) + broadcast.7 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(constant.6), dimensions={} + multiply.11 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.10, broadcast.7) + Arg_3.4 = bf16[2,6,2048,2048]{3,2,1,0} parameter(3), sharding={replicated} + add.27 = bf16[2,6,2048,2048]{3,2,1,0} add(multiply.11, Arg_3.4) + constant.9 = bf16[] constant(-inf) + reduce.32 = bf16[2,6,2048]{2,1,0} reduce(add.27, constant.9), dimensions={3}, to_apply=region_0.28 + reshape.33 = bf16[2,6,2048,1]{3,2,1,0} reshape(reduce.32) + broadcast.34 = bf16[2,6,2048,1]{3,2,1,0} broadcast(reshape.33), dimensions={0,1,2,3} + reshape.35 = bf16[2,6,2048]{2,1,0} reshape(broadcast.34) + broadcast.36 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reshape.35), dimensions={0,1,2} + subtract.37 = bf16[2,6,2048,2048]{3,2,1,0} subtract(add.27, broadcast.36) + exponential.38 = bf16[2,6,2048,2048]{3,2,1,0} exponential(subtract.37) + convert.39 = f32[2,6,2048,2048]{3,2,1,0} convert(exponential.38) + constant.8 = f32[] constant(0) + reduce.44 = f32[2,6,2048]{2,1,0} reduce(convert.39, constant.8), dimensions={3}, to_apply=region_1.40 + reshape.45 = f32[2,6,2048,1]{3,2,1,0} reshape(reduce.44) + convert.46 = bf16[2,6,2048,1]{3,2,1,0} convert(reshape.45) + broadcast.47 = bf16[2,6,2048,1]{3,2,1,0} broadcast(convert.46), dimensions={0,1,2,3} + reshape.48 = bf16[2,6,2048]{2,1,0} reshape(broadcast.47) + broadcast.49 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reshape.48), dimensions={0,1,2} + divide.50 = bf16[2,6,2048,2048]{3,2,1,0} divide(exponential.38, broadcast.49) + Arg_2.3 = bf16[2,6,2048,128]{3,2,1,0} parameter(2), sharding={replicated} + ROOT dot.51 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.50, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + } + )"; + return hlo_text; + } + + const std::string // NOLINT + GetModuleFlash_Attention_Training_BMM1_Bias_Softmax_BMM2_HloString_BF16() { // NOLINT + const std::string hlo_text = R"( + HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,64,1024]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,1024,1024]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0})->(bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,64,1024]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} + + region_0.13 { + Arg_0.14 = bf16[] parameter(0) + Arg_1.15 = bf16[] parameter(1) + ROOT maximum.16 = bf16[] maximum(Arg_0.14, Arg_1.15) + } + + region_1.25 { + Arg_0.26 = f32[] parameter(0) + Arg_1.27 = f32[] parameter(1) + ROOT add.28 = f32[] add(Arg_0.26, Arg_1.27) + } + + region_2.47 { + Arg_0.48 = bf16[] parameter(0) + Arg_1.49 = bf16[] parameter(1) + ROOT add.50 = bf16[] add(Arg_0.48, Arg_1.49) + } + + region_3.59 { + Arg_0.60 = f32[] parameter(0) + Arg_1.61 = f32[] parameter(1) + ROOT add.62 = f32[] add(Arg_0.60, Arg_1.61) + } + + ENTRY main.72 { + Arg_0.1 = bf16[2,6,1024,64]{3,2,1,0} parameter(0), sharding={replicated} + Arg_1.2 = bf16[2,6,64,1024]{3,2,1,0} parameter(1), sharding={replicated} + dot.11 = bf16[2,6,1024,1024]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + Arg_3.4 = bf16[2,6,1024,1024]{3,2,1,0} parameter(3), sharding={replicated} + add.12 = bf16[2,6,1024,1024]{3,2,1,0} add(dot.11, Arg_3.4) + constant.9 = bf16[] constant(-inf) + reduce.17 = bf16[2,6,1024]{2,1,0} reduce(add.12, constant.9), dimensions={3}, to_apply=region_0.13 + reshape.18 = bf16[2,6,1024,1]{3,2,1,0} reshape(reduce.17) + broadcast.19 = bf16[2,6,1024,1]{3,2,1,0} broadcast(reshape.18), dimensions={0,1,2,3} + reshape.20 = bf16[2,6,1024]{2,1,0} reshape(broadcast.19) + broadcast.21 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.20), dimensions={0,1,2} + subtract.22 = bf16[2,6,1024,1024]{3,2,1,0} subtract(add.12, broadcast.21) + exponential.23 = bf16[2,6,1024,1024]{3,2,1,0} exponential(subtract.22) + convert.24 = f32[2,6,1024,1024]{3,2,1,0} convert(exponential.23) + constant.8 = f32[] constant(0) + reduce.29 = f32[2,6,1024]{2,1,0} reduce(convert.24, constant.8), dimensions={3}, to_apply=region_1.25 + reshape.30 = f32[2,6,1024,1]{3,2,1,0} reshape(reduce.29) + convert.31 = bf16[2,6,1024,1]{3,2,1,0} convert(reshape.30) + broadcast.32 = bf16[2,6,1024,1]{3,2,1,0} broadcast(convert.31), dimensions={0,1,2,3} + reshape.33 = bf16[2,6,1024]{2,1,0} reshape(broadcast.32) + broadcast.34 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.33), dimensions={0,1,2} + divide.35 = bf16[2,6,1024,1024]{3,2,1,0} divide(exponential.23, broadcast.34) + Arg_2.3 = bf16[2,6,1024,64]{3,2,1,0} parameter(2), sharding={replicated} + dot.38 = bf16[2,6,1024,64]{3,2,1,0} dot(divide.35, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + Arg_4.5 = bf16[2,6,1024,64]{3,2,1,0} parameter(4), sharding={replicated} + dot.41 = bf16[2,6,1024,1024]{3,2,1,0} dot(Arg_4.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + broadcast.54 = bf16[2,6,1024,1]{3,2,1,0} broadcast(convert.31), dimensions={0,1,2,3} + reshape.55 = bf16[2,6,1024]{2,1,0} reshape(broadcast.54) + broadcast.56 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.55), dimensions={0,1,2} + divide.57 = bf16[2,6,1024,1024]{3,2,1,0} divide(dot.41, broadcast.56) + constant.5 = bf16[] constant(1) + broadcast.6 = bf16[2,6,1024,1]{3,2,1,0} broadcast(constant.5), dimensions={} + multiply.36 = bf16[2,6,1024,1]{3,2,1,0} multiply(convert.31, convert.31) + divide.37 = bf16[2,6,1024,1]{3,2,1,0} divide(broadcast.6, multiply.36) + broadcast.42 = bf16[2,6,1024,1]{3,2,1,0} broadcast(divide.37), dimensions={0,1,2,3} + reshape.43 = bf16[2,6,1024]{2,1,0} reshape(broadcast.42) + broadcast.44 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.43), dimensions={0,1,2} + multiply.45 = bf16[2,6,1024,1024]{3,2,1,0} multiply(dot.41, broadcast.44) + multiply.46 = bf16[2,6,1024,1024]{3,2,1,0} multiply(multiply.45, exponential.23) + constant.7 = bf16[] constant(0) + reduce.51 = bf16[2,6,1024]{2,1,0} reduce(multiply.46, constant.7), dimensions={3}, to_apply=region_2.47 + reshape.52 = bf16[2,6,1024,1]{3,2,1,0} reshape(reduce.51) + negate.53 = bf16[2,6,1024,1]{3,2,1,0} negate(reshape.52) + convert.58 = f32[2,6,1024,1]{3,2,1,0} convert(negate.53) + reduce.63 = f32[2,6,1024]{2,1,0} reduce(convert.58, constant.8), dimensions={3}, to_apply=region_3.59 + broadcast.64 = f32[2,6,1024,1024]{3,2,1,0} broadcast(reduce.63), dimensions={0,1,2} + convert.65 = bf16[2,6,1024,1024]{3,2,1,0} convert(broadcast.64) + add.66 = bf16[2,6,1024,1024]{3,2,1,0} add(divide.57, convert.65) + multiply.67 = bf16[2,6,1024,1024]{3,2,1,0} multiply(add.66, exponential.23) + dot.70 = bf16[2,6,1024,64]{3,2,1,0} dot(multiply.67, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + dot.68 = bf16[2,6,1024,64]{3,2,1,0} dot(multiply.67, Arg_0.1), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + transpose.69 = bf16[2,6,64,1024]{2,3,1,0} transpose(dot.68), dimensions={0,1,3,2} + dot.39 = bf16[2,6,64,1024]{3,2,1,0} dot(Arg_4.5, divide.35), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + transpose.40 = bf16[2,6,1024,64]{2,3,1,0} transpose(dot.39), dimensions={0,1,3,2} + ROOT tuple.71 = (bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,64,1024]{2,3,1,0}, bf16[2,6,1024,64]{2,3,1,0}) tuple(dot.38, dot.70, transpose.69, transpose.40) + } + )"; + return hlo_text; + } + template + void TestImpl_Flash_Attention_BMM1_Bias_Softmax_BMM2() { + stream_executor::CudaComputeCapability cc = GetCudaComputeCapability(); + se::dnn::VersionInfo real_cudnn_version = GetCudnnVersion(); + if (!(cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0 && + real_cudnn_version >= se::dnn::VersionInfo(8, 9, 3))) { + GTEST_SKIP() << "Flash Attention is supported with the Nvidia AMPERE+ " + "GPUs and cuDNN >= 8.9.3."; + } + XlaBuilder builder(TestName()); + auto lhs_bmm1_literal = + GetInput4DLiteral({2, 6, 2048, 128}, {3, 2, 1, 0}); + auto rhs_bmm1_literal = + GetInput4DLiteral({2, 6, 128, 2048}, {3, 2, 1, 0}); + auto rhs_bmm2_literal = + GetInput4DLiteral({2, 6, 2048, 128}, {3, 2, 1, 0}); + auto bias_literal = GetInput4DLiteral({2, 6, 2048, 2048}, {3, 2, 1, 0}); + std::string hlo_string = ""; + if (std::is_same::value) { + // + } else if (std::is_same::value) { + hlo_string = + GetModuleFlash_Attention_BMM1_Bias_Softmax_BMM2_HloString_BF16(); + } + + ExecuteAndCompare(hlo_string, {&lhs_bmm1_literal, &rhs_bmm1_literal, + &rhs_bmm2_literal, &bias_literal}); + } + + template + void TestImpl_Flash_Attention_Training_BMM1_Bias_Softmax_BMM2() { + stream_executor::CudaComputeCapability cc = GetCudaComputeCapability(); + se::dnn::VersionInfo real_cudnn_version = GetCudnnVersion(); + if (!(cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0 && + real_cudnn_version >= se::dnn::VersionInfo(8, 9, 3))) { + GTEST_SKIP() << "Flash Attention is supported with the Nvidia AMPERE+ " + "GPUs and cuDNN >= 8.9.3."; + } + XlaBuilder builder(TestName()); + auto lhs_bmm1_literal = + GetInput4DLiteral({2, 6, 1024, 64}, {3, 2, 1, 0}); + auto rhs_bmm1_literal = + GetInput4DLiteral({2, 6, 64, 1024}, {3, 2, 1, 0}); + auto rhs_bmm2_literal = + GetInput4DLiteral({2, 6, 1024, 64}, {3, 2, 1, 0}); + auto bias_literal = GetInput4DLiteral({2, 6, 1024, 1024}, {3, 2, 1, 0}); + auto do_literal = GetInput4DLiteral({2, 6, 1024, 64}, {3, 2, 1, 0}); + std::string hlo_string = ""; + if (std::is_same::value) { + // + } else if (std::is_same::value) { + hlo_string = + GetModuleFlash_Attention_Training_BMM1_Bias_Softmax_BMM2_HloString_BF16(); + } + + ExecuteAndCompare(hlo_string, {&lhs_bmm1_literal, &rhs_bmm1_literal, + &rhs_bmm2_literal, &bias_literal, &do_literal}, true); + } +}; + // BMM1 - BMM2 XLA_TEST_F(MultiHeadedAttentionBMMBMM, FMHABMM_BMM_vanilla_F16) { TestImpl_FMHABMM_BMM_vanilla(); @@ -2345,5 +2781,28 @@ XLA_TEST_F(MultiHeadedAttentionBMMScaleBiasSoftmaxBMM, FMHA_Training_BMM1_Scale_Bias_Softmax_BMM2_vanilla_BF16) { TestImpl_FMHA_Training_BMM1_Scale_Bias_Softmax_BMM2_vanilla(); } + +// flash attention +// BMM1 - Scale - CausalMask - Softmax - BMM2 +XLA_TEST_F(FlashAttentionBMMScaleCausalMaskSoftmaxBMM, + Flash_Attention_BMM1_CausalMask_Softmax_BMM2_BF16) { + TestImpl_Flash_Attention_BMM1_CausalMask_Softmax_BMM2(); +} + +XLA_TEST_F(FlashAttentionBMMScaleCausalMaskSoftmaxBMM, + Flash_Attention_Training_BMM1_CausalMask_Softmax_BMM2_BF16) { + TestImpl_Flash_Attention_Training_BMM1_CausalMask_Softmax_BMM2(); +} + +// BMM1 - Scale - Bias - Softmax - BMM2 +XLA_TEST_F(FlashAttentionBMMScaleBiasSoftmaxBMM, + Flash_Attention_BMM1_Bias_Softmax_BMM2_BF16) { + TestImpl_Flash_Attention_BMM1_Bias_Softmax_BMM2(); +} + +XLA_TEST_F(FlashAttentionBMMScaleBiasSoftmaxBMM, + Flash_Attention_Training_BMM1_Bias_Softmax_BMM2_BF16) { + TestImpl_Flash_Attention_Training_BMM1_Bias_Softmax_BMM2(); +} } // namespace gpu } // namespace xla