From f26dbcdf8d837f9939c658e676a9c6bd37594080 Mon Sep 17 00:00:00 2001 From: Shanbin Ke Date: Tue, 7 Nov 2023 23:11:47 -0800 Subject: [PATCH] PR #6657: [XLA:GPU ] add cuDNN flash attention support in XLA (2nd PR with only MLIR lowering and thunk/runtime) Imported from GitHub PR https://github.com/openxla/xla/pull/6657 This is the 2nd PR of splitting https://github.com/openxla/xla/pull/5910 with only MLIR lowering and thunk/runtime 1st PR https://github.com/openxla/xla/pull/6293 merged. * Added MLIR lowering for flash attention. * Added thunk/runner/runtime support for flash attention. Copybara import of the project: -- 6f89a7355b4b46cbb974b39ca60e07ae08079f1a by cjkkkk : init mlir lowering and thunk runtime -- f57b8bee2ba1ad361556c32cb9333c4ac4730016 by cjkkkk : address some comments Merging this change closes #6657 PiperOrigin-RevId: 580413629 --- lhlo_gpu/IR/lhlo_gpu_ops.td | 17 ++++++++++++++--- lhlo_gpu/IR/lhlo_gpu_ops_enums.td | 8 ++++++-- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/lhlo_gpu/IR/lhlo_gpu_ops.td b/lhlo_gpu/IR/lhlo_gpu_ops.td index 3091ed06f..e56d2964d 100644 --- a/lhlo_gpu/IR/lhlo_gpu_ops.td +++ b/lhlo_gpu/IR/lhlo_gpu_ops.td @@ -362,7 +362,9 @@ def LHLOGPU_fusedMHAOp : LHLOGPU_Op<"fMHA", [AttrSizedOperandSegments]> { FusedMhaDagSignatureAttr:$fused_mha_dag, FusedMHAAlgorithmConfigAttr:$algorithm_config, OptionalAttr:$dropout_rate, - OptionalAttr:$seed + OptionalAttr:$seed, + BoolAttr:$is_flash_attention, + BoolAttr:$is_causal_mask ); } @@ -374,21 +376,30 @@ def LHLOGPU_fusedMHABackwardOp : LHLOGPU_Op<"fMHABackward", [AttrSizedOperandSeg Arg:$bmm2_grad_gemm1_lhs, Arg:$d_output, Arg, "", [MemRead]>:$mask, + Arg, "", [MemRead]>:$bias, + Arg, "", [MemRead]>:$fwd_output, Arg:$d_bmm1_lhs, Arg:$d_bmm1_rhs, Arg:$d_bmm2_rhs, - Arg:$d_S, + Arg, "", [MemWrite]>:$d_S, + Arg, "", [MemWrite]>:$softmax_sum, + Arg, "", [MemWrite]>:$d_Q_accum, Arg:$scratch, Arg, "", [MemWrite]>:$d_bias, MHLO_DotDimensionNumbers:$bmm1_grad_gemm1_dot_dimension_numbers, MHLO_DotDimensionNumbers:$bmm1_grad_gemm2_dot_dimension_numbers, MHLO_DotDimensionNumbers:$bmm2_grad_gemm1_dot_dimension_numbers, MHLO_DotDimensionNumbers:$bmm2_grad_gemm2_dot_dimension_numbers, + I64ArrayAttr:$intermediate_tensor_dimensions, + I64ArrayAttr:$intermediate_tensor_layout, F64Attr:$fmha_scale, FusedMhaBackwardDagSignatureAttr:$fused_mha_dag, FusedMHAAlgorithmConfigAttr:$algorithm_config, OptionalAttr:$dropout_rate, - OptionalAttr:$seed); + OptionalAttr:$seed, + BoolAttr:$is_flash_attention, + BoolAttr:$is_causal_mask + ); } def LHLOGPU_RadixSortOp: LHLOGPU_Op<"radix_sort", [SameVariadicOperandSize]> { diff --git a/lhlo_gpu/IR/lhlo_gpu_ops_enums.td b/lhlo_gpu/IR/lhlo_gpu_ops_enums.td index 8ab0646a4..7ce614e43 100644 --- a/lhlo_gpu/IR/lhlo_gpu_ops_enums.td +++ b/lhlo_gpu/IR/lhlo_gpu_ops_enums.td @@ -153,6 +153,8 @@ def FusedMhaBackwardDagScaleBiasSoftmaxDropout : I32EnumAttrCase<"BackwardScaleB def FusedMhaBackwardDagScaleBiasSoftmax : I32EnumAttrCase<"BackwardScaleBiasSoftmax", 1>; def FusedMhaBackwardDagScaleBiasMaskSoftmax : I32EnumAttrCase<"BackwardScaleBiasMaskSoftmax", 2>; def FusedMhaBackwardDagScaleBiasMaskSoftmaxDropout : I32EnumAttrCase<"BackwardScaleBiasMaskSoftmaxDropout", 3>; +def FusedMhaBackwardDagSoftmax : I32EnumAttrCase<"BackwardSoftmax", 4>; +def FusedMhaBackwardDagSoftmaxDropout : I32EnumAttrCase<"BackwardSoftmaxDropout", 5>; def FusedMhaDagSignature: I32EnumAttr<"FusedMhaDagSignature", "DAG configuration for Fused Multi-Headed Attention", @@ -175,11 +177,13 @@ def FusedMhaBackwardDagSignature: I32EnumAttr<"FusedMhaBackwardDagSignature", FusedMhaBackwardDagScaleBiasSoftmaxDropout, FusedMhaBackwardDagScaleBiasSoftmax, FusedMhaBackwardDagScaleBiasMaskSoftmax, - FusedMhaBackwardDagScaleBiasMaskSoftmaxDropout]> { + FusedMhaBackwardDagScaleBiasMaskSoftmaxDropout, + FusedMhaBackwardDagSoftmax, + FusedMhaBackwardDagSoftmaxDropout]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::lmhlo_gpu"; } def FusedMhaDagSignatureAttr : EnumAttr; def FusedMhaBackwardDagSignatureAttr : EnumAttr; -#endif // LHLO_GPU_OPS_ENUMS \ No newline at end of file +#endif // LHLO_GPU_OPS_ENUMS