Skip to content

Commit

Permalink
PR #6657: [XLA:GPU ] add cuDNN flash attention support in XLA (2nd PR…
Browse files Browse the repository at this point in the history
… with only MLIR lowering and thunk/runtime)

Imported from GitHub PR #6657

This is the 2nd PR of splitting #5910 with only MLIR lowering and thunk/runtime
1st PR #6293 merged.

* Added MLIR lowering for flash attention.
* Added thunk/runner/runtime support for flash attention.
Copybara import of the project:

--
6f89a73 by cjkkkk <ske@nvidia.com>:

init mlir lowering and thunk runtime

--
f57b8be by cjkkkk <ske@nvidia.com>:

address some comments

Merging this change closes #6657

COPYBARA_INTEGRATE_REVIEW=#6657 from Cjkkkk:flash_attention_mhlo_lowering f57b8be
PiperOrigin-RevId: 580413629
  • Loading branch information
Cjkkkk authored and copybara-github committed Nov 8, 2023
1 parent 3f04af0 commit fa114ef
Show file tree
Hide file tree
Showing 12 changed files with 670 additions and 177 deletions.
51 changes: 47 additions & 4 deletions xla/mlir/backends/gpu/transforms/lmhlo_gpu_to_gpu_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,8 @@ class FusedAttentionForwardLowering
set_attr("fmha_scale", op.getFmhaScaleAttr());
set_attr("dropout_rate", op.getDropoutRateAttr());
set_attr("seed", op.getSeedAttr());

set_attr("is_flash_attention", op.getIsFlashAttentionAttr());
set_attr("is_causal_mask", op.getIsCausalMaskAttr());
set_attr("fused_mha_dag", op.getFusedMhaDagAttr());
set_attr("algorithm_config", op.getAlgorithmConfigAttr());
set_attr("bmm1_dot_dimension_numbers", op.getBmm1DotDimensionNumbers());
Expand Down Expand Up @@ -784,8 +785,10 @@ template <typename FusedDotAttentionBackward>
class FusedAttentionBackwardLowering
: public OpRewritePattern<FusedDotAttentionBackward> {
private:
static constexpr const char kCustomCallTarget[] =
static constexpr const char kFusedAttentionCustomCallTarget[] =
"xla.gpu.fused.attention.backward.";
static constexpr const char kFlashAttentionCustomCallTarget[] =
"xla.gpu.flash.attention.backward.";

public:
explicit FusedAttentionBackwardLowering(MLIRContext* ctx, UidGenerator& uid,
Expand All @@ -797,11 +800,36 @@ class FusedAttentionBackwardLowering
LogicalResult matchAndRewrite(FusedDotAttentionBackward op,
PatternRewriter& rewriter) const override {
// Get the custom call target.
std::string fused_attention = kCustomCallTarget;
bool is_flash_attention = op.getIsFlashAttention();
std::string fused_attention = is_flash_attention
? kFlashAttentionCustomCallTarget
: kFusedAttentionCustomCallTarget;
auto num_operands = op.getNumOperands();
switch (op.getFusedMhaDag()) {
case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmax:
if (is_flash_attention) {
if (num_operands == 12) {
fused_attention += "scale.softmax";
} else {
return op.emitOpError(
"unexpected number of operands for flash attention backward - "
"BMM_Softmax_BMM");
}
}
break;

case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature::
BackwardScaleBiasSoftmax:
if (is_flash_attention) {
if (num_operands == 13) {
fused_attention += "scale.bias.softmax";
} else {
return op.emitOpError(
"unexpected number of operands for flash attention backward - "
"BMM_Bias_Softmax_BMM");
}
break;
}
if (num_operands == 10) {
fused_attention += "scale.softmax";
} else if (num_operands == 11) {
Expand Down Expand Up @@ -877,7 +905,8 @@ class FusedAttentionBackwardLowering
set_attr("fmha_scale", op.getFmhaScaleAttr());
set_attr("dropout_rate", op.getDropoutRateAttr());
set_attr("seed", op.getSeedAttr());

set_attr("is_flash_attention", op.getIsFlashAttentionAttr());
set_attr("is_causal_mask", op.getIsCausalMaskAttr());
set_attr("fused_mha_dag", op.getFusedMhaDagAttr());
set_attr("algorithm_config", op.getAlgorithmConfigAttr());
set_attr("bmm1_grad_gemm1_dot_dimension_numbers",
Expand All @@ -889,6 +918,20 @@ class FusedAttentionBackwardLowering
set_attr("bmm2_grad_gemm2_dot_dimension_numbers",
op.getBmm2GradGemm2DotDimensionNumbers());

auto set_xi64 = [&](StringRef name, mlir::ArrayAttr array) {
int rank = array.size();
SmallVector<int64_t> values;
for (int i = 0; i < rank; i++) {
mlir::IntegerAttr attr = array[i].dyn_cast<mlir::IntegerAttr>();
values.push_back(attr.getInt());
}
set_attr(name, b.getI64TensorAttr(values));
};

set_xi64("intermediate_tensor_dimensions",
op.getIntermediateTensorDimensions());
set_xi64("intermediate_tensor_layout", op.getIntermediateTensorLayout());

// Erase the original fused dot attention operation.
rewriter.eraseOp(op);

Expand Down
17 changes: 14 additions & 3 deletions xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,9 @@ def LHLOGPU_fusedMHAOp : LHLOGPU_Op<"fMHA", [AttrSizedOperandSegments]> {
FusedMhaDagSignatureAttr:$fused_mha_dag,
FusedMHAAlgorithmConfigAttr:$algorithm_config,
OptionalAttr<F64Attr>:$dropout_rate,
OptionalAttr<I64Attr>:$seed
OptionalAttr<I64Attr>:$seed,
BoolAttr:$is_flash_attention,
BoolAttr:$is_causal_mask
);
}

Expand All @@ -374,21 +376,30 @@ def LHLOGPU_fusedMHABackwardOp : LHLOGPU_Op<"fMHABackward", [AttrSizedOperandSeg
Arg<LHLO_Buffer, "", [MemRead]>:$bmm2_grad_gemm1_lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$d_output,
Arg<Optional<LHLO_Buffer>, "", [MemRead]>:$mask,
Arg<Optional<LHLO_Buffer>, "", [MemRead]>:$bias,
Arg<Optional<LHLO_Buffer>, "", [MemRead]>:$fwd_output,
Arg<LHLO_Buffer, "", [MemWrite]>:$d_bmm1_lhs,
Arg<LHLO_Buffer, "", [MemWrite]>:$d_bmm1_rhs,
Arg<LHLO_Buffer, "", [MemWrite]>:$d_bmm2_rhs,
Arg<LHLO_Buffer, "", [MemWrite]>:$d_S,
Arg<Optional<LHLO_Buffer>, "", [MemWrite]>:$d_S,
Arg<Optional<LHLO_Buffer>, "", [MemWrite]>:$softmax_sum,
Arg<Optional<LHLO_Buffer>, "", [MemWrite]>:$d_Q_accum,
Arg<LHLO_Buffer, "", [MemWrite]>:$scratch,
Arg<Optional<LHLO_Buffer>, "", [MemWrite]>:$d_bias,
MHLO_DotDimensionNumbers:$bmm1_grad_gemm1_dot_dimension_numbers,
MHLO_DotDimensionNumbers:$bmm1_grad_gemm2_dot_dimension_numbers,
MHLO_DotDimensionNumbers:$bmm2_grad_gemm1_dot_dimension_numbers,
MHLO_DotDimensionNumbers:$bmm2_grad_gemm2_dot_dimension_numbers,
I64ArrayAttr:$intermediate_tensor_dimensions,
I64ArrayAttr:$intermediate_tensor_layout,
F64Attr:$fmha_scale,
FusedMhaBackwardDagSignatureAttr:$fused_mha_dag,
FusedMHAAlgorithmConfigAttr:$algorithm_config,
OptionalAttr<F64Attr>:$dropout_rate,
OptionalAttr<I64Attr>:$seed);
OptionalAttr<I64Attr>:$seed,
BoolAttr:$is_flash_attention,
BoolAttr:$is_causal_mask
);
}

def LHLOGPU_RadixSortOp: LHLOGPU_Op<"radix_sort", [SameVariadicOperandSize]> {
Expand Down
8 changes: 6 additions & 2 deletions xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_enums.td
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ def FusedMhaBackwardDagScaleBiasSoftmaxDropout : I32EnumAttrCase<"BackwardScaleB
def FusedMhaBackwardDagScaleBiasSoftmax : I32EnumAttrCase<"BackwardScaleBiasSoftmax", 1>;
def FusedMhaBackwardDagScaleBiasMaskSoftmax : I32EnumAttrCase<"BackwardScaleBiasMaskSoftmax", 2>;
def FusedMhaBackwardDagScaleBiasMaskSoftmaxDropout : I32EnumAttrCase<"BackwardScaleBiasMaskSoftmaxDropout", 3>;
def FusedMhaBackwardDagSoftmax : I32EnumAttrCase<"BackwardSoftmax", 4>;
def FusedMhaBackwardDagSoftmaxDropout : I32EnumAttrCase<"BackwardSoftmaxDropout", 5>;

def FusedMhaDagSignature: I32EnumAttr<"FusedMhaDagSignature",
"DAG configuration for Fused Multi-Headed Attention",
Expand All @@ -175,11 +177,13 @@ def FusedMhaBackwardDagSignature: I32EnumAttr<"FusedMhaBackwardDagSignature",
FusedMhaBackwardDagScaleBiasSoftmaxDropout,
FusedMhaBackwardDagScaleBiasSoftmax,
FusedMhaBackwardDagScaleBiasMaskSoftmax,
FusedMhaBackwardDagScaleBiasMaskSoftmaxDropout]> {
FusedMhaBackwardDagScaleBiasMaskSoftmaxDropout,
FusedMhaBackwardDagSoftmax,
FusedMhaBackwardDagSoftmaxDropout]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::lmhlo_gpu";
}

def FusedMhaDagSignatureAttr : EnumAttr<LmhloGpuDialect, FusedMhaDagSignature, "fused_mha_dag">;
def FusedMhaBackwardDagSignatureAttr : EnumAttr<LmhloGpuDialect, FusedMhaBackwardDagSignature, "fused_mha_backward_dag">;
#endif // LHLO_GPU_OPS_ENUMS
#endif // LHLO_GPU_OPS_ENUMS
6 changes: 6 additions & 0 deletions xla/service/gpu/backend_configs.proto
Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,10 @@ message CudnnfMHABackendConfig {

// Random seed used by dropout
int64 seed = 15;

// Is flash attention
bool is_flash_attention = 20;

// Is causal mask
bool is_causal_mask = 21;
}
64 changes: 38 additions & 26 deletions xla/service/gpu/fused_mha_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ FusedMultiHeadedAttentionRunner& FusedMHAThunk::GetOrCreateRunner(
return *it->second;
}

std::optional<se::DeviceMemoryBase> AssignBufferIfNotNull(
const BufferAllocations& buffer_allocations,
BufferAllocation::Slice& slice) {
return slice.allocation() != nullptr
? std::optional<se::DeviceMemoryBase>{buffer_allocations
.GetDeviceAddress(slice)}
: std::nullopt;
}

Status FusedMHAThunk::ExecuteOnStream(const ExecuteParams& params) {
const auto& buffer_allocations = *params.buffer_allocations;
se::DeviceMemoryBase lhs_bmm1_buffer =
Expand All @@ -74,19 +83,12 @@ Status FusedMHAThunk::ExecuteOnStream(const ExecuteParams& params) {
se::DeviceMemoryBase scratch_buffer =
buffer_allocations.GetDeviceAddress(scratch_buffer_);

std::optional<se::DeviceMemoryBase> mask_buffer;
if (mask_buffer_.allocation() != nullptr) {
mask_buffer = buffer_allocations.GetDeviceAddress(mask_buffer_);
}
std::optional<se::DeviceMemoryBase> bias_buffer;
if (bias_buffer_.allocation() != nullptr) {
bias_buffer = buffer_allocations.GetDeviceAddress(bias_buffer_);
}

std::optional<se::DeviceMemoryBase> activation_buffer;
if (activation_buffer_.allocation() != nullptr) {
activation_buffer = buffer_allocations.GetDeviceAddress(activation_buffer_);
}
std::optional<se::DeviceMemoryBase> mask_buffer =
AssignBufferIfNotNull(buffer_allocations, mask_buffer_);
std::optional<se::DeviceMemoryBase> bias_buffer =
AssignBufferIfNotNull(buffer_allocations, bias_buffer_);
std::optional<se::DeviceMemoryBase> activation_buffer =
AssignBufferIfNotNull(buffer_allocations, activation_buffer_);

RunFusedMHAOptions opts;
opts.runner_cache = &GetOrCreateRunner(params.stream);
Expand All @@ -109,7 +111,9 @@ FusedMHABackwardThunk::FusedMHABackwardThunk(
BufferAllocation::Slice d_output, BufferAllocation::Slice scratch,
BufferAllocation::Slice d_bmm1_lhs, BufferAllocation::Slice d_bmm1_rhs,
BufferAllocation::Slice d_bmm2_rhs, BufferAllocation::Slice d_s,
BufferAllocation::Slice mask, BufferAllocation::Slice d_bias)
BufferAllocation::Slice softmax_sum, BufferAllocation::Slice d_Q_accum,
BufferAllocation::Slice mask, BufferAllocation::Slice d_bias,
BufferAllocation::Slice fwd_output, BufferAllocation::Slice bias)
: Thunk(Kind::kFusedMHA, thunk_info),
bmm1_grad_gemm1_rhs_buffer_(bmm1_grad_gemm1_rhs),
bmm1_grad_gemm2_rhs_buffer_(bmm1_grad_gemm2_rhs),
Expand All @@ -121,8 +125,12 @@ FusedMHABackwardThunk::FusedMHABackwardThunk(
d_bmm1_rhs_buffer_(d_bmm1_rhs),
d_bmm2_rhs_buffer_(d_bmm2_rhs),
d_s_buffer_(d_s),
softmax_sum_buffer_(softmax_sum),
d_Q_accum_buffer_(d_Q_accum),
mask_buffer_(mask),
d_bias_buffer_(d_bias),
fwd_output_buffer_(fwd_output),
bias_buffer_(bias),
config_(std::move(config)) {}

FusedMultiHeadedAttentionBackwardRunner&
Expand Down Expand Up @@ -169,18 +177,21 @@ Status FusedMHABackwardThunk::ExecuteOnStream(const ExecuteParams& params) {
se::DeviceMemoryBase d_bmm2_rhs_buffer =
buffer_allocations.GetDeviceAddress(d_bmm2_rhs_buffer_);

se::DeviceMemoryBase d_S_buffer =
buffer_allocations.GetDeviceAddress(d_s_buffer_);
std::optional<se::DeviceMemoryBase> d_s_buffer =
AssignBufferIfNotNull(buffer_allocations, d_s_buffer_);
std::optional<se::DeviceMemoryBase> softmax_sum_buffer =
AssignBufferIfNotNull(buffer_allocations, softmax_sum_buffer_);
std::optional<se::DeviceMemoryBase> d_Q_accum_buffer =
AssignBufferIfNotNull(buffer_allocations, d_Q_accum_buffer_);
std::optional<se::DeviceMemoryBase> mask_buffer =
AssignBufferIfNotNull(buffer_allocations, mask_buffer_);
std::optional<se::DeviceMemoryBase> d_bias_buffer =
AssignBufferIfNotNull(buffer_allocations, d_bias_buffer_);
std::optional<se::DeviceMemoryBase> fwd_output_buffer =
AssignBufferIfNotNull(buffer_allocations, fwd_output_buffer_);
std::optional<se::DeviceMemoryBase> bias_buffer =
AssignBufferIfNotNull(buffer_allocations, bias_buffer_);

std::optional<se::DeviceMemoryBase> mask_buffer;
if (mask_buffer_.allocation() != nullptr) {
mask_buffer = buffer_allocations.GetDeviceAddress(mask_buffer_);
}

std::optional<se::DeviceMemoryBase> d_bias_buffer;
if (d_bias_buffer_.allocation() != nullptr) {
d_bias_buffer = buffer_allocations.GetDeviceAddress(d_bias_buffer_);
}
RunFusedMHABackwardOptions opts;

opts.runner_cache = &GetOrCreateRunner(params.stream);
Expand All @@ -189,7 +200,8 @@ Status FusedMHABackwardThunk::ExecuteOnStream(const ExecuteParams& params) {
config_, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer,
bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, d_output_buffer,
scratch_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer, d_bmm2_rhs_buffer,
d_S_buffer, mask_buffer, d_bias_buffer, params.stream, opts));
d_s_buffer, softmax_sum_buffer, d_Q_accum_buffer, mask_buffer,
d_bias_buffer, fwd_output_buffer, bias_buffer, params.stream, opts));
if (!params.stream->ok()) {
return InternalError("FusedMHABackwardThunk::ExecuteOnStream failed.");
}
Expand Down
12 changes: 10 additions & 2 deletions xla/service/gpu/fused_mha_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,13 @@ class FusedMHABackwardThunk : public Thunk {
BufferAllocation::Slice d_bmm1_lhs_slice,
BufferAllocation::Slice d_bmm1_rhs_slice,
BufferAllocation::Slice d_bmm2_rhs_slice,
BufferAllocation::Slice d_S_slice,
BufferAllocation::Slice d_s_slice,
BufferAllocation::Slice softmax_sum_slice,
BufferAllocation::Slice d_Q_accum_slice,
BufferAllocation::Slice mask_slice,
BufferAllocation::Slice d_bias_slice);
BufferAllocation::Slice d_bias_slice,
BufferAllocation::Slice fwd_output_slice,
BufferAllocation::Slice bias_slice);

FusedMHABackwardThunk(const FusedMHABackwardThunk&) = delete;
FusedMHABackwardThunk& operator=(const FusedMHABackwardThunk&) = delete;
Expand All @@ -111,8 +115,12 @@ class FusedMHABackwardThunk : public Thunk {
BufferAllocation::Slice d_bmm1_rhs_buffer_;
BufferAllocation::Slice d_bmm2_rhs_buffer_;
BufferAllocation::Slice d_s_buffer_;
BufferAllocation::Slice softmax_sum_buffer_;
BufferAllocation::Slice d_Q_accum_buffer_;
BufferAllocation::Slice mask_buffer_;
BufferAllocation::Slice d_bias_buffer_;
BufferAllocation::Slice fwd_output_buffer_;
BufferAllocation::Slice bias_buffer_;

FusedMultiHeadedAttentionBackwardRunner& GetOrCreateRunner(
const stream_executor::Stream* stream);
Expand Down
Loading

0 comments on commit fa114ef

Please sign in to comment.