Skip to content

Commit

Permalink
use while body back pointer to find causal mask
Browse files Browse the repository at this point in the history
  • Loading branch information
Cjkkkk committed Dec 5, 2023
1 parent 490d0a3 commit 90e765f
Showing 1 changed file with 10 additions and 22 deletions.
32 changes: 10 additions & 22 deletions xla/service/gpu/cudnn_fused_mha_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -478,30 +478,18 @@ bool IsCausalMaskPattern(HloInstruction* mask) {
causal_mask_pattern_fwd_remat, causal_mask_pattern_fwd,
causal_mask_pattern_bwd);
if (Match(mask, causal_mask_pattern)) {
if (param != nullptr) {
if (param != nullptr && param->parent()->IsWhileBodyComputation()) {
// need to track to outside of the while loop body to find the real mask.
auto while_instr = param->parent()->WhileCallInstruction();
auto mask_index = gte->tuple_index();
auto comp = param->parent();
auto mod = comp->parent();
auto name = comp->name();
auto entry_computation = mod->entry_computation();
bool is_causal_mask = true;
for (HloInstruction* instr :
entry_computation->MakeInstructionPostOrder()) {
if (instr->opcode() == HloOpcode::kWhile &&
instr->while_body()->name() == name &&
instr->operand(0)->opcode() == HloOpcode::kTuple) {
auto actual_mask =
instr->mutable_operand(0)->mutable_operand(mask_index);
auto causal_mask_pattern_fwd =
OptionalBitcast(m::Convert(m::MinimumAnyOrder(
m::Op(),
OptionalBitcast(m::MinimumAnyOrder(
m::Op(), m::Broadcast(OptionalBitcast(causal_mask)))))));
is_causal_mask &= Match(actual_mask, causal_mask_pattern_fwd);
}
}
return is_causal_mask;
auto actual_mask =
while_instr->mutable_operand(0)->mutable_operand(mask_index);
auto causal_mask_pattern_fwd =
OptionalBitcast(m::Convert(m::MinimumAnyOrder(
m::Op(),
OptionalBitcast(m::MinimumAnyOrder(
m::Op(), m::Broadcast(OptionalBitcast(causal_mask)))))));
return Match(actual_mask, causal_mask_pattern_fwd);
}
return true;
}
Expand Down

0 comments on commit 90e765f

Please sign in to comment.