diff --git a/xla/service/gpu/cudnn_fused_mha_rewriter.cc b/xla/service/gpu/cudnn_fused_mha_rewriter.cc index d86b13ea6a7c1f..295ed9e6aedaab 100644 --- a/xla/service/gpu/cudnn_fused_mha_rewriter.cc +++ b/xla/service/gpu/cudnn_fused_mha_rewriter.cc @@ -1791,11 +1791,12 @@ absl::StatusOr FuseBwdMultiHeadedAttentionBlock( return true; } -Status RestoreFwdGraph(HloComputation* comp, HloInstruction* fwd_fmha_call, - HloInstruction* bmm2, HloInstruction* activation, - HloInstruction* original_bmm2_producer0, - HloInstruction* original_bmm2_producer1, - std::vector& original_activation_producers) { +Status RestoreFwdGraph( + HloComputation* comp, HloInstruction* fwd_fmha_call, HloInstruction* bmm2, + HloInstruction* activation, HloInstruction* original_bmm2_producer0, + HloInstruction* original_bmm2_producer1, + std::vector& original_activation_producers, + bool bmm_2_need_canonicalization) { // If backward pattern is not matched, we need to restore the // original graph structure. // Replacing new GTEs added by forward FMHA call with cloned old @@ -1811,14 +1812,19 @@ Status RestoreFwdGraph(HloComputation* comp, HloInstruction* fwd_fmha_call, // to use the newly cloned activation. HloInstruction* lhs = activation == original_bmm2_producer0 ? cloned_activation - : original_bmm2_producer1; + : original_bmm2_producer0; HloInstruction* rhs = activation == original_bmm2_producer0 ? original_bmm2_producer1 : cloned_activation; HloInstruction* cloned_bmm2 = comp->AddInstruction( bmm2->CloneWithNewOperands(bmm2->shape(), {lhs, rhs}, suffix)); - - TF_RETURN_IF_ERROR(comp->ReplaceInstruction(output_gte, cloned_bmm2)); + if (bmm_2_need_canonicalization) { + TF_RET_CHECK(output_gte->users()[0]->opcode() == HloOpcode::kTranspose); + TF_RETURN_IF_ERROR( + comp->ReplaceInstruction(output_gte->users()[0], cloned_bmm2)); + } else { + TF_RETURN_IF_ERROR(comp->ReplaceInstruction(output_gte, cloned_bmm2)); + } TF_RETURN_IF_ERROR( comp->ReplaceInstruction(activation_gte, cloned_activation)); return OkStatus(); @@ -1881,6 +1887,7 @@ absl::StatusOr CudnnFusedMHARewriter::Run( HloInstruction* original_bmm2_producer1 = matched_result.matched_bmm_2->mutable_operand(1); + HloInstruction* original_bmm2 = matched_result.matched_bmm_2; std::vector original_activation_producers; for (HloInstruction* operand : activation->mutable_operands()) { original_activation_producers.push_back(operand); @@ -1921,9 +1928,10 @@ absl::StatusOr CudnnFusedMHARewriter::Run( VLOG(2) << "Backward pattern not matching, skipping."; // restore fwd graph if bwd pattern match failed TF_RETURN_IF_ERROR( - RestoreFwdGraph(comp, fwd_fmha_call, matched_result.matched_bmm_2, - activation, original_bmm2_producer0, original_bmm2_producer1, - original_activation_producers)); + RestoreFwdGraph(comp, fwd_fmha_call, original_bmm2, activation, + original_bmm2_producer0, original_bmm2_producer1, + original_activation_producers, + matched_result.need_canonicalization)); continue; } // if fwd uses mask input, then bwd needs cudnn 8.9.1 to take in a mask @@ -1934,9 +1942,10 @@ absl::StatusOr CudnnFusedMHARewriter::Run( stream_executor::dnn::VersionInfo(8, 9, 1))) { // restore fwd graph if bwd pattern match failed TF_RETURN_IF_ERROR( - RestoreFwdGraph(comp, fwd_fmha_call, matched_result.matched_bmm_2, - activation, original_bmm2_producer0, original_bmm2_producer1, - original_activation_producers)); + RestoreFwdGraph(comp, fwd_fmha_call, original_bmm2, activation, + original_bmm2_producer0, original_bmm2_producer1, + original_activation_producers, + matched_result.need_canonicalization)); continue; } // check if dbias exist and the cudnn version is > 8.9.1. We @@ -1948,9 +1957,10 @@ absl::StatusOr CudnnFusedMHARewriter::Run( stream_executor::dnn::VersionInfo(8, 9, 1))) { // restore fwd graph if bwd pattern match failed TF_RETURN_IF_ERROR( - RestoreFwdGraph(comp, fwd_fmha_call, matched_result.matched_bmm_2, - activation, original_bmm2_producer0, original_bmm2_producer1, - original_activation_producers)); + RestoreFwdGraph(comp, fwd_fmha_call, original_bmm2, activation, + original_bmm2_producer0, original_bmm2_producer1, + original_activation_producers, + matched_result.need_canonicalization)); continue; } // Canonicalize gemms diff --git a/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc b/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc index b21a5387b67d7e..beb8a934f31428 100644 --- a/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc +++ b/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc @@ -4623,6 +4623,185 @@ main { EXPECT_EQ(config.is_flash_attention(), true); EXPECT_EQ(config.is_causal_mask(), true); } + +TEST_F(CudnnFusedMhaRewriterTestHloTest, + BF16TrainingBmm2CanonicalizationRestoreFwdGraph) { + const char* module_str = R"( +HloModule pjit__unnamed_function_, entry_computation_layout={(bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,4,256,256]{3,2,1,0})->(bf16[4,256,8,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={false,false,false,false}, num_partitions=4 + +region_0.6 { + Arg_0.7 = bf16[] parameter(0) + Arg_1.8 = bf16[] parameter(1) + ROOT maximum.5 = bf16[] maximum(Arg_0.7, Arg_1.8) +} + +region_1.10 { + Arg_0.11 = f32[] parameter(0) + Arg_1.12 = f32[] parameter(1) + ROOT add.14 = f32[] add(Arg_0.11, Arg_1.12) +} + +add.clone { + x.1 = u32[] parameter(0) + y.1 = u32[] parameter(1) + ROOT add.15 = u32[] add(x.1, y.1) +} + +region_2.65 { + Arg_0.66 = bf16[] parameter(0) + Arg_1.67 = bf16[] parameter(1) + ROOT add.16 = bf16[] add(Arg_0.66, Arg_1.67) +} + +ENTRY main.164_spmd { + param = bf16[2,256,4,64]{3,2,1,0} parameter(2), sharding={devices=[2,1,2,1]<=[4]} + transpose.26 = bf16[2,4,64,256]{3,2,1,0} transpose(param), dimensions={0,2,3,1} + param.1 = bf16[2,256,4,64]{3,2,1,0} parameter(0), sharding={devices=[2,1,2,1]<=[4]} + transpose.27 = bf16[2,4,256,64]{3,2,1,0} transpose(param.1), dimensions={0,2,1,3} + constant.46 = bf16[] constant(0.5) + broadcast.126 = bf16[2,4,256,64]{3,2,1,0} broadcast(constant.46), dimensions={} + multiply.34 = bf16[2,4,256,64]{3,2,1,0} multiply(transpose.27, broadcast.126) + param.2 = bf16[2,256,4,64]{3,2,1,0} parameter(1), sharding={devices=[2,1,2,1]<=[4]} + transpose.29 = bf16[2,4,64,256]{3,2,1,0} transpose(param.2), dimensions={0,2,3,1} + dot.12 = bf16[2,4,256,256]{3,2,1,0} dot(multiply.34, transpose.29), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + param.3 = bf16[2,4,256,256]{3,2,1,0} parameter(4), sharding={devices=[2,2,1,1]<=[4]} + add.17 = bf16[2,4,256,256]{3,2,1,0} add(dot.12, param.3) + constant.47 = bf16[] constant(-inf) + reduce.4 = bf16[2,4,256]{2,1,0} reduce(add.17, constant.47), dimensions={3}, to_apply=region_0.6 + broadcast.127 = bf16[2,4,256,256]{3,2,1,0} broadcast(reduce.4), dimensions={0,1,2} + subtract.14 = bf16[2,4,256,256]{3,2,1,0} subtract(add.17, broadcast.127) + exponential.2 = bf16[2,4,256,256]{3,2,1,0} exponential(subtract.14) + convert.46 = f32[2,4,256,256]{3,2,1,0} convert(exponential.2) + constant.48 = f32[] constant(0) + reduce.5 = f32[2,4,256]{2,1,0} reduce(convert.46, constant.48), dimensions={3}, to_apply=region_1.10 + convert.47 = bf16[2,4,256]{2,1,0} convert(reduce.5) + broadcast.128 = bf16[2,4,256,256]{3,2,1,0} broadcast(convert.47), dimensions={0,1,2} + divide.7 = bf16[2,4,256,256]{3,2,1,0} divide(exponential.2, broadcast.128) + broadcast.129 = f32[4096]{0} broadcast(constant.48), dimensions={} + constant.50 = u32[] constant(0) + broadcast.131 = u32[8192]{0} broadcast(constant.50), dimensions={} + broadcast.133 = u32[4096]{0} broadcast(constant.50), dimensions={} + iota.3 = u32[8192]{0} iota(), iota_dimension=0 + slice.14 = u32[4096]{0} slice(iota.3), slice={[0:4096]} + slice.15 = u32[4096]{0} slice(iota.3), slice={[4096:8192]} + custom-call.3 = (u32[4096]{0}, u32[4096]{0}) custom-call(broadcast.133, broadcast.133, slice.14, slice.15), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[4096]{0}, u32[4096]{0}, u32[4096]{0}, u32[4096]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\000\020\000\000\000\000\000\000" + get-tuple-element.6 = u32[4096]{0} get-tuple-element(custom-call.3), index=0 + constant.115 = u32[1]{0} constant({0}) + constant.52 = u32[4]{0} constant({0, 0, 1, 1}) + partition-id = u32[] partition-id() + dynamic-slice.21 = u32[1]{0} dynamic-slice(constant.52, partition-id), dynamic_slice_sizes={1} + constant.116 = u32[1]{0} constant({1}) + clamp.3 = u32[1]{0} clamp(constant.115, dynamic-slice.21, constant.116) + convert.48 = s32[1]{0} convert(clamp.3) + constant.117 = s32[1]{0} constant({2048}) + multiply.35 = s32[1]{0} multiply(convert.48, constant.117) + bitcast.105 = s32[] bitcast(multiply.35) + dynamic-slice.22 = u32[2048]{0} dynamic-slice(get-tuple-element.6, bitcast.105), dynamic_slice_sizes={2048} + constant.58 = s32[4]{0} constant({0, 0, 1, 1}) + dynamic-slice.23 = s32[1]{0} dynamic-slice(constant.58, partition-id), dynamic_slice_sizes={1} + multiply.36 = s32[1]{0} multiply(dynamic-slice.23, constant.117) + bitcast.108 = s32[] bitcast(multiply.36) + dynamic-update-slice.2 = u32[8192]{0} dynamic-update-slice(broadcast.131, dynamic-slice.22, bitcast.108) + get-tuple-element.7 = u32[4096]{0} get-tuple-element(custom-call.3), index=1 + dynamic-slice.24 = u32[2048]{0} dynamic-slice(get-tuple-element.7, bitcast.105), dynamic_slice_sizes={2048} + constant.65 = s32[] constant(4096) + add.18 = s32[] add(bitcast.108, constant.65) + dynamic-update-slice.3 = u32[8192]{0} dynamic-update-slice(dynamic-update-slice.2, dynamic-slice.24, add.18) + all-reduce = u32[8192]{0} all-reduce(dynamic-update-slice.3), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=add.clone + constant.118 = s32[1]{0} constant({4096}) + multiply.37 = s32[1]{0} multiply(dynamic-slice.23, constant.118) + bitcast.119 = s32[] bitcast(multiply.37) + dynamic-slice.25 = u32[4096]{0} dynamic-slice(all-reduce, bitcast.119), dynamic_slice_sizes={4096} + constant.69 = u32[] constant(9) + broadcast.134 = u32[4096]{0} broadcast(constant.69), dimensions={} + shift-right-logical.6 = u32[4096]{0} shift-right-logical(dynamic-slice.25, broadcast.134) + constant.70 = u32[] constant(1065353216) + broadcast.135 = u32[4096]{0} broadcast(constant.70), dimensions={} + or.5 = u32[4096]{0} or(shift-right-logical.6, broadcast.135) + bitcast-convert.5 = f32[4096]{0} bitcast-convert(or.5) + constant.71 = f32[] constant(-1) + broadcast.136 = f32[4096]{0} broadcast(constant.71), dimensions={} + add.19 = f32[4096]{0} add(bitcast-convert.5, broadcast.136) + maximum.6 = f32[4096]{0} maximum(broadcast.129, add.19) + constant.72 = f32[] constant(0.5) + broadcast.137 = f32[4096]{0} broadcast(constant.72), dimensions={} + compare.4 = pred[4096]{0} compare(maximum.6, broadcast.137), direction=LT + bitcast.135 = pred[2,8,256]{2,1,0} bitcast(compare.4) + convert.49 = bf16[2,8,256]{2,1,0} convert(bitcast.135) + constant.80 = s32[] constant(0) + constant.78 = s32[4]{0} constant({0, 4, 0, 4}) + dynamic-slice.26 = s32[1]{0} dynamic-slice(constant.78, partition-id), dynamic_slice_sizes={1} + bitcast.181 = s32[] bitcast(dynamic-slice.26) + dynamic-slice.27 = bf16[2,4,256]{2,1,0} dynamic-slice(convert.49, constant.80, bitcast.181, constant.80), dynamic_slice_sizes={2,4,256} + broadcast.139 = bf16[2,4,256,256]{3,2,1,0} broadcast(dynamic-slice.27), dimensions={0,1,3} + multiply.38 = bf16[2,4,256,256]{3,2,1,0} multiply(divide.7, broadcast.139) + constant.93 = bf16[] constant(2) + broadcast.141 = bf16[2,4,256,256]{3,2,1,0} broadcast(constant.93), dimensions={} + multiply.39 = bf16[2,4,256,256]{3,2,1,0} multiply(multiply.38, broadcast.141) + dot.13 = bf16[2,4,64,256]{3,2,1,0} dot(transpose.26, multiply.39), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + transpose.31 = bf16[4,2,64,256]{3,2,1,0} transpose(dot.13), dimensions={1,0,2,3} + bitcast.154 = bf16[2,256,4,64]{1,3,0,2} bitcast(transpose.31) + all-gather = bf16[2,256,8,64]{1,3,0,2} all-gather(bitcast.154), channel_id=2, replica_groups={{0,1},{2,3}}, dimensions={2}, use_global_device_ids=true + bitcast.155 = bf16[8,2,64,256]{3,2,1,0} bitcast(all-gather) + transpose.32 = bf16[2,8,64,256]{3,2,1,0} transpose(bitcast.155), dimensions={1,0,2,3} + bitcast.157 = bf16[2,256,8,64]{1,3,2,0} bitcast(transpose.32) + all-gather.1 = bf16[4,256,8,64]{1,3,2,0} all-gather(bitcast.157), channel_id=3, replica_groups={{0,2},{1,3}}, dimensions={0}, use_global_device_ids=true + bitcast.236 = bf16[4,8,64,256]{3,2,1,0} bitcast(all-gather.1) + transpose.38 = bf16[4,256,8,64]{3,2,1,0} transpose(bitcast.236), dimensions={0,3,1,2} + param.4 = bf16[2,256,4,64]{3,2,1,0} parameter(3), sharding={devices=[2,1,2,1]<=[4]} + transpose.33 = bf16[2,4,256,64]{3,2,1,0} transpose(param.4), dimensions={0,2,1,3} + dot.14 = bf16[2,4,256,256]{3,2,1,0} dot(transpose.33, transpose.26), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + broadcast.142 = bf16[4096]{0} broadcast(constant.93), dimensions={} + constant.95 = bf16[] constant(0) + broadcast.143 = bf16[4096]{0} broadcast(constant.95), dimensions={} + select.4 = bf16[4096]{0} select(compare.4, broadcast.142, broadcast.143) + bitcast.176 = bf16[2,8,256]{2,1,0} bitcast(select.4) + dynamic-slice.28 = bf16[2,4,256]{2,1,0} dynamic-slice(bitcast.176, constant.80, bitcast.181, constant.80), dynamic_slice_sizes={2,4,256} + broadcast.145 = bf16[2,4,256,256]{3,2,1,0} broadcast(dynamic-slice.28), dimensions={0,1,3} + multiply.40 = bf16[2,4,256,256]{3,2,1,0} multiply(dot.14, broadcast.145) + divide.8 = bf16[2,4,256,256]{3,2,1,0} divide(multiply.40, broadcast.128) + constant.106 = bf16[] constant(1) + broadcast.146 = bf16[2,4,256]{2,1,0} broadcast(constant.106), dimensions={} + multiply.41 = bf16[2,4,256]{2,1,0} multiply(convert.47, convert.47) + divide.9 = bf16[2,4,256]{2,1,0} divide(broadcast.146, multiply.41) + broadcast.147 = bf16[2,4,256,256]{3,2,1,0} broadcast(divide.9), dimensions={0,1,2} + multiply.42 = bf16[2,4,256,256]{3,2,1,0} multiply(multiply.40, broadcast.147) + multiply.43 = bf16[2,4,256,256]{3,2,1,0} multiply(multiply.42, exponential.2) + reduce.6 = bf16[2,4,256]{2,1,0} reduce(multiply.43, constant.95), dimensions={3}, to_apply=region_2.65 + negate.4 = bf16[2,4,256]{2,1,0} negate(reduce.6) + broadcast.148 = bf16[2,4,256,256]{3,2,1,0} broadcast(negate.4), dimensions={0,1,2} + add.20 = bf16[2,4,256,256]{3,2,1,0} add(divide.8, broadcast.148) + multiply.44 = bf16[2,4,256,256]{3,2,1,0} multiply(add.20, exponential.2) + transpose.34 = bf16[2,4,256,64]{3,2,1,0} transpose(param.2), dimensions={0,2,1,3} + dot.15 = bf16[2,4,256,64]{3,2,1,0} dot(multiply.44, transpose.34), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + multiply.45 = bf16[2,4,256,64]{3,2,1,0} multiply(dot.15, broadcast.126) + transpose.39 = bf16[2,256,4,64]{3,2,1,0} transpose(multiply.45), dimensions={0,2,1,3} + dot.16 = bf16[2,4,256,64]{3,2,1,0} dot(multiply.44, multiply.34), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + transpose.40 = bf16[2,256,4,64]{3,2,1,0} transpose(dot.16), dimensions={0,2,1,3} + transpose.36 = bf16[2,4,64,256]{3,2,1,0} transpose(param.4), dimensions={0,2,3,1} + dot.11 = bf16[2,4,64,256]{3,2,1,0} dot(transpose.36, multiply.39), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + transpose.41 = bf16[2,256,4,64]{3,2,1,0} transpose(dot.11), dimensions={0,3,1,2} + ROOT tuple.2 = (bf16[4,256,8,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}) tuple(transpose.38, transpose.39, transpose.40, transpose.41) +} // main.164_spmd +)"; + // Dropout bwd pattern not supported, should not lower fwd as well + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + CudnnFusedMHARewriter fusedMhaRewriter{ + GetCudaComputeCapability(), + GetCudnnVersionWithDbiasAndMaskBwdInputSupport()}; + TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); + SCOPED_TRACE(m->ToString()); + // check if fwd graph has been restored with cloned activation + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + m::Transpose(), m::Transpose(), m::Transpose(), + m::Transpose(m::Dot( + m::Op(), m::Op().WithPredicate([](const HloInstruction* instr) { + return instr->name() == "multiply.39.fmha_no_match_clone"; + })))))); +} + } // anonymous namespace } // namespace gpu } // namespace xla