diff --git a/xla/service/gpu/cudnn_fused_mha_rewriter.cc b/xla/service/gpu/cudnn_fused_mha_rewriter.cc index 7c258a95737ecc..07ff15a5c1d4f2 100644 --- a/xla/service/gpu/cudnn_fused_mha_rewriter.cc +++ b/xla/service/gpu/cudnn_fused_mha_rewriter.cc @@ -478,30 +478,18 @@ bool IsCausalMaskPattern(HloInstruction* mask) { causal_mask_pattern_fwd_remat, causal_mask_pattern_fwd, causal_mask_pattern_bwd); if (Match(mask, causal_mask_pattern)) { - if (param != nullptr) { + if (param != nullptr && param->parent()->IsWhileBodyComputation()) { // need to track to outside of the while loop body to find the real mask. + auto while_instr = param->parent()->WhileCallInstruction(); auto mask_index = gte->tuple_index(); - auto comp = param->parent(); - auto mod = comp->parent(); - auto name = comp->name(); - auto entry_computation = mod->entry_computation(); - bool is_causal_mask = true; - for (HloInstruction* instr : - entry_computation->MakeInstructionPostOrder()) { - if (instr->opcode() == HloOpcode::kWhile && - instr->while_body()->name() == name && - instr->operand(0)->opcode() == HloOpcode::kTuple) { - auto actual_mask = - instr->mutable_operand(0)->mutable_operand(mask_index); - auto causal_mask_pattern_fwd = - OptionalBitcast(m::Convert(m::MinimumAnyOrder( - m::Op(), - OptionalBitcast(m::MinimumAnyOrder( - m::Op(), m::Broadcast(OptionalBitcast(causal_mask))))))); - is_causal_mask &= Match(actual_mask, causal_mask_pattern_fwd); - } - } - return is_causal_mask; + auto actual_mask = + while_instr->mutable_operand(0)->mutable_operand(mask_index); + auto causal_mask_pattern_fwd = + OptionalBitcast(m::Convert(m::MinimumAnyOrder( + m::Op(), + OptionalBitcast(m::MinimumAnyOrder( + m::Op(), m::Broadcast(OptionalBitcast(causal_mask))))))); + return Match(actual_mask, causal_mask_pattern_fwd); } return true; }