diff --git a/xla/mlir/backends/gpu/transforms/lmhlo_gpu_to_gpu_runtime.cc b/xla/mlir/backends/gpu/transforms/lmhlo_gpu_to_gpu_runtime.cc index d3d52e4516b46..3a74ed46de424 100644 --- a/xla/mlir/backends/gpu/transforms/lmhlo_gpu_to_gpu_runtime.cc +++ b/xla/mlir/backends/gpu/transforms/lmhlo_gpu_to_gpu_runtime.cc @@ -741,7 +741,8 @@ class FusedAttentionForwardLowering set_attr("fmha_scale", op.getFmhaScaleAttr()); set_attr("dropout_rate", op.getDropoutRateAttr()); set_attr("seed", op.getSeedAttr()); - + set_attr("is_flash_attention", op.getIsFlashAttentionAttr()); + set_attr("is_causal_mask", op.getIsCausalMaskAttr()); set_attr("fused_mha_dag", op.getFusedMhaDagAttr()); set_attr("algorithm_config", op.getAlgorithmConfigAttr()); set_attr("bmm1_dot_dimension_numbers", op.getBmm1DotDimensionNumbers()); @@ -784,8 +785,10 @@ template class FusedAttentionBackwardLowering : public OpRewritePattern { private: - static constexpr const char kCustomCallTarget[] = + static constexpr const char kFusedAttentionCustomCallTarget[] = "xla.gpu.fused.attention.backward."; + static constexpr const char kFlashAttentionCustomCallTarget[] = + "xla.gpu.flash.attention.backward."; public: explicit FusedAttentionBackwardLowering(MLIRContext* ctx, UidGenerator& uid, @@ -797,11 +800,36 @@ class FusedAttentionBackwardLowering LogicalResult matchAndRewrite(FusedDotAttentionBackward op, PatternRewriter& rewriter) const override { // Get the custom call target. - std::string fused_attention = kCustomCallTarget; + bool is_flash_attention = op.getIsFlashAttention(); + std::string fused_attention = is_flash_attention + ? kFlashAttentionCustomCallTarget + : kFusedAttentionCustomCallTarget; auto num_operands = op.getNumOperands(); switch (op.getFusedMhaDag()) { + case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmax: + if (is_flash_attention) { + if (num_operands == 12) { + fused_attention += "scale.softmax"; + } else { + return op.emitOpError( + "unexpected number of operands for flash attention backward - " + "BMM_Softmax_BMM"); + } + } + break; + case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: BackwardScaleBiasSoftmax: + if (is_flash_attention) { + if (num_operands == 13) { + fused_attention += "scale.bias.softmax"; + } else { + return op.emitOpError( + "unexpected number of operands for flash attention backward - " + "BMM_Bias_Softmax_BMM"); + } + break; + } if (num_operands == 10) { fused_attention += "scale.softmax"; } else if (num_operands == 11) { @@ -877,7 +905,8 @@ class FusedAttentionBackwardLowering set_attr("fmha_scale", op.getFmhaScaleAttr()); set_attr("dropout_rate", op.getDropoutRateAttr()); set_attr("seed", op.getSeedAttr()); - + set_attr("is_flash_attention", op.getIsFlashAttentionAttr()); + set_attr("is_causal_mask", op.getIsCausalMaskAttr()); set_attr("fused_mha_dag", op.getFusedMhaDagAttr()); set_attr("algorithm_config", op.getAlgorithmConfigAttr()); set_attr("bmm1_grad_gemm1_dot_dimension_numbers", @@ -889,6 +918,20 @@ class FusedAttentionBackwardLowering set_attr("bmm2_grad_gemm2_dot_dimension_numbers", op.getBmm2GradGemm2DotDimensionNumbers()); + auto set_xi64 = [&](StringRef name, mlir::ArrayAttr array) { + int rank = array.size(); + SmallVector values; + for (int i = 0; i < rank; i++) { + mlir::IntegerAttr attr = array[i].dyn_cast(); + values.push_back(attr.getInt()); + } + set_attr(name, b.getI64TensorAttr(values)); + }; + + set_xi64("intermediate_tensor_dimensions", + op.getIntermediateTensorDimensions()); + set_xi64("intermediate_tensor_layout", op.getIntermediateTensorLayout()); + // Erase the original fused dot attention operation. rewriter.eraseOp(op); diff --git a/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td b/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td index 3091ed06f4fec..e56d2964d767a 100644 --- a/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td +++ b/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td @@ -362,7 +362,9 @@ def LHLOGPU_fusedMHAOp : LHLOGPU_Op<"fMHA", [AttrSizedOperandSegments]> { FusedMhaDagSignatureAttr:$fused_mha_dag, FusedMHAAlgorithmConfigAttr:$algorithm_config, OptionalAttr:$dropout_rate, - OptionalAttr:$seed + OptionalAttr:$seed, + BoolAttr:$is_flash_attention, + BoolAttr:$is_causal_mask ); } @@ -374,21 +376,30 @@ def LHLOGPU_fusedMHABackwardOp : LHLOGPU_Op<"fMHABackward", [AttrSizedOperandSeg Arg:$bmm2_grad_gemm1_lhs, Arg:$d_output, Arg, "", [MemRead]>:$mask, + Arg, "", [MemRead]>:$bias, + Arg, "", [MemRead]>:$fwd_output, Arg:$d_bmm1_lhs, Arg:$d_bmm1_rhs, Arg:$d_bmm2_rhs, - Arg:$d_S, + Arg, "", [MemWrite]>:$d_S, + Arg, "", [MemWrite]>:$softmax_sum, + Arg, "", [MemWrite]>:$d_Q_accum, Arg:$scratch, Arg, "", [MemWrite]>:$d_bias, MHLO_DotDimensionNumbers:$bmm1_grad_gemm1_dot_dimension_numbers, MHLO_DotDimensionNumbers:$bmm1_grad_gemm2_dot_dimension_numbers, MHLO_DotDimensionNumbers:$bmm2_grad_gemm1_dot_dimension_numbers, MHLO_DotDimensionNumbers:$bmm2_grad_gemm2_dot_dimension_numbers, + I64ArrayAttr:$intermediate_tensor_dimensions, + I64ArrayAttr:$intermediate_tensor_layout, F64Attr:$fmha_scale, FusedMhaBackwardDagSignatureAttr:$fused_mha_dag, FusedMHAAlgorithmConfigAttr:$algorithm_config, OptionalAttr:$dropout_rate, - OptionalAttr:$seed); + OptionalAttr:$seed, + BoolAttr:$is_flash_attention, + BoolAttr:$is_causal_mask + ); } def LHLOGPU_RadixSortOp: LHLOGPU_Op<"radix_sort", [SameVariadicOperandSize]> { diff --git a/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_enums.td b/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_enums.td index 8ab0646a44a6d..7ce614e43b859 100644 --- a/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_enums.td +++ b/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_enums.td @@ -153,6 +153,8 @@ def FusedMhaBackwardDagScaleBiasSoftmaxDropout : I32EnumAttrCase<"BackwardScaleB def FusedMhaBackwardDagScaleBiasSoftmax : I32EnumAttrCase<"BackwardScaleBiasSoftmax", 1>; def FusedMhaBackwardDagScaleBiasMaskSoftmax : I32EnumAttrCase<"BackwardScaleBiasMaskSoftmax", 2>; def FusedMhaBackwardDagScaleBiasMaskSoftmaxDropout : I32EnumAttrCase<"BackwardScaleBiasMaskSoftmaxDropout", 3>; +def FusedMhaBackwardDagSoftmax : I32EnumAttrCase<"BackwardSoftmax", 4>; +def FusedMhaBackwardDagSoftmaxDropout : I32EnumAttrCase<"BackwardSoftmaxDropout", 5>; def FusedMhaDagSignature: I32EnumAttr<"FusedMhaDagSignature", "DAG configuration for Fused Multi-Headed Attention", @@ -175,11 +177,13 @@ def FusedMhaBackwardDagSignature: I32EnumAttr<"FusedMhaBackwardDagSignature", FusedMhaBackwardDagScaleBiasSoftmaxDropout, FusedMhaBackwardDagScaleBiasSoftmax, FusedMhaBackwardDagScaleBiasMaskSoftmax, - FusedMhaBackwardDagScaleBiasMaskSoftmaxDropout]> { + FusedMhaBackwardDagScaleBiasMaskSoftmaxDropout, + FusedMhaBackwardDagSoftmax, + FusedMhaBackwardDagSoftmaxDropout]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::lmhlo_gpu"; } def FusedMhaDagSignatureAttr : EnumAttr; def FusedMhaBackwardDagSignatureAttr : EnumAttr; -#endif // LHLO_GPU_OPS_ENUMS \ No newline at end of file +#endif // LHLO_GPU_OPS_ENUMS diff --git a/xla/service/gpu/backend_configs.proto b/xla/service/gpu/backend_configs.proto index c9bdbd009f000..27b1fc1283580 100644 --- a/xla/service/gpu/backend_configs.proto +++ b/xla/service/gpu/backend_configs.proto @@ -176,4 +176,10 @@ message CudnnfMHABackendConfig { // Random seed used by dropout int64 seed = 15; + + // Is flash attention + bool is_flash_attention = 20; + + // Is causal mask + bool is_causal_mask = 21; } diff --git a/xla/service/gpu/fused_mha_thunk.cc b/xla/service/gpu/fused_mha_thunk.cc index 96562a8e2f940..f0ba6f3fbd177 100644 --- a/xla/service/gpu/fused_mha_thunk.cc +++ b/xla/service/gpu/fused_mha_thunk.cc @@ -61,6 +61,15 @@ FusedMultiHeadedAttentionRunner& FusedMHAThunk::GetOrCreateRunner( return *it->second; } +std::optional AssignBufferIfNotNull( + const BufferAllocations& buffer_allocations, + BufferAllocation::Slice& slice) { + return slice.allocation() != nullptr + ? std::optional{buffer_allocations + .GetDeviceAddress(slice)} + : std::nullopt; +} + Status FusedMHAThunk::ExecuteOnStream(const ExecuteParams& params) { const auto& buffer_allocations = *params.buffer_allocations; se::DeviceMemoryBase lhs_bmm1_buffer = @@ -74,19 +83,12 @@ Status FusedMHAThunk::ExecuteOnStream(const ExecuteParams& params) { se::DeviceMemoryBase scratch_buffer = buffer_allocations.GetDeviceAddress(scratch_buffer_); - std::optional mask_buffer; - if (mask_buffer_.allocation() != nullptr) { - mask_buffer = buffer_allocations.GetDeviceAddress(mask_buffer_); - } - std::optional bias_buffer; - if (bias_buffer_.allocation() != nullptr) { - bias_buffer = buffer_allocations.GetDeviceAddress(bias_buffer_); - } - - std::optional activation_buffer; - if (activation_buffer_.allocation() != nullptr) { - activation_buffer = buffer_allocations.GetDeviceAddress(activation_buffer_); - } + std::optional mask_buffer = + AssignBufferIfNotNull(buffer_allocations, mask_buffer_); + std::optional bias_buffer = + AssignBufferIfNotNull(buffer_allocations, bias_buffer_); + std::optional activation_buffer = + AssignBufferIfNotNull(buffer_allocations, activation_buffer_); RunFusedMHAOptions opts; opts.runner_cache = &GetOrCreateRunner(params.stream); @@ -109,7 +111,9 @@ FusedMHABackwardThunk::FusedMHABackwardThunk( BufferAllocation::Slice d_output, BufferAllocation::Slice scratch, BufferAllocation::Slice d_bmm1_lhs, BufferAllocation::Slice d_bmm1_rhs, BufferAllocation::Slice d_bmm2_rhs, BufferAllocation::Slice d_s, - BufferAllocation::Slice mask, BufferAllocation::Slice d_bias) + BufferAllocation::Slice softmax_sum, BufferAllocation::Slice d_Q_accum, + BufferAllocation::Slice mask, BufferAllocation::Slice d_bias, + BufferAllocation::Slice fwd_output, BufferAllocation::Slice bias) : Thunk(Kind::kFusedMHA, thunk_info), bmm1_grad_gemm1_rhs_buffer_(bmm1_grad_gemm1_rhs), bmm1_grad_gemm2_rhs_buffer_(bmm1_grad_gemm2_rhs), @@ -121,8 +125,12 @@ FusedMHABackwardThunk::FusedMHABackwardThunk( d_bmm1_rhs_buffer_(d_bmm1_rhs), d_bmm2_rhs_buffer_(d_bmm2_rhs), d_s_buffer_(d_s), + softmax_sum_buffer_(softmax_sum), + d_Q_accum_buffer_(d_Q_accum), mask_buffer_(mask), d_bias_buffer_(d_bias), + fwd_output_buffer_(fwd_output), + bias_buffer_(bias), config_(std::move(config)) {} FusedMultiHeadedAttentionBackwardRunner& @@ -169,18 +177,21 @@ Status FusedMHABackwardThunk::ExecuteOnStream(const ExecuteParams& params) { se::DeviceMemoryBase d_bmm2_rhs_buffer = buffer_allocations.GetDeviceAddress(d_bmm2_rhs_buffer_); - se::DeviceMemoryBase d_S_buffer = - buffer_allocations.GetDeviceAddress(d_s_buffer_); + std::optional d_s_buffer = + AssignBufferIfNotNull(buffer_allocations, d_s_buffer_); + std::optional softmax_sum_buffer = + AssignBufferIfNotNull(buffer_allocations, softmax_sum_buffer_); + std::optional d_Q_accum_buffer = + AssignBufferIfNotNull(buffer_allocations, d_Q_accum_buffer_); + std::optional mask_buffer = + AssignBufferIfNotNull(buffer_allocations, mask_buffer_); + std::optional d_bias_buffer = + AssignBufferIfNotNull(buffer_allocations, d_bias_buffer_); + std::optional fwd_output_buffer = + AssignBufferIfNotNull(buffer_allocations, fwd_output_buffer_); + std::optional bias_buffer = + AssignBufferIfNotNull(buffer_allocations, bias_buffer_); - std::optional mask_buffer; - if (mask_buffer_.allocation() != nullptr) { - mask_buffer = buffer_allocations.GetDeviceAddress(mask_buffer_); - } - - std::optional d_bias_buffer; - if (d_bias_buffer_.allocation() != nullptr) { - d_bias_buffer = buffer_allocations.GetDeviceAddress(d_bias_buffer_); - } RunFusedMHABackwardOptions opts; opts.runner_cache = &GetOrCreateRunner(params.stream); @@ -189,7 +200,8 @@ Status FusedMHABackwardThunk::ExecuteOnStream(const ExecuteParams& params) { config_, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer, bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, d_output_buffer, scratch_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, - d_S_buffer, mask_buffer, d_bias_buffer, params.stream, opts)); + d_s_buffer, softmax_sum_buffer, d_Q_accum_buffer, mask_buffer, + d_bias_buffer, fwd_output_buffer, bias_buffer, params.stream, opts)); if (!params.stream->ok()) { return InternalError("FusedMHABackwardThunk::ExecuteOnStream failed."); } diff --git a/xla/service/gpu/fused_mha_thunk.h b/xla/service/gpu/fused_mha_thunk.h index a1db1d23e16c9..a0d9e58aa0e64 100644 --- a/xla/service/gpu/fused_mha_thunk.h +++ b/xla/service/gpu/fused_mha_thunk.h @@ -91,9 +91,13 @@ class FusedMHABackwardThunk : public Thunk { BufferAllocation::Slice d_bmm1_lhs_slice, BufferAllocation::Slice d_bmm1_rhs_slice, BufferAllocation::Slice d_bmm2_rhs_slice, - BufferAllocation::Slice d_S_slice, + BufferAllocation::Slice d_s_slice, + BufferAllocation::Slice softmax_sum_slice, + BufferAllocation::Slice d_Q_accum_slice, BufferAllocation::Slice mask_slice, - BufferAllocation::Slice d_bias_slice); + BufferAllocation::Slice d_bias_slice, + BufferAllocation::Slice fwd_output_slice, + BufferAllocation::Slice bias_slice); FusedMHABackwardThunk(const FusedMHABackwardThunk&) = delete; FusedMHABackwardThunk& operator=(const FusedMHABackwardThunk&) = delete; @@ -111,8 +115,12 @@ class FusedMHABackwardThunk : public Thunk { BufferAllocation::Slice d_bmm1_rhs_buffer_; BufferAllocation::Slice d_bmm2_rhs_buffer_; BufferAllocation::Slice d_s_buffer_; + BufferAllocation::Slice softmax_sum_buffer_; + BufferAllocation::Slice d_Q_accum_buffer_; BufferAllocation::Slice mask_buffer_; BufferAllocation::Slice d_bias_buffer_; + BufferAllocation::Slice fwd_output_buffer_; + BufferAllocation::Slice bias_buffer_; FusedMultiHeadedAttentionBackwardRunner& GetOrCreateRunner( const stream_executor::Stream* stream); diff --git a/xla/service/gpu/gpu_fused_mha_runner.cc b/xla/service/gpu/gpu_fused_mha_runner.cc index 74c3d839750a5..a796f0b5652ca 100644 --- a/xla/service/gpu/gpu_fused_mha_runner.cc +++ b/xla/service/gpu/gpu_fused_mha_runner.cc @@ -83,8 +83,8 @@ Status RunFusedMHA(GpufMHAParams params, se::Stream *stream, params.config->activation, dropout_rate, seed, - false, - false}; + params.config->is_flash_attention, + params.config->is_causal_mask}; TF_ASSIGN_OR_RETURN(auto *runner, lazy_runner->GetOrCreateRunner(config, stream)); return (*runner)(stream, options.profile_result, scratch_memory, @@ -201,20 +201,21 @@ void AssignSeed(GpufMHAConfig &config, } template -Status RunFusedMHABackward(GpufMHABackwardParams params, se::Stream *stream, - RunFusedMHABackwardOptions options, - DeviceMemory bmm1_grad_gemm1_rhs_buffer, - DeviceMemory bmm1_grad_gemm2_rhs_buffer, - DeviceMemory bmm2_grad_gemm1_lhs_buffer, - DeviceMemory bmm2_grad_gemm2_rhs_buffer, - DeviceMemory d_output_buffer, - DeviceMemory d_bmm1_lhs_buffer, - DeviceMemory d_bmm1_rhs_buffer, - DeviceMemory d_bmm2_rhs_buffer, - DeviceMemory d_s_buffer, - DeviceMemoryBase mask_buffer, - DeviceMemoryBase d_bias_buffer, - DeviceMemoryBase scratch_memory) { +Status RunFusedMHABackward( + GpufMHABackwardParams params, se::Stream *stream, + RunFusedMHABackwardOptions options, + DeviceMemory bmm1_grad_gemm1_rhs_buffer, + DeviceMemory bmm1_grad_gemm2_rhs_buffer, + DeviceMemory bmm2_grad_gemm1_lhs_buffer, + DeviceMemory bmm2_grad_gemm2_rhs_buffer, + DeviceMemory d_output_buffer, + DeviceMemory d_bmm1_lhs_buffer, + DeviceMemory d_bmm1_rhs_buffer, + DeviceMemory d_bmm2_rhs_buffer, DeviceMemoryBase d_s_buffer, + DeviceMemoryBase softmax_buffer, DeviceMemoryBase d_Q_accum_buffer, + DeviceMemoryBase mask_buffer, DeviceMemoryBase d_bias_buffer, + DeviceMemoryBase fwd_output_buffer, DeviceMemoryBase bias_buffer, + DeviceMemoryBase scratch_memory) { se::dnn::LazyOpRunner *lazy_runner = options.runner_cache->AsFusedMHABackwardRunner(); std::optional> @@ -223,6 +224,7 @@ Status RunFusedMHABackward(GpufMHABackwardParams params, se::Stream *stream, local_runner.emplace(params.config->algorithm); lazy_runner = &*local_runner; } + // FMHA TODO: add GetDNNFusedMHAKindFromCudnnfMHAKind here TF_ASSIGN_OR_RETURN(se::dnn::FusedMHAKind kind, GetDNNFusedMHAKindFromCudnnfMHAKind(params.config->kind)); std::optional dropout_rate; @@ -239,27 +241,25 @@ Status RunFusedMHABackward(GpufMHABackwardParams params, se::Stream *stream, if (params.config->seed) { seed = *params.config->seed; } - // TODO: set is_flash_attention to real value, set it to false for now - se::dnn::FusedMHABackwardOp::Config config{ - kind, - scale, - params.config->bmm1_grad_gemm1_rhs, - params.config->bmm1_grad_gemm2_rhs, - params.config->bmm2_grad_gemm1_lhs, - params.config->bmm2_grad_gemm2_rhs, - params.config->d_output, - params.config->d_bmm1_lhs, - params.config->d_bmm1_rhs, - params.config->d_bmm2_rhs, - std::optional(params.config->d_s), - params.config->mask, - params.config->d_bias, - std::nullopt, - std::nullopt, - dropout_rate, - seed, - false, - false}; + se::dnn::FusedMHABackwardOp::Config config{kind, + scale, + params.config->bmm1_grad_gemm1_rhs, + params.config->bmm1_grad_gemm2_rhs, + params.config->bmm2_grad_gemm1_lhs, + params.config->bmm2_grad_gemm2_rhs, + params.config->d_output, + params.config->d_bmm1_lhs, + params.config->d_bmm1_rhs, + params.config->d_bmm2_rhs, + params.config->d_s, + params.config->mask, + params.config->d_bias, + params.config->fwd_output, + params.config->bias, + dropout_rate, + seed, + params.config->is_flash_attention, + params.config->is_causal_mask}; TF_ASSIGN_OR_RETURN(auto *runner, lazy_runner->GetOrCreateRunner(config, stream)); // TODO: pass in real softmax_sum, dQ_accum, fwd_output @@ -267,9 +267,10 @@ Status RunFusedMHABackward(GpufMHABackwardParams params, se::Stream *stream, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer, bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, d_output_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer, - d_bmm2_rhs_buffer, d_s_buffer, se::DeviceMemoryBase(), - se::DeviceMemoryBase(), mask_buffer, d_bias_buffer, - se::DeviceMemoryBase(), se::DeviceMemoryBase()); + d_bmm2_rhs_buffer, d_s_buffer, softmax_buffer, + d_Q_accum_buffer, mask_buffer, d_bias_buffer, + fwd_output_buffer, bias_buffer); + return OkStatus(); } template @@ -292,7 +293,20 @@ Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, se::DeviceMemory(params.d_bmm1_rhs_buffer); auto d_bmm2_rhs_buffer = se::DeviceMemory(params.d_bmm2_rhs_buffer); - auto d_s_buffer = se::DeviceMemory(params.d_s_buffer); + + // optional buffers + auto d_s_buffer = params.d_s_buffer.has_value() + ? se::DeviceMemory(*params.d_s_buffer) + : se::DeviceMemoryBase(); + auto softmax_sum_buffer = + params.softmax_sum_buffer.has_value() + ? se::DeviceMemory(*params.softmax_sum_buffer) + : se::DeviceMemoryBase(); + + auto d_Q_accum_buffer = + params.d_Q_accum_buffer.has_value() + ? se::DeviceMemory(*params.d_Q_accum_buffer) + : se::DeviceMemoryBase(); auto mask_buffer = params.mask_buffer.has_value() ? se::DeviceMemory(*params.mask_buffer) @@ -302,6 +316,15 @@ Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, ? se::DeviceMemory(*params.d_bias_buffer) : se::DeviceMemoryBase(); + auto fwd_output_buffer = + params.fwd_output_buffer.has_value() + ? se::DeviceMemory(*params.fwd_output_buffer) + : se::DeviceMemoryBase(); + + auto bias_buffer = params.bias_buffer.has_value() + ? se::DeviceMemory(*params.bias_buffer) + : se::DeviceMemoryBase(); + se::dnn::AlgorithmDesc algorithm = params.config->algorithm; if (options.runner_cache) { algorithm = options.runner_cache->ToAlgorithmDesc(); @@ -322,8 +345,9 @@ Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, params, stream, options, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer, bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, d_output_buffer, d_bmm1_lhs_buffer, - d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, d_s_buffer, mask_buffer, - d_bias_buffer, scratch_memory); + d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, d_s_buffer, softmax_sum_buffer, + d_Q_accum_buffer, mask_buffer, d_bias_buffer, fwd_output_buffer, + bias_buffer, scratch_memory); break; default: return InternalError("Invalid cuDNN fMHA kind"); @@ -428,6 +452,8 @@ Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, bias_shape.layout().minor_to_major()); } config.kind = desc.kind; + config.is_flash_attention = desc.is_flash_attention; + config.is_causal_mask = desc.is_causal_mask; const CudnnfMHABackendConfig &backend_config = desc.backend_config; config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm()); @@ -449,7 +475,6 @@ Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, const Shape &d_bmm1_lhs_shape = desc.d_bmm1_lhs_shape; const Shape &d_bmm1_rhs_shape = desc.d_bmm1_rhs_shape; const Shape &d_bmm2_rhs_shape = desc.d_bmm2_rhs_shape; - // Get DNN dtype from primtive types TF_ASSIGN_OR_RETURN(DataType bmm1_grad_gemm1_rhs_type, GetDNNDataTypeFromPrimitiveType( @@ -537,7 +562,6 @@ Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, if (desc.d_bias_shape) { const Shape &d_bias_shape = *desc.d_bias_shape; - // Get DNN dtype from primtive types TF_ASSIGN_OR_RETURN(DataType d_bias_type, GetDNNDataTypeFromPrimitiveType( d_bias_shape.element_type())); @@ -553,7 +577,27 @@ Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, config.mask = TensorDescriptor::For(mask_type, mask_shape.dimensions(), mask_shape.layout().minor_to_major()); } + if (desc.fwd_output_shape) { + const Shape &fwd_output_shape = *desc.fwd_output_shape; + TF_ASSIGN_OR_RETURN( + DataType fwd_output_type, + GetDNNDataTypeFromPrimitiveType(fwd_output_shape.element_type())); + config.fwd_output = + TensorDescriptor::For(fwd_output_type, fwd_output_shape.dimensions(), + fwd_output_shape.layout().minor_to_major()); + } + + if (desc.bias_shape) { + const Shape &bias_shape = *desc.bias_shape; + TF_ASSIGN_OR_RETURN(DataType bias_type, GetDNNDataTypeFromPrimitiveType( + bias_shape.element_type())); + config.bias = TensorDescriptor::For(bias_type, bias_shape.dimensions(), + bias_shape.layout().minor_to_major()); + } + config.kind = desc.kind; + config.is_flash_attention = desc.is_flash_attention; + config.is_causal_mask = desc.is_causal_mask; const CudnnfMHABackendConfig &backend_config = desc.backend_config; config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm()); @@ -601,9 +645,14 @@ Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, se::DeviceMemoryBase d_output_buffer, se::DeviceMemoryBase d_bmm1_lhs_buffer, se::DeviceMemoryBase d_bmm1_rhs_buffer, - se::DeviceMemoryBase d_bmm2_rhs_buffer, se::DeviceMemoryBase d_s_buffer, + se::DeviceMemoryBase d_bmm2_rhs_buffer, + std::optional d_s_buffer, + std::optional softmax_sum_buffer, + std::optional d_Q_accum_buffer, std::optional mask_buffer, - std::optional d_bias_buffer) { + std::optional d_bias_buffer, + std::optional fwd_output_buffer, + std::optional bias_buffer) { GpufMHABackwardParams params; params.config = &config; params.bmm1_grad_gemm1_rhs_buffer = bmm1_grad_gemm1_rhs_buffer; @@ -615,9 +664,12 @@ Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, params.d_bmm1_rhs_buffer = d_bmm1_rhs_buffer; params.d_bmm2_rhs_buffer = d_bmm2_rhs_buffer; params.d_s_buffer = d_s_buffer; + params.softmax_sum_buffer = softmax_sum_buffer; + params.d_Q_accum_buffer = d_Q_accum_buffer; params.mask_buffer = mask_buffer; params.d_bias_buffer = d_bias_buffer; - + params.fwd_output_buffer = fwd_output_buffer; + params.bias_buffer = bias_buffer; return params; } @@ -651,28 +703,32 @@ Status RunGpuFMHA(const GpufMHAConfig &fmha_config, return OkStatus(); } -Status RunGpuFMHABackward(const GpufMHABackwardConfig &fmha_config, - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase d_output_buffer, - se::DeviceMemoryBase scratch_buffer, - se::DeviceMemoryBase d_bmm1_lhs_buffer, - se::DeviceMemoryBase d_bmm1_rhs_buffer, - se::DeviceMemoryBase d_bmm2_rhs_buffer, - se::DeviceMemoryBase d_s_buffer, - std::optional mask_buffer, - std::optional d_bias_buffer, - se::Stream *stream, - RunFusedMHABackwardOptions options) { +Status RunGpuFMHABackward( + const GpufMHABackwardConfig &fmha_config, + se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, + se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, + se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer, + se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer, + se::DeviceMemoryBase d_output_buffer, se::DeviceMemoryBase scratch_buffer, + se::DeviceMemoryBase d_bmm1_lhs_buffer, + se::DeviceMemoryBase d_bmm1_rhs_buffer, + se::DeviceMemoryBase d_bmm2_rhs_buffer, + std::optional d_s_buffer, + std::optional softmax_sum_buffer, + std::optional d_Q_accum_buffer, + std::optional mask_buffer, + std::optional d_bias_buffer, + std::optional fwd_output_buffer, + std::optional bias_buffer, se::Stream *stream, + RunFusedMHABackwardOptions options) { TF_ASSIGN_OR_RETURN( GpufMHABackwardParams params, GpufMHABackwardParams::For( fmha_config, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer, bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, d_output_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer, - d_bmm2_rhs_buffer, d_s_buffer, mask_buffer, d_bias_buffer)); + d_bmm2_rhs_buffer, d_s_buffer, softmax_sum_buffer, d_Q_accum_buffer, + mask_buffer, d_bias_buffer, fwd_output_buffer, bias_buffer)); PrimitiveType input_primitive_type = fmha_config.input_type; switch (input_primitive_type) { case F16: diff --git a/xla/service/gpu/gpu_fused_mha_runner.h b/xla/service/gpu/gpu_fused_mha_runner.h index 637a3c474c4f7..041993431030c 100644 --- a/xla/service/gpu/gpu_fused_mha_runner.h +++ b/xla/service/gpu/gpu_fused_mha_runner.h @@ -46,6 +46,8 @@ namespace gpu { struct GpufMHADescriptor { CudnnfMHAKind kind; CudnnfMHABackendConfig backend_config; + bool is_flash_attention; + bool is_causal_mask; Shape lhs_bmm1_shape; Shape rhs_bmm1_shape; Shape rhs_bmm2_shape; @@ -62,6 +64,8 @@ struct GpufMHADescriptor { struct GpufMHABackwardDescriptor { CudnnfMHAKind kind; CudnnfMHABackendConfig backend_config; + bool is_flash_attention; + bool is_causal_mask; Shape bmm1_grad_gemm1_rhs_shape; Shape bmm1_grad_gemm2_rhs_shape; Shape bmm2_grad_gemm1_lhs_shape; @@ -75,8 +79,11 @@ struct GpufMHABackwardDescriptor { DotDimensionNumbers bmm2_grad_gemm1_dnums; DotDimensionNumbers bmm2_grad_gemm2_dnums; + std::optional d_s_shape; + std::optional fwd_output_shape; std::optional mask_shape; std::optional d_bias_shape; + std::optional bias_shape; }; // Structure to describe static properties of a GPU fused Multi-Headed // Attention. @@ -91,7 +98,8 @@ struct GpufMHAConfig { std::optional seed; se::dnn::AlgorithmDesc algorithm; - + bool is_flash_attention; + bool is_causal_mask; // bias -> [1, num_attn_heads, q_seq_len, kv_seq_len] // mask -> [batch_size, 1, q_seq_len, kv_seq_len] se::dnn::MatmulTensorDescriptor lhs_bmm1; @@ -119,7 +127,8 @@ struct GpufMHABackwardConfig { std::optional seed; se::dnn::AlgorithmDesc algorithm; - + bool is_flash_attention; + bool is_causal_mask; // mask -> [batch_size, 1, q_seq_len, kv_seq_len] // d_bias -> [1, num_heads, q_seq_len, kv_seq_len] se::dnn::MatmulTensorDescriptor bmm1_grad_gemm1_rhs; @@ -130,9 +139,11 @@ struct GpufMHABackwardConfig { se::dnn::TensorDescriptor d_bmm1_lhs; se::dnn::TensorDescriptor d_bmm1_rhs; se::dnn::TensorDescriptor d_bmm2_rhs; - se::dnn::TensorDescriptor d_s; - std::optional d_bias; + std::optional d_s; std::optional mask; + std::optional d_bias; + std::optional fwd_output; + std::optional bias; }; // Implementation struct exposed for debugging and log analysis. @@ -165,9 +176,14 @@ struct GpufMHABackwardParams { se::DeviceMemoryBase d_output_buffer, se::DeviceMemoryBase d_bmm1_lhs_buffer, se::DeviceMemoryBase d_bmm1_rhs_buffer, - se::DeviceMemoryBase d_bmm2_rhs_buffer, se::DeviceMemoryBase d_s_buffer, + se::DeviceMemoryBase d_bmm2_rhs_buffer, + std::optional d_s_buffer, + std::optional softmax_sum_buffer, + std::optional d_Q_accum_buffer, std::optional mask_buffer, - std::optional d_bias_buffer); + std::optional d_bias_buffer, + std::optional fwd_output_buffer, + std::optional bias_buffer); const GpufMHABackwardConfig* config; // Not owned se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer; @@ -178,9 +194,13 @@ struct GpufMHABackwardParams { se::DeviceMemoryBase d_bmm1_lhs_buffer; se::DeviceMemoryBase d_bmm1_rhs_buffer; se::DeviceMemoryBase d_bmm2_rhs_buffer; - se::DeviceMemoryBase d_s_buffer; - std::optional d_bias_buffer; + std::optional d_s_buffer; + std::optional softmax_sum_buffer; + std::optional d_Q_accum_buffer; std::optional mask_buffer; + std::optional d_bias_buffer; + std::optional fwd_output_buffer; + std::optional bias_buffer; }; class FusedMultiHeadedAttentionRunner { @@ -371,20 +391,24 @@ Status RunGpuFMHA(const GpufMHAConfig& fmha_config, std::optional activation_buffer, se::Stream* stream, RunFusedMHAOptions = {}); -Status RunGpuFMHABackward(const GpufMHABackwardConfig& fmha_config, - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase d_output_buffer, - se::DeviceMemoryBase scratch_buffer, - se::DeviceMemoryBase d_bmm1_lhs_buffer, - se::DeviceMemoryBase d_bmm1_rhs_buffer, - se::DeviceMemoryBase d_bmm2_rhs_buffer, - se::DeviceMemoryBase d_s_buffer, - std::optional mask_buffer, - std::optional d_bias_buffer, - se::Stream* stream, RunFusedMHABackwardOptions = {}); +Status RunGpuFMHABackward( + const GpufMHABackwardConfig& fmha_config, + se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, + se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, + se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer, + se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer, + se::DeviceMemoryBase d_output_buffer, se::DeviceMemoryBase scratch_buffer, + se::DeviceMemoryBase d_bmm1_lhs_buffer, + se::DeviceMemoryBase d_bmm1_rhs_buffer, + se::DeviceMemoryBase d_bmm2_rhs_buffer, + std::optional d_s_buffer, + std::optional softmax_sum_buffer, + std::optional d_Q_accum_buffer, + std::optional mask_buffer, + std::optional d_bias_buffer, + std::optional fwd_output_buffer, + std::optional bias_buffer, se::Stream* stream, + RunFusedMHABackwardOptions = {}); std::string ToString(const GpufMHAConfig& config); diff --git a/xla/service/gpu/ir_emitter_unnested.cc b/xla/service/gpu/ir_emitter_unnested.cc index 5816a68493458..02ec904ca4a65 100644 --- a/xla/service/gpu/ir_emitter_unnested.cc +++ b/xla/service/gpu/ir_emitter_unnested.cc @@ -244,6 +244,13 @@ StatusOr AsCudnnBackwardfMHAKind( case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: BackwardScaleBiasMaskSoftmaxDropout: return xla::gpu::CudnnfMHAKind::kBackwardScaleBiasMaskSoftmaxDropout; + break; + case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmax: + return xla::gpu::CudnnfMHAKind::kBackwardSoftmax; + break; + case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmaxDropout: + return xla::gpu::CudnnfMHAKind::kBackwardSoftmaxDropout; + break; default: return xla::InternalError("Unsupported fused_mha_backward_dag_signature"); } @@ -1200,6 +1207,11 @@ Status IrEmitterUnnested::EmitFusedMHAThunk(mlir::Operation* op) { ShapeUtil::MakeShapeWithDenseLayout( GetShape(fmha.getOutput()).element_type(), intermediate_tensor_dims_array, intermediate_tensor_layout_array); + + // set if flash attention here + descriptor.is_flash_attention = fmha.getIsFlashAttention(); + // set if causal mask here + descriptor.is_causal_mask = fmha.getIsCausalMask(); return OkStatus(); }; @@ -1227,9 +1239,9 @@ Status IrEmitterUnnested::EmitFusedMHABackwardThunk(mlir::Operation* op) { GpufMHABackwardDescriptor descriptor; BufferAllocation::Slice bmm1_grad_gemm1_rhs_slice, bmm1_grad_gemm2_rhs_slice, bmm2_grad_gemm1_lhs_slice, bmm2_grad_gemm2_rhs_slice, d_output_slice, - scratch_slice, mask_slice; + scratch_slice, mask_slice, fwd_output_slice, bias_slice; BufferAllocation::Slice d_bmm1_lhs_slice, d_bmm1_rhs_slice, d_bmm2_rhs_slice, - d_S_slice, d_bias_slice; + d_s_slice, softmax_sum_slice, d_Q_accum_slice, d_bias_slice; auto populate_common = [&](auto fmha) -> Status { descriptor.backend_config.set_fmha_scale( @@ -1259,6 +1271,10 @@ Status IrEmitterUnnested::EmitFusedMHABackwardThunk(mlir::Operation* op) { algorithm->mutable_workspace_size()->set_value(workspace_size); } + // set if flash attention here + descriptor.is_flash_attention = fmha.getIsFlashAttention(); + // set if causal mask here + descriptor.is_causal_mask = fmha.getIsCausalMask(); descriptor.bmm1_grad_gemm1_dnums = ConvertDotDimensionNumbers(fmha.getBmm1GradGemm1DotDimensionNumbers()); descriptor.bmm1_grad_gemm2_dnums = @@ -1282,10 +1298,31 @@ Status IrEmitterUnnested::EmitFusedMHABackwardThunk(mlir::Operation* op) { TF_ASSIGN_OR_RETURN(bmm1_grad_gemm2_rhs_slice, GetAllocationSlice(fmha.getBmm1GradGemm2Rhs())); - descriptor.bmm2_grad_gemm1_lhs_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getBmm2GradGemm1Lhs()).element_type(), - GetShape(fmha.getBmm2GradGemm1Lhs()).dimensions(), - GetShape(fmha.getBmm2GradGemm1Lhs()).layout().minor_to_major()); + // fwd activation + // fmha.getBmm2GradGemm1Lhs() could be bmm2_grad_gemm1_lhs for regular + // attention or softmax stats for flash attention here we set the shape to + // be bmm2_grad_gemm1_lhs even it is flash attention + if (descriptor.is_flash_attention) { + // flash attention TODO: make sure the layout is correct for + // bmm2_grad_gemm1_lhs + TF_ASSIGN_OR_RETURN(auto intermediate_tensor_dims_array, + ConvertMlirArrayAttrToInt64Array( + fmha.getIntermediateTensorDimensions())); + TF_ASSIGN_OR_RETURN( + auto intermediate_tensor_layout_array, + ConvertMlirArrayAttrToInt64Array(fmha.getIntermediateTensorLayout())); + + descriptor.bmm2_grad_gemm1_lhs_shape = + ShapeUtil::MakeShapeWithDenseLayout( + GetShape(fmha.getDOutput()).element_type(), + intermediate_tensor_dims_array, intermediate_tensor_layout_array); + } else { + descriptor.bmm2_grad_gemm1_lhs_shape = + ShapeUtil::MakeShapeWithDenseLayout( + GetShape(fmha.getBmm2GradGemm1Lhs()).element_type(), + GetShape(fmha.getBmm2GradGemm1Lhs()).dimensions(), + GetShape(fmha.getBmm2GradGemm1Lhs()).layout().minor_to_major()); + } TF_ASSIGN_OR_RETURN(bmm2_grad_gemm1_lhs_slice, GetAllocationSlice(fmha.getBmm2GradGemm1Lhs())); @@ -1324,7 +1361,13 @@ Status IrEmitterUnnested::EmitFusedMHABackwardThunk(mlir::Operation* op) { TF_ASSIGN_OR_RETURN(scratch_slice, GetAllocationSlice(fmha.getScratch())); - TF_ASSIGN_OR_RETURN(d_S_slice, GetAllocationSlice(fmha.getD_S())); + if (fmha.getD_S() != nullptr) { + descriptor.d_s_shape = ShapeUtil::MakeShapeWithDenseLayout( + GetShape(fmha.getD_S()).element_type(), + GetShape(fmha.getD_S()).dimensions(), + GetShape(fmha.getD_S()).layout().minor_to_major()); + TF_ASSIGN_OR_RETURN(d_s_slice, GetAllocationSlice(fmha.getD_S())); + } if (fmha.getDBias() != nullptr) { descriptor.d_bias_shape = ShapeUtil::MakeShapeWithDenseLayout( @@ -1348,6 +1391,33 @@ Status IrEmitterUnnested::EmitFusedMHABackwardThunk(mlir::Operation* op) { TF_ASSIGN_OR_RETURN(mask_slice, GetAllocationSlice(fmha.getMask())); } + // add flash attention backward related slice here + if (fmha.getBias() != nullptr) { + descriptor.bias_shape = ShapeUtil::MakeShapeWithDenseLayout( + GetShape(fmha.getBias()).element_type(), + GetShape(fmha.getBias()).dimensions(), + GetShape(fmha.getBias()).layout().minor_to_major()); + TF_ASSIGN_OR_RETURN(bias_slice, GetAllocationSlice(fmha.getBias())); + } + + if (fmha.getSoftmaxSum() != nullptr) { + TF_ASSIGN_OR_RETURN(softmax_sum_slice, + GetAllocationSlice(fmha.getSoftmaxSum())); + } + + if (fmha.getD_QAccum() != nullptr) { + TF_ASSIGN_OR_RETURN(d_Q_accum_slice, + GetAllocationSlice(fmha.getD_QAccum())); + } + + if (fmha.getFwdOutput() != nullptr) { + descriptor.fwd_output_shape = ShapeUtil::MakeShapeWithDenseLayout( + GetShape(fmha.getFwdOutput()).element_type(), + GetShape(fmha.getFwdOutput()).dimensions(), + GetShape(fmha.getFwdOutput()).layout().minor_to_major()); + TF_ASSIGN_OR_RETURN(fwd_output_slice, + GetAllocationSlice(fmha.getFwdOutput())); + } return OkStatus(); }; @@ -1369,7 +1439,8 @@ Status IrEmitterUnnested::EmitFusedMHABackwardThunk(mlir::Operation* op) { bmm1_grad_gemm1_rhs_slice, bmm1_grad_gemm2_rhs_slice, bmm2_grad_gemm1_lhs_slice, bmm2_grad_gemm2_rhs_slice, d_output_slice, scratch_slice, d_bmm1_lhs_slice, d_bmm1_rhs_slice, d_bmm2_rhs_slice, - d_S_slice, mask_slice, d_bias_slice)); + d_s_slice, softmax_sum_slice, d_Q_accum_slice, mask_slice, d_bias_slice, + fwd_output_slice, bias_slice)); return OkStatus(); } diff --git a/xla/service/gpu/nvptx_compiler.cc b/xla/service/gpu/nvptx_compiler.cc index ed41927807cd9..5aff01f0c94e9 100644 --- a/xla/service/gpu/nvptx_compiler.cc +++ b/xla/service/gpu/nvptx_compiler.cc @@ -227,11 +227,11 @@ Status NVPTXCompiler::OptimizeHloPostLayoutAssignment( mha_fusion_pipeline.AddPass(); mha_fusion_pipeline.AddPass(); } + mha_fusion_pipeline.AddPass(/*is_layout_sensitive=*/true); mha_fusion_pipeline.AddPass>( alg_sim_options); mha_fusion_pipeline.AddPass(/*is_layout_sensitive=*/true); - // Rewrite Multi-Headed Attention modules to Fused MHA custom-calls. if (stream_exec) { mha_fusion_pipeline.AddPass( diff --git a/xla/service/gpu/runtime/fused_attention.cc b/xla/service/gpu/runtime/fused_attention.cc index 3174445ef1ee5..9dbd3e6aa1907 100644 --- a/xla/service/gpu/runtime/fused_attention.cc +++ b/xla/service/gpu/runtime/fused_attention.cc @@ -114,6 +114,10 @@ static auto EncodeFusedAttentionBackwardDAGSignature( lmhlo_gpu::FusedMhaBackwardDagSignature signature) { switch (signature) { // backward + case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmax: + return xla::gpu::CudnnfMHAKind::kBackwardSoftmax; + case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmaxDropout: + return xla::gpu::CudnnfMHAKind::kBackwardSoftmaxDropout; case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: BackwardScaleBiasSoftmax: return xla::gpu::CudnnfMHAKind::kBackwardScaleBiasSoftmax; @@ -193,8 +197,8 @@ static GpufMHADescriptor GetGpufMHADescriptor( absl::Span intermediate_tensor_dimensions, absl::Span intermediate_tensor_layout, AlgorithmConfig algo, DotDimensionNumbers bmm1_dot_dimension_numbers, - DotDimensionNumbers bmm2_dot_dimension_numbers, - std::optional dropout = std::nullopt) { + DotDimensionNumbers bmm2_dot_dimension_numbers, bool is_flash_attention, + bool is_causal_mask, std::optional dropout = std::nullopt) { GpufMHADescriptor descriptor; descriptor.backend_config.set_fmha_scale(fmha_scale); @@ -250,7 +254,8 @@ static GpufMHADescriptor GetGpufMHADescriptor( } descriptor.kind = kind; - + descriptor.is_flash_attention = is_flash_attention; + descriptor.is_causal_mask = is_causal_mask; return descriptor; } @@ -262,11 +267,19 @@ static GpufMHABackwardDescriptor GetGpufMHABackwardDescriptor( std::optional mask, std::optional d_bias, StridedMemrefView d_bmm1_lhs, StridedMemrefView d_bmm1_rhs, StridedMemrefView d_bmm2_rhs, - StridedMemrefView d_S, double fmha_scale, AlgorithmConfig algo, + std::optional d_S, + std::optional softmax_sum, + std::optional d_Q_accum, + std::optional fwd_output, + std::optional bias, double fmha_scale, + AlgorithmConfig algo, DotDimensionNumbers bmm1_grad_gemm1_dot_dimension_numbers, DotDimensionNumbers bmm1_grad_gemm2_dot_dimension_numbers, DotDimensionNumbers bmm2_grad_gemm1_dot_dimension_numbers, DotDimensionNumbers bmm2_grad_gemm2_dot_dimension_numbers, + absl::Span intermediate_tensor_dimensions, + absl::Span intermediate_tensor_layout, + bool is_flash_attention, bool is_causal_mask, std::optional dropout_attrs = std::nullopt) { GpufMHABackwardDescriptor descriptor; descriptor.backend_config.set_fmha_scale(fmha_scale); @@ -313,7 +326,15 @@ static GpufMHABackwardDescriptor GetGpufMHABackwardDescriptor( descriptor.bmm1_grad_gemm1_rhs_shape = apply_shape(bmm1_grad_gemm1_rhs); descriptor.bmm1_grad_gemm2_rhs_shape = apply_shape(bmm1_grad_gemm2_rhs); descriptor.bmm2_grad_gemm2_rhs_shape = apply_shape(bmm2_grad_gemm2_rhs); - descriptor.bmm2_grad_gemm1_lhs_shape = apply_shape(bmm2_grad_gemm1_lhs); + if (is_flash_attention) { + // if it is flash attention then bmm2_grad_gemm1_lhs will be softmax_stats + // instead of P we need to use real P layout + descriptor.bmm2_grad_gemm1_lhs_shape = ShapeUtil::MakeShapeWithDenseLayout( + descriptor.bmm2_grad_gemm2_rhs_shape.element_type(), + intermediate_tensor_dimensions, intermediate_tensor_layout); + } else { + descriptor.bmm2_grad_gemm1_lhs_shape = apply_shape(bmm2_grad_gemm1_lhs); + } descriptor.d_output_shape = apply_shape(d_output); descriptor.d_bmm1_lhs_shape = apply_shape(d_bmm1_lhs); @@ -326,14 +347,20 @@ static GpufMHABackwardDescriptor GetGpufMHABackwardDescriptor( if (d_bias.has_value()) { descriptor.d_bias_shape = apply_shape(*d_bias); } - + if (fwd_output.has_value()) { + descriptor.fwd_output_shape = apply_shape(*fwd_output); + } + if (bias.has_value()) { + descriptor.bias_shape = apply_shape(*bias); + } if (dropout_attrs.has_value()) { descriptor.backend_config.set_dropout_rate(dropout_attrs->dropout_rate); descriptor.backend_config.set_seed(dropout_attrs->seed); } descriptor.kind = kind; - + descriptor.is_flash_attention = is_flash_attention; + descriptor.is_causal_mask = is_causal_mask; return descriptor; } @@ -344,7 +371,8 @@ static absl::Status FusedAttentionForwardImpl( StridedMemrefView rhs_bmm2, std::optional mask, std::optional bias, StridedMemrefView output, FlatMemrefView scratch, std::optional activation, - int64_t uid, double fmha_scale, + int64_t uid, double fmha_scale, bool is_flash_attention, + bool is_causal_mask, absl::Span intermediate_tensor_dimensions, absl::Span intermediate_tensor_layout, DotDimensionNumbers bmm1_dot_dimension_numbers, @@ -364,7 +392,7 @@ static absl::Status FusedAttentionForwardImpl( fmha_scale, intermediate_tensor_dimensions, intermediate_tensor_layout, algorithm_config, bmm1_dot_dimension_numbers, bmm2_dot_dimension_numbers, - dropout_attrs); + is_flash_attention, is_causal_mask, dropout_attrs); StatusOr config = GpufMHAConfig::For(descriptor); if (!config.ok()) return tsl::ToAbslStatus(config.status()); @@ -414,10 +442,17 @@ static absl::Status FusedAttentionBackwardImpl( StridedMemrefView bmm1_grad_gemm2_rhs, StridedMemrefView bmm2_grad_gemm2_rhs, StridedMemrefView bmm2_grad_gemm1_lhs, StridedMemrefView d_output, - std::optional mask, StridedMemrefView d_bmm1_lhs, + std::optional mask, + std::optional bias, + std::optional fwd_output, StridedMemrefView d_bmm1_lhs, StridedMemrefView d_bmm1_rhs, StridedMemrefView d_bmm2_rhs, - StridedMemrefView d_S, FlatMemrefView scratch, + std::optional d_S, + std::optional softmax_sum, + std::optional d_Q_accum, FlatMemrefView scratch, std::optional d_bias, int64_t uid, double fmha_scale, + bool is_flash_attention, bool is_causal_mask, + absl::Span intermediate_tensor_dimensions, + absl::Span intermediate_tensor_layout, DotDimensionNumbers bmm1_grad_gemm1_dot_dimension_numbers, DotDimensionNumbers bmm1_grad_gemm2_dot_dimension_numbers, DotDimensionNumbers bmm2_grad_gemm1_dot_dimension_numbers, @@ -436,12 +471,13 @@ static absl::Status FusedAttentionBackwardImpl( GpufMHABackwardDescriptor descriptor = GetGpufMHABackwardDescriptor( kind, bmm1_grad_gemm1_rhs, bmm1_grad_gemm2_rhs, bmm2_grad_gemm2_rhs, bmm2_grad_gemm1_lhs, d_output, mask, d_bias, d_bmm1_lhs, d_bmm1_rhs, - d_bmm2_rhs, d_S, fmha_scale, algorithm_config, - bmm1_grad_gemm1_dot_dimension_numbers, + d_bmm2_rhs, d_S, softmax_sum, d_Q_accum, fwd_output, bias, + fmha_scale, algorithm_config, bmm1_grad_gemm1_dot_dimension_numbers, bmm1_grad_gemm2_dot_dimension_numbers, bmm2_grad_gemm1_dot_dimension_numbers, - bmm2_grad_gemm2_dot_dimension_numbers, dropout_attrs); - + bmm2_grad_gemm2_dot_dimension_numbers, + intermediate_tensor_dimensions, intermediate_tensor_layout, + is_flash_attention, is_causal_mask, dropout_attrs); StatusOr config = GpufMHABackwardConfig::For(descriptor); if (!config.ok()) return tsl::ToAbslStatus(config.status()); @@ -463,9 +499,13 @@ static absl::Status FusedAttentionBackwardImpl( se::DeviceMemoryBase d_bmm1_lhs_buffer = GetDeviceAddress(d_bmm1_lhs); se::DeviceMemoryBase d_bmm1_rhs_buffer = GetDeviceAddress(d_bmm1_rhs); se::DeviceMemoryBase d_bmm2_rhs_buffer = GetDeviceAddress(d_bmm2_rhs); - se::DeviceMemoryBase d_S_buffer = GetDeviceAddress(d_S); se::DeviceMemoryBase scratch_buffer = GetDeviceAddress(scratch); + se::DeviceMemoryBase d_S_buffer; + if (d_S.has_value()) { + d_S_buffer = GetDeviceAddress(*d_S); + } + se::DeviceMemoryBase mask_buffer; if (mask.has_value()) { mask_buffer = GetDeviceAddress(*mask); @@ -476,6 +516,26 @@ static absl::Status FusedAttentionBackwardImpl( d_bias_buffer = GetDeviceAddress(*d_bias); } + se::DeviceMemoryBase softmax_sum_buffer; + if (softmax_sum.has_value()) { + softmax_sum_buffer = GetDeviceAddress(*softmax_sum); + } + + se::DeviceMemoryBase d_Q_accum_buffer; + if (d_Q_accum.has_value()) { + d_Q_accum_buffer = GetDeviceAddress(*d_Q_accum); + } + + se::DeviceMemoryBase fwd_output_buffer; + if (fwd_output.has_value()) { + fwd_output_buffer = GetDeviceAddress(*fwd_output); + } + + se::DeviceMemoryBase bias_buffer; + if (bias.has_value()) { + bias_buffer = GetDeviceAddress(*bias); + } + RunFusedMHABackwardOptions opts; opts.runner_cache = &(*fda)->runner; @@ -484,7 +544,9 @@ static absl::Status FusedAttentionBackwardImpl( (*fda)->config, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer, bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, d_output_buffer, scratch_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, - d_S_buffer, mask_buffer, d_bias_buffer, run_options->stream(), opts); + d_S_buffer, softmax_sum_buffer, d_Q_accum_buffer, mask_buffer, + d_bias_buffer, fwd_output_buffer, bias_buffer, run_options->stream(), + opts); if (!st.ok() || !run_options->stream()->ok()) { return tsl::ToAbslStatus(st); } @@ -500,6 +562,8 @@ auto BindFusedAttentionAttributes(runtime::CustomCallBinding binding) { return std::move(binding) .template Attr("uid") .template Attr("fmha_scale") + .template Attr("is_flash_attention") + .template Attr("is_causal_mask") .template Attr>( "intermediate_tensor_dimensions") .template Attr>("intermediate_tensor_layout") @@ -805,6 +869,11 @@ auto BindFusedAttentionBackwardAttributes( return std::move(binding) .template Attr("uid") .template Attr("fmha_scale") + .template Attr("is_flash_attention") + .template Attr("is_causal_mask") + .template Attr>( + "intermediate_tensor_dimensions") + .template Attr>("intermediate_tensor_layout") .template Attr( "bmm1_grad_gemm1_dot_dimension_numbers") .template Attr( @@ -822,11 +891,11 @@ auto FusedAttentionBackwardCall(const char* name) { .UserData() .UserData() .State("uid") - .Arg() // bmm1_grad_gemm1_rhs - .Arg() // bmm1_grad_gemm2_rhs - .Arg() // bmm2_grad_gemm2_rhs - .Arg() // bmm2_grad_gemm1_lhs - .Arg(); + .Arg() // bmm1_grad_gemm1_rhs + .Arg() // bmm1_grad_gemm2_rhs + .Arg() // bmm2_grad_gemm2_rhs + .Arg() // bmm2_grad_gemm1_lhs + .Arg(); // d_output } XLA_RUNTIME_DEFINE_CUSTOM_CALL( @@ -836,10 +905,14 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( FusedAttentionBackwardCall( "xla.gpu.fused.attention.backward.scale.dbias.softmax") .Value(std::optional()) // mask + .Value(std::optional()) // bias + .Value(std::optional()) // fwd_output .Arg() // d_bmm1_lhs .Arg() // d_bmm1_rhs .Arg() // d_bmm2_rhs .Arg() // d_S + .Value(std::optional()) // softmax_sum + .Value(std::optional()) // d_Q_accum .Arg() // scratch .Arg() // d_bias ) @@ -854,10 +927,14 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( FusedAttentionBackwardCall( "xla.gpu.fused.attention.backward.scale.softmax") .Value(std::optional()) // mask + .Value(std::optional()) // bias + .Value(std::optional()) // fwd_output .Arg() // d_bmm1_lhs .Arg() // d_bmm1_rhs .Arg() // d_bmm2_rhs .Arg() // d_S + .Value(std::optional()) // softmax_sum + .Value(std::optional()) // d_Q_accum .Arg() // scratch .Value(std::optional()) // d_bias ) @@ -872,10 +949,14 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( FusedAttentionBackwardCall( "xla.gpu.fused.attention.backward.scale.dbias.softmax.dropout") .Value(std::optional()) // mask + .Value(std::optional()) // bias + .Value(std::optional()) // fwd_output .Arg() // d_bmm1_lhs .Arg() // d_bmm1_rhs .Arg() // d_bmm2_rhs .Arg() // d_S + .Value(std::optional()) // softmax_sum + .Value(std::optional()) // d_Q_accum .Arg() // scratch .Arg() // d_bias ) @@ -890,10 +971,14 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( FusedAttentionBackwardCall( "xla.gpu.fused.attention.backward.scale.softmax.dropout") .Value(std::optional()) // mask + .Value(std::optional()) // bias + .Value(std::optional()) // fwd_output .Arg() // d_bmm1_lhs .Arg() // d_bmm1_rhs .Arg() // d_bmm2_rhs .Arg() // d_S + .Value(std::optional()) // softmax_sum + .Value(std::optional()) // d_Q_accum .Arg() // scratch .Value(std::optional()) // d_bias ) @@ -907,13 +992,17 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( BindFusedAttentionBackwardAttributes( FusedAttentionBackwardCall( "xla.gpu.fused.attention.backward.scale.dbias.mask.softmax") - .Arg() // mask - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Arg() // d_S - .Arg() // scratch - .Arg() // d_bias + .Arg() // mask + .Value(std::optional()) // bias + .Value(std::optional()) // fwd_output + .Arg() // d_bmm1_lhs + .Arg() // d_bmm1_rhs + .Arg() // d_bmm2_rhs + .Arg() // d_S + .Value(std::optional()) // softmax_sum + .Value(std::optional()) // d_Q_accum + .Arg() // scratch + .Arg() // d_bias ) .Value(std::optional()) // dropout_rate .Value(std::optional()) // seed @@ -926,10 +1015,14 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( FusedAttentionBackwardCall( "xla.gpu.fused.attention.backward.scale.mask.softmax") .Arg() // mask + .Value(std::optional()) // bias + .Value(std::optional()) // fwd_output .Arg() // d_bmm1_lhs .Arg() // d_bmm1_rhs .Arg() // d_bmm2_rhs .Arg() // d_S + .Value(std::optional()) // softmax_sum + .Value(std::optional()) // d_Q_accum .Arg() // scratch .Value(std::optional()) // d_bias ) @@ -943,13 +1036,17 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( BindFusedAttentionBackwardAttributes( FusedAttentionBackwardCall( "xla.gpu.fused.attention.backward.scale.dbias.mask.softmax.dropout") - .Arg() // mask - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Arg() // d_S - .Arg() // scratch - .Arg() // d_bias + .Arg() // mask + .Value(std::optional()) // bias + .Value(std::optional()) // fwd_output + .Arg() // d_bmm1_lhs + .Arg() // d_bmm1_rhs + .Arg() // d_bmm2_rhs + .Arg() // d_S + .Value(std::optional()) // softmax_sum + .Value(std::optional()) // d_Q_accum + .Arg() // scratch + .Arg() // d_bias ) .Attr("dropout_rate") // dropout_rate .Attr("seed") // seed @@ -962,16 +1059,66 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( FusedAttentionBackwardCall( "xla.gpu.fused.attention.backward.scale.mask.softmax.dropout") .Arg() // mask + .Value(std::optional()) // bias + .Value(std::optional()) // fwd_output .Arg() // d_bmm1_lhs .Arg() // d_bmm1_rhs .Arg() // d_bmm2_rhs .Arg() // d_S + .Value(std::optional()) // softmax_sum + .Value(std::optional()) // d_Q_accum .Arg() // scratch .Value(std::optional()) // d_bias ) .Attr("dropout_rate") // dropout_rate .Attr("seed") // seed ); + +// flash attention backward custom call +XLA_RUNTIME_DEFINE_CUSTOM_CALL( + FlashAttentionScaleBiasSoftmaxBackward, + FunctionWrapper(), checks, + BindFusedAttentionBackwardAttributes( + FusedAttentionBackwardCall( + "xla.gpu.flash.attention.backward.scale.bias.softmax") + .Value(std::optional()) // mask + .Arg() // bias + .Arg() // fwd_output + .Arg() // d_bmm1_lhs + .Arg() // d_bmm1_rhs + .Arg() // d_bmm2_rhs + .Value(std::optional()) // d_S + .Arg() // softmax_sum + .Arg() // d_Q_accum + .Arg() // scratch + .Value(std::optional()) // d_bias + ) + .Value(std::optional()) // dropout_rate + .Value(std::optional()) // seed +); + +XLA_RUNTIME_DEFINE_CUSTOM_CALL( + FlashAttentionScaleSoftmaxBackward, + FunctionWrapper(), checks, + BindFusedAttentionBackwardAttributes( + FusedAttentionBackwardCall( + "xla.gpu.flash.attention.backward.scale.softmax") + .Value(std::optional()) // mask + .Value(std::optional()) // bias + .Arg() // fwd_output + .Arg() // d_bmm1_lhs + .Arg() // d_bmm1_rhs + .Arg() // d_bmm2_rhs + .Value(std::optional()) // d_S + .Arg() // softmax_sum + .Arg() // d_Q_accum + .Arg() // scratch + .Value(std::optional()) // d_bias + ) + .Value(std::optional()) // dropout_rate + .Value(std::optional()) // seed +); + //===----------------------------------------------------------------------===// // cuBLASLt custom calls bindings and registration. //===----------------------------------------------------------------------===// @@ -1040,6 +1187,14 @@ void RegisterFusedAttentionBackwardCustomCalls( FusedAttentionScaleBiasMaskSoftmaxDropoutBackward); registry.Register(fused_attention("scale.mask.softmax.dropout"), FusedAttentionScaleMaskSoftmaxDropoutBackward); + // flash attention bwd + auto flash_attention = [](std::string name) { + return "xla.gpu.flash.attention.backward." + name; + }; + registry.Register(flash_attention("scale.bias.softmax"), + FlashAttentionScaleBiasSoftmaxBackward); + registry.Register(flash_attention("scale.softmax"), + FlashAttentionScaleSoftmaxBackward); } } // namespace gpu } // namespace xla diff --git a/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc b/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc index ac7c223b87426..1a340c51cb326 100644 --- a/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc +++ b/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc @@ -780,6 +780,12 @@ AsLhloFusedMhaBackwardDagSignature(xla::gpu::CudnnfMHAKind kind) { return lmhlo_gpu::FusedMhaBackwardDagSignature:: BackwardScaleBiasMaskSoftmaxDropout; break; + case xla::gpu::CudnnfMHAKind::kBackwardSoftmax: + return lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmax; + break; + case xla::gpu::CudnnfMHAKind::kBackwardSoftmaxDropout: + return lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmaxDropout; + break; default: return xla::InternalError("unknown cudnn fmha bwd kind"); } @@ -1270,13 +1276,17 @@ tsl::StatusOr LhloDialectEmitter::EmitDnnfMHA( has_activation ? 1 : 0}; op->setAttr(op.getOperandSegmentSizeAttr(), builder_.getDenseI32ArrayAttr(operand_sizes)); + // set is flash attention here + op.setIsFlashAttentionAttr( + builder_.getBoolAttr(config.is_flash_attention())); + // set is causal mask here + op.setIsCausalMaskAttr(builder_.getBoolAttr(config.is_causal_mask())); return op.getOperation(); }; llvm::SmallVector operands; TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(0), &operands)); TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(1), &operands)); TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(2), &operands)); - switch (kind) { case xla::gpu::CudnnfMHAKind::kBmmBmm: case xla::gpu::CudnnfMHAKind::kSoftmax: { @@ -1365,8 +1375,11 @@ tsl::StatusOr LhloDialectEmitter::EmitDnnfMHABackward( TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind, xla::gpu::GetCudnnfMHAKind(custom_call)); - bool has_dbias = custom_call->shape().tuple_shapes().size() == 6; + bool is_flash_attention = config.is_flash_attention(); + bool has_dbias = + custom_call->shape().tuple_shapes().size() == 6 && !is_flash_attention; bool has_mask = false; + bool has_bias = false; auto set_common_fmha_backward_attributes = [&, this](auto op) -> tsl::StatusOr { @@ -1384,13 +1397,44 @@ tsl::StatusOr LhloDialectEmitter::EmitDnnfMHABackward( op.setBmm2GradGemm2DotDimensionNumbersAttr(GetDotDimensionNumbersAttr( builder_, config.bmm2_grad_gemm2_dot_dimension_numbers())); + auto intermediate_tensor_shape = Shape(config.intermediate_tensor_shape()); + auto arrayref = [](absl::Span array) { + return llvm::ArrayRef{array.data(), array.size()}; + }; + auto intermediate_tensor_dims = builder_.getI64ArrayAttr( + arrayref(intermediate_tensor_shape.dimensions())); + op.setIntermediateTensorDimensionsAttr(intermediate_tensor_dims); + + auto intermediate_tensor_layout = builder_.getI64ArrayAttr( + arrayref(intermediate_tensor_shape.layout().minor_to_major())); + op.setIntermediateTensorLayoutAttr(intermediate_tensor_layout); + op.setFmhaScaleAttr(builder_.getF64FloatAttr(config.fmha_scale())); - int32_t operand_sizes[] = {1, 1, 1, 1, 1, has_mask ? 1 : 0, - 1, 1, 1, 1, 1, has_dbias ? 1 : 0}; + int32_t operand_sizes[] = {1, + 1, + 1, + 1, + 1, + has_mask ? 1 : 0, + has_bias ? 1 : 0, + is_flash_attention ? 1 : 0, // fwd_output + 1, + 1, + 1, + is_flash_attention ? 0 : 1, // d_S + is_flash_attention ? 1 : 0, // softmax_sum + is_flash_attention ? 1 : 0, // d_Q_accum + 1, + has_dbias ? 1 : 0}; op->setAttr(op.getOperandSegmentSizeAttr(), builder_.getDenseI32ArrayAttr(operand_sizes)); + // set is flash attention here + op.setIsFlashAttentionAttr( + builder_.getBoolAttr(config.is_flash_attention())); + // set is causal mask here + op.setIsCausalMaskAttr(builder_.getBoolAttr(config.is_causal_mask())); const auto& algorithm = config.algorithm(); std::vector knob_ids; std::vector knob_values; @@ -1406,7 +1450,7 @@ tsl::StatusOr LhloDialectEmitter::EmitDnnfMHABackward( return op.getOperation(); }; - llvm::SmallVector operands; + llvm::SmallVector operands; TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(0), &operands)); TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(1), &operands)); TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(2), &operands)); @@ -1415,15 +1459,35 @@ tsl::StatusOr LhloDialectEmitter::EmitDnnfMHABackward( switch (kind) { case xla::gpu::CudnnfMHAKind::kBackwardBmmBmm: - case xla::gpu::CudnnfMHAKind::kBackwardSoftmax: + case xla::gpu::CudnnfMHAKind::kBackwardSoftmax: { + // push fwd output for bwd here if it is flash attention + if (config.is_flash_attention()) { + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); + } + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); + auto fmha_backward = CreateOpWithoutAttrs( + custom_call, operands); + return set_common_fmha_backward_attributes(fmha_backward); + } case xla::gpu::CudnnfMHAKind::kBackwardScaleBiasSoftmax: { + // push fwd output for bwd here if it is flash attention + if (config.is_flash_attention()) { + has_bias = true; + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(6), &operands)); + } TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); auto fmha_backward = CreateOpWithoutAttrs( custom_call, operands); return set_common_fmha_backward_attributes(fmha_backward); } - case xla::gpu::CudnnfMHAKind::kBackwardSoftmaxDropout: case xla::gpu::CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout: { + // push fwd output for bwd here if it is flash attention + if (config.is_flash_attention()) { + has_bias = true; + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(6), &operands)); + } TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); auto fmha_backward = CreateOpWithoutAttrs( custom_call, operands); @@ -1433,9 +1497,26 @@ tsl::StatusOr LhloDialectEmitter::EmitDnnfMHABackward( return set_common_fmha_backward_attributes(fmha_backward); } - case xla::gpu::CudnnfMHAKind::kBackwardScaleMaskSoftmax: + case xla::gpu::CudnnfMHAKind::kBackwardScaleMaskSoftmax: { + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); + // push fwd output for bwd here if it is flash attention + if (config.is_flash_attention()) { + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(6), &operands)); + } + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); + has_mask = true; + auto fmha_backward = CreateOpWithoutAttrs( + custom_call, operands); + return set_common_fmha_backward_attributes(fmha_backward); + } case xla::gpu::CudnnfMHAKind::kBackwardScaleBiasMaskSoftmax: { TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); + // push fwd output for bwd here if it is flash attention + if (config.is_flash_attention()) { + has_bias = true; + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(6), &operands)); + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(7), &operands)); + } TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); has_mask = true; auto fmha_backward = CreateOpWithoutAttrs( @@ -1443,9 +1524,31 @@ tsl::StatusOr LhloDialectEmitter::EmitDnnfMHABackward( return set_common_fmha_backward_attributes(fmha_backward); } - case xla::gpu::CudnnfMHAKind::kBackwardScaleMaskSoftmaxDropout: + case xla::gpu::CudnnfMHAKind::kBackwardScaleMaskSoftmaxDropout: { + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); + // push fwd output for bwd here if it is flash attention + if (config.is_flash_attention()) { + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(6), &operands)); + } + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); + has_mask = true; + auto fmha_backward = CreateOpWithoutAttrs( + custom_call, operands); + fmha_backward.setDropoutRateAttr( + builder_.getF64FloatAttr(config.dropout_rate())); + fmha_backward.setSeedAttr(builder_.getI64IntegerAttr(config.seed())); + return set_common_fmha_backward_attributes(fmha_backward); + } case xla::gpu::CudnnfMHAKind::kBackwardScaleBiasMaskSoftmaxDropout: { TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); + // push fwd output for bwd here if it is flash attention + if (config.is_flash_attention()) { + has_bias = true; + TF_RETURN_IF_ERROR( + GetOrCreateView(custom_call->operand(6), &operands)); // bias + TF_RETURN_IF_ERROR( + GetOrCreateView(custom_call->operand(7), &operands)); // fwd_output + } TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); has_mask = true; auto fmha_backward = CreateOpWithoutAttrs(