Skip to content

Commit

Permalink
add gpu backend to fmha e2e tests && address some format issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Cjkkkk committed Dec 5, 2023
1 parent 90e765f commit a3e5905
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 12 deletions.
19 changes: 10 additions & 9 deletions xla/service/gpu/cudnn_fused_mha_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -423,19 +423,20 @@ StatusOr<bool> IsSupportedBMM2(const HloInstruction* bmm_2,

StatusOr<bool> IsFlashAttention(HloInstruction* bmm_1, bool is_causal_mask,
absl::string_view custom_call_name) {
const DotDimensionNumbers& dnums = bmm_1->dot_dimension_numbers();
TF_ASSIGN_OR_RETURN(
std::vector<int64_t> seq_q_dims,
GetNonContractingDims(
bmm_1->operand(0)->shape(),
bmm_1->dot_dimension_numbers().lhs_batch_dimensions(),
bmm_1->dot_dimension_numbers().lhs_contracting_dimensions()));
dnums.lhs_batch_dimensions(),
dnums.lhs_contracting_dimensions()));

TF_ASSIGN_OR_RETURN(
std::vector<int64_t> seq_k_dims,
GetNonContractingDims(
bmm_1->operand(1)->shape(),
bmm_1->dot_dimension_numbers().rhs_batch_dimensions(),
bmm_1->dot_dimension_numbers().rhs_contracting_dimensions()));
dnums.rhs_batch_dimensions(),
dnums.rhs_contracting_dimensions()));

std::vector<int64_t> seq_q =
GetDimensionVector(bmm_1->operand(0)->shape().dimensions(), seq_q_dims);
Expand All @@ -445,7 +446,7 @@ StatusOr<bool> IsFlashAttention(HloInstruction* bmm_1, bool is_causal_mask,

std::vector<int64_t> hidden_dim = GetDimensionVector(
bmm_1->operand(0)->shape().dimensions(),
bmm_1->dot_dimension_numbers().lhs_contracting_dimensions());
dnums.lhs_contracting_dimensions());
// for now, seq_q and seq_k should be equal for flash attention to work
// flash attention only supports fixed topology so we check if custom call is
// such topology by checking custom_call_name
Expand Down Expand Up @@ -1379,12 +1380,12 @@ StatusOr<HloInstruction*> FuseFwdMultiHeadedAttentionBlock(
if (is_training) {
activation_output = bmm_2->mutable_operand(0);
// Sometimes activation output is bitcast, the actual activation is the
// second user of the producer of bmm_2's first operand.
// other user of the producer of bmm_2's first operand.
if (activation_output->user_count() < 2 &&
activation_output->opcode() == HloOpcode::kBitcast) {
HloInstruction* producer = activation_output->mutable_operand(0);
TF_RET_CHECK(producer->user_count() == 2);
HloInstruction* bmm2_grad2_user = producer->UserId(activation_output) == 0
HloInstruction* bmm2_grad2_user = producer->users()[0] == activation_output
? producer->users()[1]
: producer->users()[0];
// might be (transpose) - bmm2_grad2
Expand Down Expand Up @@ -1834,9 +1835,9 @@ StatusOr<bool> CudnnFusedMHARewriter::Run(
original_activation_producers.push_back(operand);
}
// We make sure no attention block is matched and replaced twice here
if (matched_bmm1.find(matched_result.matched_bmm_1) != matched_bmm1.end())
if (!matched_bmm1.insert(matched_result.matched_bmm_1).second) {
continue;
matched_bmm1.insert(matched_result.matched_bmm_1);
}
// If we need to canonicalize the bmm, we will assign the newly
// canonicalized bmm to bmm_2.
if (matched_result.need_canonicalization) {
Expand Down
4 changes: 2 additions & 2 deletions xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,8 @@ StatusOr<bool> FusePrologueTransposeWithcuDNNFMHA(HloComputation* comp) {
if (!config.is_flash_attention()) {
TF_ASSIGN_OR_RETURN(changed,
FuseArgPrologueTransposeWithcuDNNFMHA(
fmha, 4, true /*is_lhs*/,
true /*should_contracting_be_fastest*/));
fmha, 4, true /*is_lhs=*/,
true /*should_contracting_be_fastest=*/));
}

if (changed && VLOG_IS_ON(2)) {
Expand Down
8 changes: 7 additions & 1 deletion xla/service/gpu/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -886,10 +886,16 @@ xla_test(
],
)

xla_cc_test(
xla_test(
name = "gpu_fused_mha_test",
srcs = ["gpu_fused_mha_test.cc"],
tags = tf_cuda_tests_tags(),
backend_tags = {"gpu": [
"requires-gpu-sm80",
]},
backends = [
"gpu",
],
deps = [
":gpu_codegen_test",
"//xla:array4d",
Expand Down

0 comments on commit a3e5905

Please sign in to comment.