Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PaddleMIX ppdiffusers Stable Diffusion 3 inference optimize #681

Open
wants to merge 58 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
a6631e7
optimize SD3
chang-wenbin Aug 19, 2024
b0ea9ef
optimize SD3 transformer_SD3
chang-wenbin Aug 19, 2024
f06a61a
optimize SD3 transformer_SD3
chang-wenbin Aug 19, 2024
dcff90c
update SD3
chang-wenbin Aug 20, 2024
15c5e44
uodate triton &sim_SD3
chang-wenbin Aug 20, 2024
ab73a63
modify temb_silu && modify nvtx
chang-wenbin Aug 20, 2024
ed2b7b1
modify linear from fused_linear
chang-wenbin Aug 20, 2024
f4330d3
modify simplified_sd3
chang-wenbin Aug 20, 2024
cc1af0f
add split_concat triton kernel
chang-wenbin Aug 20, 2024
70e6b6e
modify split_concat triton kernel
chang-wenbin Aug 21, 2024
9543b11
update
chang-wenbin Aug 21, 2024
357b75a
update transformer_sd3
chang-wenbin Aug 21, 2024
f54bf84
update transformer_sd3
chang-wenbin Aug 21, 2024
3245b2f
update triton & simplified_sd3
chang-wenbin Aug 21, 2024
5516df6
update simplified_sd3
chang-wenbin Aug 22, 2024
874d5d7
update simplified_sd3
chang-wenbin Aug 22, 2024
111f4cd
delete context_pre_only=False
chang-wenbin Aug 22, 2024
18777b6
modify triton_optimize
chang-wenbin Aug 22, 2024
7a288e4
modify triton_optimize
chang-wenbin Aug 22, 2024
840b153
modify triton_optimize
chang-wenbin Aug 22, 2024
95c9e47
modify triton_fuse & Modifying performance issues affected by CUDA sy…
chang-wenbin Aug 22, 2024
84a9e7a
modify transformer_sd3 if optimize_prigin
chang-wenbin Aug 23, 2024
9dd918d
update vae triton_split
chang-wenbin Aug 23, 2024
3a0b7e1
vae T5 d2s & transformer forward d2s
chang-wenbin Aug 26, 2024
6d02d79
update demo
chang-wenbin Aug 26, 2024
5d81b44
update five model d2s
chang-wenbin Aug 26, 2024
4bab118
update SD3 clip T5 vae
chang-wenbin Aug 27, 2024
5a14a0f
update clip
chang-wenbin Aug 27, 2024
cd2ef01
uodate T5
chang-wenbin Aug 27, 2024
624168c
uodate T5
chang-wenbin Aug 27, 2024
b009b9f
update scheduling_flow_match_euler_discrete
chang-wenbin Aug 27, 2024
8caa10a
update normalization
chang-wenbin Aug 28, 2024
377629a
update normalization
chang-wenbin Aug 28, 2024
6863054
Merge remote-tracking branch 'upstream/develop' into SD3_PaddleMIX_819
chang-wenbin Aug 28, 2024
15fda4e
update SD3
chang-wenbin Aug 29, 2024
cb993c5
merge develop
chang-wenbin Aug 30, 2024
0e90eaf
update cutlass gemm&fast_gelu
chang-wenbin Sep 2, 2024
c5bb81f
update per-mmdit
chang-wenbin Sep 4, 2024
2c8cc85
merge develop
chang-wenbin Sep 4, 2024
499752a
update triton op split_concat
chang-wenbin Sep 4, 2024
1084f4a
update embeddings
chang-wenbin Sep 5, 2024
e3a5d7c
merge
chang-wenbin Sep 6, 2024
fa84559
recovery
chang-wenbin Sep 6, 2024
27c62f9
recovery
chang-wenbin Sep 6, 2024
951f7a6
merge
chang-wenbin Sep 6, 2024
9515323
update normalization
chang-wenbin Sep 10, 2024
d61e4cb
update dtype
chang-wenbin Sep 10, 2024
d961a4a
add SD3 doc
chang-wenbin Sep 10, 2024
ac1e139
merge develop
chang-wenbin Sep 18, 2024
48c66a6
update SD3 doc
chang-wenbin Sep 18, 2024
24c3c9e
add 'del transformer_blocks'
chang-wenbin Sep 19, 2024
422f33b
update SD3
chang-wenbin Sep 19, 2024
c43d84f
update SD3
chang-wenbin Sep 19, 2024
9d03624
update Notes
chang-wenbin Sep 19, 2024
ded06bf
add Notes
chang-wenbin Sep 19, 2024
d845da2
update demo
chang-wenbin Sep 19, 2024
db6aad1
update doc
chang-wenbin Sep 19, 2024
3527954
update SD3
chang-wenbin Sep 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddlemix/triton_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
fused_rotary_emb,
paddle_use_triton,
rms_norm,
split_concat,
triton_split,
weight_only_int8,
)
from .triton_utils import (
Expand All @@ -39,6 +41,8 @@
"rms_norm",
"get_dtype_str",
"fused_rotary_emb",
"split_concat",
"triton_split",
]
except:
pass
289 changes: 289 additions & 0 deletions paddlemix/triton_ops/triton_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1580,3 +1580,292 @@ def fused_rotary_emb(
outputs={"q_out": q_out, "k_out": k_out, "v_out": v_out},
)
return q_out, k_out, v_out


########################### split concat ###############################
split_concat_template = (
"""
std::vector<paddle::Tensor> ${op_name}_func(
const paddle::Tensor &x,
const paddle::Tensor &y) {

int batch = x.dims()[0];

int seq_qkv = x.dims()[1];
int seq_eqkv = y.dims()[1];
int output_hidden = x.dims()[2] / 3;


auto qkv = get_tensor_ptr(x);
auto eqkv = get_tensor_ptr(y);


auto out0_tensor = paddle::empty({batch, seq_qkv+seq_eqkv, output_hidden}, x.dtype(), x.place());
auto out1_tensor = paddle::empty({batch, seq_qkv+seq_eqkv, output_hidden}, x.dtype(), x.place());
auto out2_tensor = paddle::empty({batch, seq_qkv+seq_eqkv, output_hidden}, x.dtype(), x.place());

auto out0 = get_tensor_ptr(out0_tensor);
auto out1 = get_tensor_ptr(out1_tensor);
auto out2 = get_tensor_ptr(out2_tensor);


auto run_stream = out0_tensor.stream();

"""
+ tune_and_invoke_part
+ """
return {out0_tensor, out1_tensor, out2_tensor};
}

std::vector<std::vector<int64_t>> ${op_name}_InferShape(
const std::vector<int64_t>& A_shape, const std::vector<int64_t>& B_shape) {

std::vector<int64_t> out_shape = {A_shape[0], A_shape[1]+B_shape[1], A_shape[2]/3};

return {out_shape, out_shape, out_shape};
}

std::vector<paddle::DataType> ${op_name}_InferDtype(const paddle::DataType& A_dtype) {
return {A_dtype, A_dtype, A_dtype};
}

PD_BUILD_OP(${op_name})
.Inputs({"x", "y"})
.Outputs({"out0_tensor", "out1_tensor", "out2_tensor"})
.SetKernelFn(PD_KERNEL(${op_name}_func))
.SetInferDtypeFn(PD_INFER_DTYPE(${op_name}_InferDtype))
.SetInferShapeFn(PD_INFER_SHAPE(${op_name}_InferShape));
"""
)


@paddle_use_triton(
custom_op_template=split_concat_template,
key=["1"],
)
def split_concat_kernel(
out0,
out1,
out2,
qkv,
eqkv,
batch,
seq_qkv,
seq_eqkv,
output_hidden,
BLOCK_SIZE: tl.constexpr,
):
out_id = tl.program_id(axis=0)
batch = tl.program_id(axis=1)
out_row = tl.program_id(axis=2)
if out_row < seq_qkv:
read_ptr = out_id * output_hidden + out_row * 3 * output_hidden + batch * seq_qkv * output_hidden * 3 + qkv
else:
read_ptr = (
out_id * output_hidden
+ (out_row - seq_qkv) * 3 * output_hidden
+ batch * seq_eqkv * output_hidden * 3
+ eqkv
)

read_offsets = tl.arange(0, BLOCK_SIZE)
mask = read_offsets < output_hidden
read_data = tl.load(read_ptr + read_offsets, mask=mask)

real_output = out0
if out_id == 1:
real_output = out1
elif out_id == 2:
real_output = out2

write_ptr = batch * (seq_qkv + seq_eqkv) * output_hidden + out_row * output_hidden + real_output + read_offsets

tl.store(write_ptr, read_data, mask=mask)


def split_concat(x, y):
assert len(x.shape) == 3
assert len(y.shape) == 3

assert x.shape[0] == y.shape[0]
assert x.shape[2] == y.shape[2]

batch = x.shape[0]
seq_qkv = x.shape[1]
hidd_x = x.shape[2]
seq_eqkv = y.shape[1]
ouput_hidden = hidd_x // 3
BLOCK_SIZE = triton.next_power_of_2(ouput_hidden)
op_name = "split_concat"
op_name += get_dtype_str(x.dtype)
op_name += f"_{BLOCK_SIZE}"

if op_name not in OpProtoHolder.instance().op_proto_map.keys():
out0 = paddle.empty(shape=[batch, seq_qkv + seq_eqkv, ouput_hidden], dtype=x.dtype)
out1 = paddle.empty(shape=[batch, seq_qkv + seq_eqkv, ouput_hidden], dtype=x.dtype)
out2 = paddle.empty(shape=[batch, seq_qkv + seq_eqkv, ouput_hidden], dtype=x.dtype)
grid = ("3", "batch", "seq_qkv + seq_eqkv")

split_concat_kernel[(op_name, grid)](
out0, out1, out2, x, y, batch, seq_qkv, seq_eqkv, ouput_hidden, BLOCK_SIZE=BLOCK_SIZE
)

if in_dynamic_or_pir_mode():
print(f"== we are in dynamic mode, op_name: {op_name}")
outs = _C_ops._run_custom_op(
op_name,
x,
y,
)
return outs[0], outs[1], outs[2]
else:
print(f"== we are in dynamic to static mode, op_name: {op_name}")
helper = LayerHelper(op_name, **locals())
inputs = {
"x": x,
"y": y,
}
out0 = helper.create_variable_for_type_inference(dtype=x.dtype)
out1 = helper.create_variable_for_type_inference(dtype=x.dtype)
out2 = helper.create_variable_for_type_inference(dtype=x.dtype)

helper.append_op(
type=op_name,
inputs=inputs,
outputs={"out0_tensor": out0, "out1_tensor": out1, "out2_tensor": out2},
)
return out0, out1, out2


########################### triton split ###############################
triton_split_template = (
"""
std::vector<paddle::Tensor> ${op_name}_func(
const paddle::Tensor &x,
const std::vector<int64_t> num_or_sections,
const int64_t axis) {

int output_batch = x.dims()[0];
int output_seq0 = num_or_sections[0];
int output_seq1 = num_or_sections[1];
int output_hidden = x.dims()[2];

auto out0_tensor = paddle::empty({output_batch, output_seq0, output_hidden}, x.dtype(), x.place());
auto out1_tensor = paddle::empty({output_batch, output_seq1, output_hidden}, x.dtype(), x.place());

auto out0 = get_tensor_ptr(out0_tensor);
auto out1 = get_tensor_ptr(out1_tensor);

auto input = get_tensor_ptr(x);

auto run_stream = out0_tensor.stream();

"""
+ tune_and_invoke_part
+ """
return {out0_tensor, out1_tensor};
}

std::vector<std::vector<int64_t>> ${op_name}_InferShape(
const std::vector<int64_t>& A_shape) {

std::vector<int64_t> out_shape0 = {A_shape[0], 1024, A_shape[2]};
std::vector<int64_t> out_shape1 = {A_shape[0], 154, A_shape[2]};

return {out_shape0, out_shape1};
}

std::vector<paddle::DataType> ${op_name}_InferDtype(const paddle::DataType& A_dtype) {
return {A_dtype, A_dtype};
}

PD_BUILD_OP(${op_name})
.Inputs({"x"})
.Outputs({"out0_tensor", "out1_tensor"})
.SetKernelFn(PD_KERNEL(${op_name}_func))
.Attrs({"num_or_sections: std::vector<int64_t>", "axis: int64_t"})
.SetInferDtypeFn(PD_INFER_DTYPE(${op_name}_InferDtype))
.SetInferShapeFn(PD_INFER_SHAPE(${op_name}_InferShape));
"""
)


@paddle_use_triton(
custom_op_template=triton_split_template,
key=["1"],
)
def triton_split_kernel(
out0,
out1,
input,
output_seq0,
output_seq1,
output_batch,
output_hidden,
BLOCK_SIZE: tl.constexpr,
):
batch = tl.program_id(axis=0)
out_row = tl.program_id(axis=1)
read_ptr = out_row * output_hidden + batch * (output_seq0 + output_seq1) * output_hidden + input

read_offsets = tl.arange(0, BLOCK_SIZE)
mask = read_offsets < output_hidden
read_data = tl.load(read_ptr + read_offsets, mask=mask)

if out_row < output_seq0:
write_ptr = batch * output_seq0 * output_hidden + out_row * output_hidden + out0 + read_offsets
else:
write_ptr = batch * output_seq1 * output_hidden + (out_row - output_seq0) * output_hidden + out1 + read_offsets

tl.store(write_ptr, read_data, mask=mask)


def triton_split(x, num_or_sections=[-1, -1], axis=1):
assert len(x.shape) == 3
output_batch = x.shape[0]
output_seq0 = num_or_sections[0]
output_seq1 = num_or_sections[1]
output_hidden = x.shape[2]

BLOCK_SIZE = triton.next_power_of_2(output_hidden)
op_name = "triton_split"
op_name += get_dtype_str(x.dtype)
op_name += f"_{BLOCK_SIZE}"

if op_name not in OpProtoHolder.instance().op_proto_map.keys():
out0 = paddle.empty(shape=[output_batch, output_seq0, output_hidden], dtype=x.dtype)
out1 = paddle.empty(shape=[output_batch, output_seq1, output_hidden], dtype=x.dtype)
grid = ("output_batch", "output_seq0+output_seq1")

triton_split_kernel[(op_name, grid)](
out0, out1, x, output_seq0, output_seq1, output_batch, output_hidden, BLOCK_SIZE=2048
)

if in_dynamic_or_pir_mode():
print(f"== we are in dynamic mode, op_name: {op_name}")
outs = _C_ops._run_custom_op(
op_name,
x,
num_or_sections,
axis,
)
return outs[0], outs[1]
else:
print(f"== we are in dynamic to static mode, op_name: {op_name}")
helper = LayerHelper(op_name, **locals())
inputs = {
"x": x,
}
out0 = helper.create_variable_for_type_inference(dtype=x.dtype)
out1 = helper.create_variable_for_type_inference(dtype=x.dtype)

helper.append_op(
type=op_name,
inputs=inputs,
attrs={
"num_or_sections": num_or_sections,
"axis": axis,
},
outputs={"out0_tensor": out0, "out1_tensor": out1},
)
return out0, out1
43 changes: 43 additions & 0 deletions ppdiffusers/deploy/sd3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Stable Diffusion 3 高性能推理

- Paddle Inference提供Stable Diffusion 3 模型高性能推理实现,推理性能提升70%+
环境准备:
```shell
# 安装 triton并适配paddle
python -m pip install triton
python -m pip install git+https://github.com/zhoutianzi666/UseTritonInPaddle.git
python -c "import use_triton_in_paddle; use_triton_in_paddle.make_triton_compatible_with_paddle()"

# 安装develop版本的paddle
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu123/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里加一句,请使用2024年9月6日之后的PaddleNLP,因为在该天,我们修复了一个针对PaddleNLP的bug。
https://github.com/PaddlePaddle/PaddleNLP/pull/9016/files

# 安装PaddleNLP,请使用2024年9月6日之后的PaddleNLP,因为在该天,我们修复了一个针对PaddleNLP的bug。
# https://github.com/PaddlePaddle/PaddleNLP/pull/9016/files
python -m pip install paddlenlp==3.0.0b1

# 指定Tensor-RT的lib路径
export LD_LIBRARY_PATH=/your_TensorRT_dir//lib:$LD_LIBRARY_PATH

# 指定cutlass包路径
export LD_LIBRARY_PATH=/your_dir/Paddle/paddle/phi/kernels/fusion/cutlass/conv2d/build:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=/your_dir/Paddle/paddle/phi/kernels/fusion/cutlass/gemm_epilogue/build:$LD_LIBRARY_PATH
```

高性能推理指令:
```shell
# step1: 生成FP32的paddle模型,同时根据Paddle模型生成FP16的TensorRT engine。
python text_to_image_generation-stable_diffusion_3.py --dtype float32 --height 512 --width 512 \
--num-inference-steps 50 --inference_optimize 1 \
--benchmark 1

# step2: 执行FP16推理
python text_to_image_generation-stable_diffusion_3.py --dtype float16 --height 512 --width 512 \
--num-inference-steps 50 --inference_optimize 1 \
--benchmark 1
```

- 在 NVIDIA A100-SXM4-40GB 上测试的性能如下:

| Paddle Inference| PyTorch | Paddle 动态图 |
| --------------- | ------------ | ------------ |
| 1.2 s | 1.78 s | 4.202 s |
Loading