Skip to content

Commit

Permalink
PR #6657: [XLA:GPU ] add cuDNN flash attention support in XLA (2nd PR…
Browse files Browse the repository at this point in the history
… with only MLIR lowering and thunk/runtime)

Imported from GitHub PR openxla/xla#6657

This is the 2nd PR of splitting openxla/xla#5910 with only MLIR lowering and thunk/runtime
1st PR openxla/xla#6293 merged.

* Added MLIR lowering for flash attention.
* Added thunk/runner/runtime support for flash attention.
Copybara import of the project:

--
6f89a7355b4b46cbb974b39ca60e07ae08079f1a by cjkkkk <ske@nvidia.com>:

init mlir lowering and thunk runtime

--
f57b8bee2ba1ad361556c32cb9333c4ac4730016 by cjkkkk <ske@nvidia.com>:

address some comments

Merging this change closes #6657

PiperOrigin-RevId: 580413629
  • Loading branch information
Cjkkkk authored and TensorFlow MLIR Team committed Nov 8, 2023
1 parent b431999 commit f26dbcd
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
17 changes: 14 additions & 3 deletions lhlo_gpu/IR/lhlo_gpu_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,9 @@ def LHLOGPU_fusedMHAOp : LHLOGPU_Op<"fMHA", [AttrSizedOperandSegments]> {
FusedMhaDagSignatureAttr:$fused_mha_dag,
FusedMHAAlgorithmConfigAttr:$algorithm_config,
OptionalAttr<F64Attr>:$dropout_rate,
OptionalAttr<I64Attr>:$seed
OptionalAttr<I64Attr>:$seed,
BoolAttr:$is_flash_attention,
BoolAttr:$is_causal_mask
);
}

Expand All @@ -374,21 +376,30 @@ def LHLOGPU_fusedMHABackwardOp : LHLOGPU_Op<"fMHABackward", [AttrSizedOperandSeg
Arg<LHLO_Buffer, "", [MemRead]>:$bmm2_grad_gemm1_lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$d_output,
Arg<Optional<LHLO_Buffer>, "", [MemRead]>:$mask,
Arg<Optional<LHLO_Buffer>, "", [MemRead]>:$bias,
Arg<Optional<LHLO_Buffer>, "", [MemRead]>:$fwd_output,
Arg<LHLO_Buffer, "", [MemWrite]>:$d_bmm1_lhs,
Arg<LHLO_Buffer, "", [MemWrite]>:$d_bmm1_rhs,
Arg<LHLO_Buffer, "", [MemWrite]>:$d_bmm2_rhs,
Arg<LHLO_Buffer, "", [MemWrite]>:$d_S,
Arg<Optional<LHLO_Buffer>, "", [MemWrite]>:$d_S,
Arg<Optional<LHLO_Buffer>, "", [MemWrite]>:$softmax_sum,
Arg<Optional<LHLO_Buffer>, "", [MemWrite]>:$d_Q_accum,
Arg<LHLO_Buffer, "", [MemWrite]>:$scratch,
Arg<Optional<LHLO_Buffer>, "", [MemWrite]>:$d_bias,
MHLO_DotDimensionNumbers:$bmm1_grad_gemm1_dot_dimension_numbers,
MHLO_DotDimensionNumbers:$bmm1_grad_gemm2_dot_dimension_numbers,
MHLO_DotDimensionNumbers:$bmm2_grad_gemm1_dot_dimension_numbers,
MHLO_DotDimensionNumbers:$bmm2_grad_gemm2_dot_dimension_numbers,
I64ArrayAttr:$intermediate_tensor_dimensions,
I64ArrayAttr:$intermediate_tensor_layout,
F64Attr:$fmha_scale,
FusedMhaBackwardDagSignatureAttr:$fused_mha_dag,
FusedMHAAlgorithmConfigAttr:$algorithm_config,
OptionalAttr<F64Attr>:$dropout_rate,
OptionalAttr<I64Attr>:$seed);
OptionalAttr<I64Attr>:$seed,
BoolAttr:$is_flash_attention,
BoolAttr:$is_causal_mask
);
}

def LHLOGPU_RadixSortOp: LHLOGPU_Op<"radix_sort", [SameVariadicOperandSize]> {
Expand Down
8 changes: 6 additions & 2 deletions lhlo_gpu/IR/lhlo_gpu_ops_enums.td
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ def FusedMhaBackwardDagScaleBiasSoftmaxDropout : I32EnumAttrCase<"BackwardScaleB
def FusedMhaBackwardDagScaleBiasSoftmax : I32EnumAttrCase<"BackwardScaleBiasSoftmax", 1>;
def FusedMhaBackwardDagScaleBiasMaskSoftmax : I32EnumAttrCase<"BackwardScaleBiasMaskSoftmax", 2>;
def FusedMhaBackwardDagScaleBiasMaskSoftmaxDropout : I32EnumAttrCase<"BackwardScaleBiasMaskSoftmaxDropout", 3>;
def FusedMhaBackwardDagSoftmax : I32EnumAttrCase<"BackwardSoftmax", 4>;
def FusedMhaBackwardDagSoftmaxDropout : I32EnumAttrCase<"BackwardSoftmaxDropout", 5>;

def FusedMhaDagSignature: I32EnumAttr<"FusedMhaDagSignature",
"DAG configuration for Fused Multi-Headed Attention",
Expand All @@ -175,11 +177,13 @@ def FusedMhaBackwardDagSignature: I32EnumAttr<"FusedMhaBackwardDagSignature",
FusedMhaBackwardDagScaleBiasSoftmaxDropout,
FusedMhaBackwardDagScaleBiasSoftmax,
FusedMhaBackwardDagScaleBiasMaskSoftmax,
FusedMhaBackwardDagScaleBiasMaskSoftmaxDropout]> {
FusedMhaBackwardDagScaleBiasMaskSoftmaxDropout,
FusedMhaBackwardDagSoftmax,
FusedMhaBackwardDagSoftmaxDropout]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::lmhlo_gpu";
}

def FusedMhaDagSignatureAttr : EnumAttr<LmhloGpuDialect, FusedMhaDagSignature, "fused_mha_dag">;
def FusedMhaBackwardDagSignatureAttr : EnumAttr<LmhloGpuDialect, FusedMhaBackwardDagSignature, "fused_mha_backward_dag">;
#endif // LHLO_GPU_OPS_ENUMS
#endif // LHLO_GPU_OPS_ENUMS

0 comments on commit f26dbcd

Please sign in to comment.