Skip to content

Commit

Permalink
fix fwd graph restore
Browse files Browse the repository at this point in the history
  • Loading branch information
Cjkkkk committed Jan 11, 2024
1 parent 1804894 commit 28071dc
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 17 deletions.
44 changes: 27 additions & 17 deletions xla/service/gpu/cudnn_fused_mha_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1791,11 +1791,12 @@ absl::StatusOr<bool> FuseBwdMultiHeadedAttentionBlock(
return true;
}

Status RestoreFwdGraph(HloComputation* comp, HloInstruction* fwd_fmha_call,
HloInstruction* bmm2, HloInstruction* activation,
HloInstruction* original_bmm2_producer0,
HloInstruction* original_bmm2_producer1,
std::vector<HloInstruction*>& original_activation_producers) {
Status RestoreFwdGraph(
HloComputation* comp, HloInstruction* fwd_fmha_call, HloInstruction* bmm2,
HloInstruction* activation, HloInstruction* original_bmm2_producer0,
HloInstruction* original_bmm2_producer1,
std::vector<HloInstruction*>& original_activation_producers,
bool bmm_2_need_canonicalization) {
// If backward pattern is not matched, we need to restore the
// original graph structure.
// Replacing new GTEs added by forward FMHA call with cloned old
Expand All @@ -1811,14 +1812,19 @@ Status RestoreFwdGraph(HloComputation* comp, HloInstruction* fwd_fmha_call,
// to use the newly cloned activation.
HloInstruction* lhs = activation == original_bmm2_producer0
? cloned_activation
: original_bmm2_producer1;
: original_bmm2_producer0;
HloInstruction* rhs = activation == original_bmm2_producer0
? original_bmm2_producer1
: cloned_activation;
HloInstruction* cloned_bmm2 = comp->AddInstruction(
bmm2->CloneWithNewOperands(bmm2->shape(), {lhs, rhs}, suffix));

TF_RETURN_IF_ERROR(comp->ReplaceInstruction(output_gte, cloned_bmm2));
if (bmm_2_need_canonicalization) {
TF_RET_CHECK(output_gte->users()[0]->opcode() == HloOpcode::kTranspose);
TF_RETURN_IF_ERROR(
comp->ReplaceInstruction(output_gte->users()[0], cloned_bmm2));
} else {
TF_RETURN_IF_ERROR(comp->ReplaceInstruction(output_gte, cloned_bmm2));
}
TF_RETURN_IF_ERROR(
comp->ReplaceInstruction(activation_gte, cloned_activation));
return OkStatus();
Expand Down Expand Up @@ -1881,6 +1887,7 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
HloInstruction* original_bmm2_producer1 =
matched_result.matched_bmm_2->mutable_operand(1);

HloInstruction* original_bmm2 = matched_result.matched_bmm_2;
std::vector<HloInstruction*> original_activation_producers;
for (HloInstruction* operand : activation->mutable_operands()) {
original_activation_producers.push_back(operand);
Expand Down Expand Up @@ -1921,9 +1928,10 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
VLOG(2) << "Backward pattern not matching, skipping.";
// restore fwd graph if bwd pattern match failed
TF_RETURN_IF_ERROR(
RestoreFwdGraph(comp, fwd_fmha_call, matched_result.matched_bmm_2,
activation, original_bmm2_producer0, original_bmm2_producer1,
original_activation_producers));
RestoreFwdGraph(comp, fwd_fmha_call, original_bmm2, activation,
original_bmm2_producer0, original_bmm2_producer1,
original_activation_producers,
matched_result.need_canonicalization));
continue;
}
// if fwd uses mask input, then bwd needs cudnn 8.9.1 to take in a mask
Expand All @@ -1934,9 +1942,10 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
stream_executor::dnn::VersionInfo(8, 9, 1))) {
// restore fwd graph if bwd pattern match failed
TF_RETURN_IF_ERROR(
RestoreFwdGraph(comp, fwd_fmha_call, matched_result.matched_bmm_2,
activation, original_bmm2_producer0, original_bmm2_producer1,
original_activation_producers));
RestoreFwdGraph(comp, fwd_fmha_call, original_bmm2, activation,
original_bmm2_producer0, original_bmm2_producer1,
original_activation_producers,
matched_result.need_canonicalization));
continue;
}
// check if dbias exist and the cudnn version is > 8.9.1. We
Expand All @@ -1948,9 +1957,10 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
stream_executor::dnn::VersionInfo(8, 9, 1))) {
// restore fwd graph if bwd pattern match failed
TF_RETURN_IF_ERROR(
RestoreFwdGraph(comp, fwd_fmha_call, matched_result.matched_bmm_2,
activation, original_bmm2_producer0, original_bmm2_producer1,
original_activation_producers));
RestoreFwdGraph(comp, fwd_fmha_call, original_bmm2, activation,
original_bmm2_producer0, original_bmm2_producer1,
original_activation_producers,
matched_result.need_canonicalization));
continue;
}
// Canonicalize gemms
Expand Down
179 changes: 179 additions & 0 deletions xla/service/gpu/cudnn_fused_mha_rewriter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4623,6 +4623,185 @@ main {
EXPECT_EQ(config.is_flash_attention(), true);
EXPECT_EQ(config.is_causal_mask(), true);
}

TEST_F(CudnnFusedMhaRewriterTestHloTest,
BF16TrainingBmm2CanonicalizationRestoreFwdGraph) {
const char* module_str = R"(
HloModule pjit__unnamed_function_, entry_computation_layout={(bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,4,256,256]{3,2,1,0})->(bf16[4,256,8,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={false,false,false,false}, num_partitions=4
region_0.6 {
Arg_0.7 = bf16[] parameter(0)
Arg_1.8 = bf16[] parameter(1)
ROOT maximum.5 = bf16[] maximum(Arg_0.7, Arg_1.8)
}
region_1.10 {
Arg_0.11 = f32[] parameter(0)
Arg_1.12 = f32[] parameter(1)
ROOT add.14 = f32[] add(Arg_0.11, Arg_1.12)
}
add.clone {
x.1 = u32[] parameter(0)
y.1 = u32[] parameter(1)
ROOT add.15 = u32[] add(x.1, y.1)
}
region_2.65 {
Arg_0.66 = bf16[] parameter(0)
Arg_1.67 = bf16[] parameter(1)
ROOT add.16 = bf16[] add(Arg_0.66, Arg_1.67)
}
ENTRY main.164_spmd {
param = bf16[2,256,4,64]{3,2,1,0} parameter(2), sharding={devices=[2,1,2,1]<=[4]}
transpose.26 = bf16[2,4,64,256]{3,2,1,0} transpose(param), dimensions={0,2,3,1}
param.1 = bf16[2,256,4,64]{3,2,1,0} parameter(0), sharding={devices=[2,1,2,1]<=[4]}
transpose.27 = bf16[2,4,256,64]{3,2,1,0} transpose(param.1), dimensions={0,2,1,3}
constant.46 = bf16[] constant(0.5)
broadcast.126 = bf16[2,4,256,64]{3,2,1,0} broadcast(constant.46), dimensions={}
multiply.34 = bf16[2,4,256,64]{3,2,1,0} multiply(transpose.27, broadcast.126)
param.2 = bf16[2,256,4,64]{3,2,1,0} parameter(1), sharding={devices=[2,1,2,1]<=[4]}
transpose.29 = bf16[2,4,64,256]{3,2,1,0} transpose(param.2), dimensions={0,2,3,1}
dot.12 = bf16[2,4,256,256]{3,2,1,0} dot(multiply.34, transpose.29), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
param.3 = bf16[2,4,256,256]{3,2,1,0} parameter(4), sharding={devices=[2,2,1,1]<=[4]}
add.17 = bf16[2,4,256,256]{3,2,1,0} add(dot.12, param.3)
constant.47 = bf16[] constant(-inf)
reduce.4 = bf16[2,4,256]{2,1,0} reduce(add.17, constant.47), dimensions={3}, to_apply=region_0.6
broadcast.127 = bf16[2,4,256,256]{3,2,1,0} broadcast(reduce.4), dimensions={0,1,2}
subtract.14 = bf16[2,4,256,256]{3,2,1,0} subtract(add.17, broadcast.127)
exponential.2 = bf16[2,4,256,256]{3,2,1,0} exponential(subtract.14)
convert.46 = f32[2,4,256,256]{3,2,1,0} convert(exponential.2)
constant.48 = f32[] constant(0)
reduce.5 = f32[2,4,256]{2,1,0} reduce(convert.46, constant.48), dimensions={3}, to_apply=region_1.10
convert.47 = bf16[2,4,256]{2,1,0} convert(reduce.5)
broadcast.128 = bf16[2,4,256,256]{3,2,1,0} broadcast(convert.47), dimensions={0,1,2}
divide.7 = bf16[2,4,256,256]{3,2,1,0} divide(exponential.2, broadcast.128)
broadcast.129 = f32[4096]{0} broadcast(constant.48), dimensions={}
constant.50 = u32[] constant(0)
broadcast.131 = u32[8192]{0} broadcast(constant.50), dimensions={}
broadcast.133 = u32[4096]{0} broadcast(constant.50), dimensions={}
iota.3 = u32[8192]{0} iota(), iota_dimension=0
slice.14 = u32[4096]{0} slice(iota.3), slice={[0:4096]}
slice.15 = u32[4096]{0} slice(iota.3), slice={[4096:8192]}
custom-call.3 = (u32[4096]{0}, u32[4096]{0}) custom-call(broadcast.133, broadcast.133, slice.14, slice.15), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[4096]{0}, u32[4096]{0}, u32[4096]{0}, u32[4096]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\000\020\000\000\000\000\000\000"
get-tuple-element.6 = u32[4096]{0} get-tuple-element(custom-call.3), index=0
constant.115 = u32[1]{0} constant({0})
constant.52 = u32[4]{0} constant({0, 0, 1, 1})
partition-id = u32[] partition-id()
dynamic-slice.21 = u32[1]{0} dynamic-slice(constant.52, partition-id), dynamic_slice_sizes={1}
constant.116 = u32[1]{0} constant({1})
clamp.3 = u32[1]{0} clamp(constant.115, dynamic-slice.21, constant.116)
convert.48 = s32[1]{0} convert(clamp.3)
constant.117 = s32[1]{0} constant({2048})
multiply.35 = s32[1]{0} multiply(convert.48, constant.117)
bitcast.105 = s32[] bitcast(multiply.35)
dynamic-slice.22 = u32[2048]{0} dynamic-slice(get-tuple-element.6, bitcast.105), dynamic_slice_sizes={2048}
constant.58 = s32[4]{0} constant({0, 0, 1, 1})
dynamic-slice.23 = s32[1]{0} dynamic-slice(constant.58, partition-id), dynamic_slice_sizes={1}
multiply.36 = s32[1]{0} multiply(dynamic-slice.23, constant.117)
bitcast.108 = s32[] bitcast(multiply.36)
dynamic-update-slice.2 = u32[8192]{0} dynamic-update-slice(broadcast.131, dynamic-slice.22, bitcast.108)
get-tuple-element.7 = u32[4096]{0} get-tuple-element(custom-call.3), index=1
dynamic-slice.24 = u32[2048]{0} dynamic-slice(get-tuple-element.7, bitcast.105), dynamic_slice_sizes={2048}
constant.65 = s32[] constant(4096)
add.18 = s32[] add(bitcast.108, constant.65)
dynamic-update-slice.3 = u32[8192]{0} dynamic-update-slice(dynamic-update-slice.2, dynamic-slice.24, add.18)
all-reduce = u32[8192]{0} all-reduce(dynamic-update-slice.3), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=add.clone
constant.118 = s32[1]{0} constant({4096})
multiply.37 = s32[1]{0} multiply(dynamic-slice.23, constant.118)
bitcast.119 = s32[] bitcast(multiply.37)
dynamic-slice.25 = u32[4096]{0} dynamic-slice(all-reduce, bitcast.119), dynamic_slice_sizes={4096}
constant.69 = u32[] constant(9)
broadcast.134 = u32[4096]{0} broadcast(constant.69), dimensions={}
shift-right-logical.6 = u32[4096]{0} shift-right-logical(dynamic-slice.25, broadcast.134)
constant.70 = u32[] constant(1065353216)
broadcast.135 = u32[4096]{0} broadcast(constant.70), dimensions={}
or.5 = u32[4096]{0} or(shift-right-logical.6, broadcast.135)
bitcast-convert.5 = f32[4096]{0} bitcast-convert(or.5)
constant.71 = f32[] constant(-1)
broadcast.136 = f32[4096]{0} broadcast(constant.71), dimensions={}
add.19 = f32[4096]{0} add(bitcast-convert.5, broadcast.136)
maximum.6 = f32[4096]{0} maximum(broadcast.129, add.19)
constant.72 = f32[] constant(0.5)
broadcast.137 = f32[4096]{0} broadcast(constant.72), dimensions={}
compare.4 = pred[4096]{0} compare(maximum.6, broadcast.137), direction=LT
bitcast.135 = pred[2,8,256]{2,1,0} bitcast(compare.4)
convert.49 = bf16[2,8,256]{2,1,0} convert(bitcast.135)
constant.80 = s32[] constant(0)
constant.78 = s32[4]{0} constant({0, 4, 0, 4})
dynamic-slice.26 = s32[1]{0} dynamic-slice(constant.78, partition-id), dynamic_slice_sizes={1}
bitcast.181 = s32[] bitcast(dynamic-slice.26)
dynamic-slice.27 = bf16[2,4,256]{2,1,0} dynamic-slice(convert.49, constant.80, bitcast.181, constant.80), dynamic_slice_sizes={2,4,256}
broadcast.139 = bf16[2,4,256,256]{3,2,1,0} broadcast(dynamic-slice.27), dimensions={0,1,3}
multiply.38 = bf16[2,4,256,256]{3,2,1,0} multiply(divide.7, broadcast.139)
constant.93 = bf16[] constant(2)
broadcast.141 = bf16[2,4,256,256]{3,2,1,0} broadcast(constant.93), dimensions={}
multiply.39 = bf16[2,4,256,256]{3,2,1,0} multiply(multiply.38, broadcast.141)
dot.13 = bf16[2,4,64,256]{3,2,1,0} dot(transpose.26, multiply.39), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
transpose.31 = bf16[4,2,64,256]{3,2,1,0} transpose(dot.13), dimensions={1,0,2,3}
bitcast.154 = bf16[2,256,4,64]{1,3,0,2} bitcast(transpose.31)
all-gather = bf16[2,256,8,64]{1,3,0,2} all-gather(bitcast.154), channel_id=2, replica_groups={{0,1},{2,3}}, dimensions={2}, use_global_device_ids=true
bitcast.155 = bf16[8,2,64,256]{3,2,1,0} bitcast(all-gather)
transpose.32 = bf16[2,8,64,256]{3,2,1,0} transpose(bitcast.155), dimensions={1,0,2,3}
bitcast.157 = bf16[2,256,8,64]{1,3,2,0} bitcast(transpose.32)
all-gather.1 = bf16[4,256,8,64]{1,3,2,0} all-gather(bitcast.157), channel_id=3, replica_groups={{0,2},{1,3}}, dimensions={0}, use_global_device_ids=true
bitcast.236 = bf16[4,8,64,256]{3,2,1,0} bitcast(all-gather.1)
transpose.38 = bf16[4,256,8,64]{3,2,1,0} transpose(bitcast.236), dimensions={0,3,1,2}
param.4 = bf16[2,256,4,64]{3,2,1,0} parameter(3), sharding={devices=[2,1,2,1]<=[4]}
transpose.33 = bf16[2,4,256,64]{3,2,1,0} transpose(param.4), dimensions={0,2,1,3}
dot.14 = bf16[2,4,256,256]{3,2,1,0} dot(transpose.33, transpose.26), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
broadcast.142 = bf16[4096]{0} broadcast(constant.93), dimensions={}
constant.95 = bf16[] constant(0)
broadcast.143 = bf16[4096]{0} broadcast(constant.95), dimensions={}
select.4 = bf16[4096]{0} select(compare.4, broadcast.142, broadcast.143)
bitcast.176 = bf16[2,8,256]{2,1,0} bitcast(select.4)
dynamic-slice.28 = bf16[2,4,256]{2,1,0} dynamic-slice(bitcast.176, constant.80, bitcast.181, constant.80), dynamic_slice_sizes={2,4,256}
broadcast.145 = bf16[2,4,256,256]{3,2,1,0} broadcast(dynamic-slice.28), dimensions={0,1,3}
multiply.40 = bf16[2,4,256,256]{3,2,1,0} multiply(dot.14, broadcast.145)
divide.8 = bf16[2,4,256,256]{3,2,1,0} divide(multiply.40, broadcast.128)
constant.106 = bf16[] constant(1)
broadcast.146 = bf16[2,4,256]{2,1,0} broadcast(constant.106), dimensions={}
multiply.41 = bf16[2,4,256]{2,1,0} multiply(convert.47, convert.47)
divide.9 = bf16[2,4,256]{2,1,0} divide(broadcast.146, multiply.41)
broadcast.147 = bf16[2,4,256,256]{3,2,1,0} broadcast(divide.9), dimensions={0,1,2}
multiply.42 = bf16[2,4,256,256]{3,2,1,0} multiply(multiply.40, broadcast.147)
multiply.43 = bf16[2,4,256,256]{3,2,1,0} multiply(multiply.42, exponential.2)
reduce.6 = bf16[2,4,256]{2,1,0} reduce(multiply.43, constant.95), dimensions={3}, to_apply=region_2.65
negate.4 = bf16[2,4,256]{2,1,0} negate(reduce.6)
broadcast.148 = bf16[2,4,256,256]{3,2,1,0} broadcast(negate.4), dimensions={0,1,2}
add.20 = bf16[2,4,256,256]{3,2,1,0} add(divide.8, broadcast.148)
multiply.44 = bf16[2,4,256,256]{3,2,1,0} multiply(add.20, exponential.2)
transpose.34 = bf16[2,4,256,64]{3,2,1,0} transpose(param.2), dimensions={0,2,1,3}
dot.15 = bf16[2,4,256,64]{3,2,1,0} dot(multiply.44, transpose.34), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
multiply.45 = bf16[2,4,256,64]{3,2,1,0} multiply(dot.15, broadcast.126)
transpose.39 = bf16[2,256,4,64]{3,2,1,0} transpose(multiply.45), dimensions={0,2,1,3}
dot.16 = bf16[2,4,256,64]{3,2,1,0} dot(multiply.44, multiply.34), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
transpose.40 = bf16[2,256,4,64]{3,2,1,0} transpose(dot.16), dimensions={0,2,1,3}
transpose.36 = bf16[2,4,64,256]{3,2,1,0} transpose(param.4), dimensions={0,2,3,1}
dot.11 = bf16[2,4,64,256]{3,2,1,0} dot(transpose.36, multiply.39), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
transpose.41 = bf16[2,256,4,64]{3,2,1,0} transpose(dot.11), dimensions={0,3,1,2}
ROOT tuple.2 = (bf16[4,256,8,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}) tuple(transpose.38, transpose.39, transpose.40, transpose.41)
} // main.164_spmd
)";
// Dropout bwd pattern not supported, should not lower fwd as well
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
CudnnFusedMHARewriter fusedMhaRewriter{
GetCudaComputeCapability(),
GetCudnnVersionWithDbiasAndMaskBwdInputSupport()};
TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
SCOPED_TRACE(m->ToString());
// check if fwd graph has been restored with cloned activation
EXPECT_THAT(
m->entry_computation()->root_instruction(),
GmockMatch(m::Tuple(
m::Transpose(), m::Transpose(), m::Transpose(),
m::Transpose(m::Dot(
m::Op(), m::Op().WithPredicate([](const HloInstruction* instr) {
return instr->name() == "multiply.39.fmha_no_match_clone";
}))))));
}

} // anonymous namespace
} // namespace gpu
} // namespace xla

0 comments on commit 28071dc

Please sign in to comment.