Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XPU] Add multi-head self/cross attention fused ops. #10037

Merged
merged 1 commit into from
Mar 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lite/api/paddle_use_passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ USE_MIR_PASS(assign_value_calc_offline_pass);
USE_MIR_PASS(__xpu__graph_dedup_pass);
USE_MIR_PASS(__xpu__resnet_fuse_pass);
USE_MIR_PASS(__xpu__gn_silu_fuse_pass);
USE_MIR_PASS(__xpu__multihead_cross_attn_fuse_pass);
USE_MIR_PASS(__xpu__multihead_self_attn_fuse_pass);
USE_MIR_PASS(__xpu__multi_encoder_fuse_pass);
USE_MIR_PASS(__xpu__embedding_with_eltwise_add_fuse_pass);
USE_MIR_PASS(__xpu__fc_fuse_pass);
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions lite/core/optimizer/optimizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ std::unique_ptr<RuntimeProgram> RunDefaultOptimizer(
"__xpu__mmdnn_fuse_pass",
"__xpu__bigru_fuse_pass",
"__xpu__roformer_relative_pos_fuse_pass",
"__xpu__multihead_self_attn_fuse_pass",
"__xpu__multihead_cross_attn_fuse_pass",
"__xpu__quick_gelu_fuse_pass",
"__xpu__gn_silu_fuse_pass",
"__xpu__multi_encoder_fuse_pass",
Expand Down
3 changes: 3 additions & 0 deletions lite/kernels/xpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ add_kernel(__xpu__bigru_compute_xpu XPU extra SRCS __xpu__bigru_compute.cc)
add_kernel(__xpu__dynamic_lstm_compute_xpu XPU extra SRCS __xpu__dynamic_lstm_compute.cc)
add_kernel(__xpu__multi_softmax_compute_xpu XPU extra SRCS __xpu__multi_softmax_compute.cc)
add_kernel(__xpu__gn_silu_compute_xpu XPU extra SRCS __xpu__gn_silu_compute.cc)
add_kernel(__xpu__multihead_self_attn_compute_xpu XPU extra SRCS __xpu__multihead_self_attn_compute.cc)
add_kernel(__xpu__multihead_cross_attn_compute_xpu XPU extra SRCS __xpu__multihead_cross_attn_compute.cc)

if(XPU_WITH_XFT)
add_kernel(fusion_decoding_compute_xpu XPU extra SRCS fusion_decoding_compute.cc)
add_kernel(fusion_unified_decoding_compute_xpu XPU extra SRCS fusion_unified_decoding_compute.cc)
Expand Down
150 changes: 150 additions & 0 deletions lite/kernels/xpu/__xpu__multihead_cross_attn_compute.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/xpu/__xpu__multihead_cross_attn_compute.h"
#include <vector>
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {

template <typename T>
static std::vector<const T*> prepare_weight(
const std::vector<lite::Tensor*>& fc_weight) {
std::vector<const T*> result;
for (auto* weight : fc_weight) {
result.push_back(reinterpret_cast<const T*>(weight->data<float>()));
}
return result;
}

template <typename InType, PrecisionType PType>
void XPUMhcaCompute<InType, PType>::PrepareWeightMax(
const std::vector<lite::Tensor*>& weight_max,
int max_ptr_len,
std::vector<const float*>* max_xpu_ptrs) {
int max_value_num = 0;
for (auto max_tensor : weight_max) {
max_value_num += max_tensor->numel();
}
VLOG(3) << "Total weight max value number: " << max_value_num;
weight_max_guard_ =
TargetWrapperXPU::MallocScratchPad(max_value_num * sizeof(float));
float* weight_max_ptr = reinterpret_cast<float*>(weight_max_guard_->addr_);

int offset = 0;
for (auto max_tensor : weight_max) {
float* cur_weight_max_ptr = weight_max_ptr + offset;
auto len = max_tensor->numel();
VLOG(6) << "weight max value: " << max_tensor->data<float>()[0] << " "
<< max_tensor->data<float>()[len - 1];
std::vector<float> cpu_max(max_ptr_len, max_tensor->data<float>()[0]);
lite::TargetWrapperXPU::MemcpySync(cur_weight_max_ptr,
cpu_max.data(),
sizeof(float) * max_ptr_len,
IoDirection::HtoD);
max_xpu_ptrs->push_back(cur_weight_max_ptr);
offset += max_ptr_len;
}
}

template <typename InType, PrecisionType PType>
void XPUMhcaCompute<InType, PType>::PrepareForRun() {
auto& ctx = this->ctx_->template As<XPUContext>();
auto& param = this->template Param<param_t>();
// prepare bias
for (auto* fc_bias : param.fc_bias) {
arg_fc_bias_.push_back(fc_bias->template data<float>());
}
// prepare scale
for (auto* ln_scale : param.ln_scale) {
arg_ln_scale_.push_back(ln_scale->template data<float>());
}
// prepare ln_bias
for (auto* ln_bias : param.ln_bias) {
arg_ln_bias_.push_back(ln_bias->template data<float>());
}
arg_fc_weight_int16_ = prepare_weight<int16_t>(param.fc_weight);
const int XPU_QUANT_SCALE_NUM = ctx.GetRawContext()->max_ptr_size();
PrepareWeightMax(param.weight_max, XPU_QUANT_SCALE_NUM, &fc_weight_max_);
}

template <typename InType, PrecisionType PType>
void XPUMhcaCompute<InType, PType>::Run() {
// TODO(shenyijun): The compute of this op will be adapted to XFT interface
// later on.
//
// auto& param = this->template Param<param_t>();
// auto& ctx = this->ctx_->template As<XPUContext>();
// const InType* in = param.input->template data<InType>();
// const InType* embedding = param.embedding->template data<InType>();
// InType* out = param.output->template mutable_data<InType>(TARGET(kXPU));
// int batch = static_cast<int>(param.input->dims()[0]);
// int seqlen = static_cast<int>(param.input->dims()[1]);
// int embedding_seq = static_cast<int>(param.embedding->dims()[1]);
// int r = xdnn::unet_mhca_fusion<InType, int16_t, InType, int16_t>(
// ctx.GetRawContext(),
// in,
// embedding,
// *(XPUMhcaCompute::GetWeight<int16_t>()),
// out,
// arg_fc_bias_,
// arg_ln_scale_,
// arg_ln_bias_,
// fc_weight_max_,
// batch,
// param.head_num,
// param.size_per_head,
// seqlen,
// param.hidden_dim,
// embedding_seq,
// param.embedding_dim);
// CHECK_EQ(r, 0);
}

} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle

namespace xpu = paddle::lite::kernels::xpu;

using XPUMhca_FP32 = xpu::XPUMhcaCompute<float, PRECISION(kFloat)>;
using XPUMhca_FP16 = xpu::XPUMhcaCompute<float16, PRECISION(kFP16)>;

REGISTER_LITE_KERNEL(
__xpu__multihead_cross_attn, kXPU, kFloat, kNCHW, XPUMhca_FP32, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Embedding", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("FCWeight", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("FCBias", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("LNScale", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("LNBias", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
REGISTER_LITE_KERNEL(
__xpu__multihead_cross_attn, kXPU, kFP16, kNCHW, XPUMhca_FP16, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
.BindInput("Embedding",
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
.BindInput("FCWeight", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("FCBias", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("LNScale", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("LNBias", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
.Finalize();
62 changes: 62 additions & 0 deletions lite/kernels/xpu/__xpu__multihead_cross_attn_compute.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <vector>
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {

template <typename InType, PrecisionType PType>
class XPUMhcaCompute : public KernelLite<TARGET(kXPU), PType> {
public:
using param_t = operators::XPUMhcaParam;

virtual void PrepareForRun();

virtual void Run();

virtual ~XPUMhcaCompute() = default;

private:
std::vector<const int16_t *> arg_fc_weight_int16_;
std::vector<const float *> arg_fc_bias_;
std::vector<const float *> arg_ln_scale_;
std::vector<const float *> arg_ln_bias_;
std::vector<const float *> fc_weight_max_;
XPUScratchPadGuard weight_max_guard_;

template <typename T>
std::vector<const T *> *GetWeight() {
LOG(FATAL) << "Invalid Weight Type";
return nullptr;
}

std::vector<const int16_t *> *GetWeight() { return &arg_fc_weight_int16_; }

void PrepareWeightMax(const std::vector<lite::Tensor *> &weight_max,
int max_ptr_len,
std::vector<const float *> *max_xpu_ptrs);
};

} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
141 changes: 141 additions & 0 deletions lite/kernels/xpu/__xpu__multihead_self_attn_compute.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/xpu/__xpu__multihead_self_attn_compute.h"
#include <vector>
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {

template <typename T>
static std::vector<const T*> prepare_weight(
const std::vector<lite::Tensor*>& fc_weight) {
std::vector<const T*> result;
for (auto* weight : fc_weight) {
result.push_back(reinterpret_cast<const T*>(weight->data<float>()));
}
return result;
}

template <typename InType, PrecisionType PType>
void XPUMhsaCompute<InType, PType>::PrepareWeightMax(
const std::vector<lite::Tensor*>& weight_max,
int max_ptr_len,
std::vector<const float*>* max_xpu_ptrs) {
int max_value_num = 0;
for (auto max_tensor : weight_max) {
max_value_num += max_tensor->numel();
}
VLOG(3) << "Total weight max value number: " << max_value_num;
weight_max_guard_ =
TargetWrapperXPU::MallocScratchPad(max_value_num * sizeof(float));
float* weight_max_ptr = reinterpret_cast<float*>(weight_max_guard_->addr_);

int offset = 0;
for (auto max_tensor : weight_max) {
float* cur_weight_max_ptr = weight_max_ptr + offset;
auto len = max_tensor->numel();
VLOG(6) << "weight max value: " << max_tensor->data<float>()[0] << " "
<< max_tensor->data<float>()[len - 1];
std::vector<float> cpu_max(max_ptr_len, max_tensor->data<float>()[0]);
lite::TargetWrapperXPU::MemcpySync(cur_weight_max_ptr,
cpu_max.data(),
sizeof(float) * max_ptr_len,
IoDirection::HtoD);
max_xpu_ptrs->push_back(cur_weight_max_ptr);
offset += max_ptr_len;
}
}

template <typename InType, PrecisionType PType>
void XPUMhsaCompute<InType, PType>::PrepareForRun() {
auto& ctx = this->ctx_->template As<XPUContext>();
auto& param = this->template Param<param_t>();
// prepare bias
for (auto* fc_bias : param.fc_bias) {
arg_fc_bias_.push_back(fc_bias->template data<float>());
}
// prepare scale
for (auto* ln_scale : param.ln_scale) {
arg_ln_scale_.push_back(ln_scale->template data<float>());
}
// prepare ln_bias
for (auto* ln_bias : param.ln_bias) {
arg_ln_bias_.push_back(ln_bias->template data<float>());
}
arg_fc_weight_int16_ = prepare_weight<int16_t>(param.fc_weight);
const int XPU_QUANT_SCALE_NUM = ctx.GetRawContext()->max_ptr_size();
PrepareWeightMax(param.weight_max, XPU_QUANT_SCALE_NUM, &fc_weight_max_);
}

template <typename InType, PrecisionType PType>
void XPUMhsaCompute<InType, PType>::Run() {
// TODO(shenyijun): The compute of this op will be adapted to XFT interface
// later on.
//
// auto& param = this->template Param<param_t>();
// auto& ctx = this->ctx_->template As<XPUContext>();
// const InType* in = param.input->template data<InType>();
// InType* out = param.output->template mutable_data<InType>(TARGET(kXPU));
// int batch = static_cast<int>(param.input->dims()[0]);
// int seqlen = static_cast<int>(param.input->dims()[1]);
// int r = xdnn::unet_mhsa_fusion<InType, int16_t, InType, int16_t>(
// ctx.GetRawContext(),
// in,
// *(XPUMhsaCompute::GetWeight<int16_t>()),
// out,
// arg_fc_bias_,
// arg_ln_scale_,
// arg_ln_bias_,
// fc_weight_max_,
// batch,
// param.head_num,
// param.size_per_head,
// seqlen,
// param.hidden_dim);
// CHECK_EQ(r, 0);
}

} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle

namespace xpu = paddle::lite::kernels::xpu;

using XPUMhsa_FP32 = xpu::XPUMhsaCompute<float, PRECISION(kFloat)>;
using XPUMhsa_FP16 = xpu::XPUMhsaCompute<float16, PRECISION(kFP16)>;
REGISTER_LITE_KERNEL(
__xpu__multihead_self_attn, kXPU, kFloat, kNCHW, XPUMhsa_FP32, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("FCWeight", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("FCBias", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("LNScale", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("LNBias", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
REGISTER_LITE_KERNEL(
__xpu__multihead_self_attn, kXPU, kFP16, kNCHW, XPUMhsa_FP16, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
.BindInput("FCWeight", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("FCBias", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("LNScale", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("LNBias", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
.Finalize();
Loading