Skip to content

Commit

Permalink
address some comments and fix wrong layout for softmax stat if O is n…
Browse files Browse the repository at this point in the history
…ot [batch, num_heads, seq, head] layout
  • Loading branch information
Cjkkkk committed Jan 11, 2024
1 parent 48ffd45 commit 542fe9f
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 12 deletions.
4 changes: 2 additions & 2 deletions xla/service/gpu/cudnn_fused_mha_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1552,7 +1552,6 @@ absl::StatusOr<bool> FuseBwdMultiHeadedAttentionBlock(
HloInstruction* lhs_bmm2_grad_gemm1;
HloInstruction* rhs_bmm2_grad_gemm2;
HloInstruction* d_output_grad;
HloInstruction* fwd_act;

DotDimensionNumbers orig_bmm1_grad1_config =
bmm_1_grad_1->dot_dimension_numbers();
Expand Down Expand Up @@ -1587,6 +1586,7 @@ absl::StatusOr<bool> FuseBwdMultiHeadedAttentionBlock(
// Forward activation
// if it is not flash attention, fwd activation is the P tensor
// else it is the softmax_stats
HloInstruction* fwd_act;
if (fwd_config.is_flash_attention()) {
auto fwd_act_index = 2;
fwd_act = comp->AddInstruction(HloInstruction::CreateGetTupleElement(
Expand Down Expand Up @@ -1835,7 +1835,7 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
matched_result.matched_custom_call_name, debug_options));

if (!is_mha_module_supported) continue;
// flash attention require cuDNN 8.9.3 to run non-fused QKV
// flash attention requires cuDNN 8.9.3 to run non-fused QKV
// once we have fused QKV support, we can relax this contraint
if (matched_result.is_flash_attention &&
!IsComputeCapabilityAndCudnnSupported(
Expand Down
16 changes: 8 additions & 8 deletions xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ absl::StatusOr<bool> FuseArgPrologueTransposeWithcuDNNFMHA(
absl::Span<const int64_t> checked_dims;
std::vector<int64_t> checked_dims_vec;

// `should_contracting_be_fastest` means if contracting dim is the hidden
// dim. cuDNN requires hidden dim to be the fastest dim. fwd bmm1 and bwd
// `should_contracting_be_fastest` means if contracting dim is the head
// dim. cuDNN requires head dim to be the fastest dim. fwd bmm1 and bwd
// bmm2grad1 should set this value to true.
if (should_contracting_be_fastest) {
checked_dims = is_lhs ? new_bmm_dot_dims.lhs_contracting_dimensions()
Expand Down Expand Up @@ -133,10 +133,10 @@ absl::StatusOr<bool> FuseArgPrologueTransposeWithcuDNNFMHA(
new_bmm_checked_dims[i] = std::distance(inverse_perm.begin(), itr);
}
// We want to make sure that making the argument to transpose, an input to
// fmha, doesn't break cuDNN constraint that the checked dimensions of
// corresponding operand of BMM has the fastest moving dimension.
// fmha, doesn't break cuDNN constraint that the head dim of
// corresponding operand of BMM is the fastest moving dimension.
// One exception is the forward activation which doesn't have the constraint
// that the fastest dim has to be 64.
// since it does not have head dim.
absl::Span<const int64_t> minor_to_major_bmm =
transpose_arg_operand->shape().layout().minor_to_major();
if ((minor_to_major_bmm[0] != new_bmm_checked_dims[0]) &&
Expand Down Expand Up @@ -516,7 +516,7 @@ use(FMHA_out_t)
absl::StatusOr<bool> FuseEpilogueTransposeWithcuDNNFMHA(HloComputation* comp) {
bool changed = false;

auto onlyOneGTEWithSpecIndex = [](const HloInstruction* instr,
auto only_one_gte_with_spec_index = [](const HloInstruction* instr,
int64_t index) {
int count = 0;
for (auto user : instr->users()) {
Expand Down Expand Up @@ -548,7 +548,7 @@ absl::StatusOr<bool> FuseEpilogueTransposeWithcuDNNFMHA(HloComputation* comp) {
if (Match(instr, fwd_pattern)) {
// check if only one gte with such index exist
int64_t tuple_index = gte->tuple_index();
if (!onlyOneGTEWithSpecIndex(fmha, tuple_index)) continue;
if (!only_one_gte_with_spec_index(fmha, tuple_index)) continue;

std::vector<int64_t> inverse_perm =
InversePermutation(transpose->dimensions());
Expand Down Expand Up @@ -600,7 +600,7 @@ absl::StatusOr<bool> FuseEpilogueTransposeWithcuDNNFMHA(HloComputation* comp) {
} else if (Match(instr, bwd_pattern)) {
// check if only one gte with such index exist
int64_t operand_tuple_idx = gte->tuple_index();
if (!onlyOneGTEWithSpecIndex(fmha, operand_tuple_idx)) continue;
if (!only_one_gte_with_spec_index(fmha, operand_tuple_idx)) continue;

std::vector<int64_t> inverse_perm =
InversePermutation(transpose->dimensions());
Expand Down
16 changes: 14 additions & 2 deletions xla/stream_executor/cuda/cuda_dnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6958,7 +6958,7 @@ GetCudnnFlashAttentionBackwardOperationGraph(
d_output_descriptor.GetCudnnCompatibleDimensions(false);
std::vector<int64_t> do_strides =
d_output_descriptor.GetCudnnCompatibleStrides(false);

VLOG(2) << "\n cuDNN compatible d_output_dims: "
<< absl::StrJoin(do_dims, ",")
<< "\n cuDNN compatible d_output_strides: "
Expand Down Expand Up @@ -7144,9 +7144,21 @@ GetCudnnFlashAttentionBackwardOperationGraph(
CreateCudnnTensor(p_dims, p_strides, CudnnfMHAUid::VIRTUAL_ID + 104,
dnn::DataType::kFloat, 1, -1,
/*is_virtual*/ true));

std::vector<int64_t> p_reduction_dims(p_dims.begin(), p_dims.end() - 1);
p_reduction_dims.push_back(1);

// Divide every stride by the last dim value.
std::vector<int64_t> p_reduction_strides;
p_reduction_strides.reserve(p_strides.size());
int64_t p_reduced_dim_len = p_dims.back();
for (auto stride : p_strides) {
p_reduction_strides.push_back(stride / p_reduced_dim_len);
}

TF_ASSIGN_OR_RETURN(
auto tensor_softmax_stats,
CreateCudnnTensor(do_reduction_dims, do_reduction_strides,
CreateCudnnTensor(p_reduction_dims, p_reduction_strides,
CudnnfMHAUid::P_ID, dnn::DataType::kFloat, 1, -1));

TF_ASSIGN_OR_RETURN(auto sub_desc,
Expand Down

0 comments on commit 542fe9f

Please sign in to comment.