Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry pick] full api fp16 support and quant_dequant_pass fix #9654

Merged
merged 4 commits into from
Nov 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions lite/api/cxx_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@

#include "lite/api/paddle_use_passes.h"
#include "lite/utils/io.h"
#ifdef ENABLE_ARM_FP16
#include "lite/backends/arm/math/fp16/type_trans_fp16.h"
#endif

namespace paddle {
namespace lite {
Expand Down Expand Up @@ -297,6 +300,54 @@ const cpp::ProgramDesc &Predictor::program_desc() const {
}
const RuntimeProgram &Predictor::runtime_program() const { return *program_; }

#ifdef ENABLE_ARM_FP16
typedef __fp16 float16_t;
void Predictor::WeightFP32ToFP16() {
std::shared_ptr<const cpp::ProgramDesc> program_desc = program_desc_;
std::vector<std::string> fp16_ops{"conv2d",
"depthwise_conv2d",
"conv2d_transpose",
"fc",
"mul",
"matmul",
"matmul_v2",
"gru",
"sequence_conv",
"elementwise_add",
"elementwise_sub",
"elementwise_div",
"elementwise_mul",
"prelu"};
for (size_t i = 0; i < program_desc->BlocksSize(); i++) {
auto *block = program_desc->GetBlock<cpp::BlockDesc>(i);
for (size_t k = 0; k < block->OpsSize(); ++k) {
auto *op_desc = block->GetOp<cpp::OpDesc>(k);
std::string op_type = op_desc->Type();
auto iter = std::find(fp16_ops.begin(), fp16_ops.end(), op_type);
if (iter != fp16_ops.end()) {
auto input_names = op_desc->input_vars();
for (auto &input_name : input_names) {
std::string input_weight_name = input_name + "_fp16";
if (op_desc->HasAttr(input_weight_name)) { // the input is fp16
Tensor tmp_tensor;
auto input_tensor =
scope_->FindVar(input_name)->GetMutable<lite::Tensor>();
if (input_tensor->precision() != PRECISION(kFloat)) continue;
tmp_tensor.CopyDataFrom(*input_tensor);
input_tensor->clear();
input_tensor->set_precision(PRECISION(kFP16));
float16_t *fp_data = input_tensor->mutable_data<float16_t>();
const float *in_data = tmp_tensor.data<float>();
lite::arm::math::fp16::fp32_to_fp16(
in_data, fp_data, input_tensor->numel());
}
}
}
}
}
}
#endif // ENABLE_ARM_FP16

void Predictor::Build(const lite_api::CxxConfig &config,
const std::vector<Place> &valid_places,
const std::vector<std::string> &passes,
Expand Down Expand Up @@ -413,6 +464,11 @@ void Predictor::Build(const std::shared_ptr<cpp::ProgramDesc> &program_desc,

// Update the runtime program to program_desc only once
program_->SaveRuntimProgramIntoProgramDesc(program_desc_);

#ifdef ENABLE_ARM_FP16
// fp16 Weight convert
WeightFP32ToFP16();
#endif
}

void Predictor::GenRuntimeProgram() {
Expand Down
3 changes: 3 additions & 0 deletions lite/api/cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ class LITE_API Predictor {

void ClearTensorArray(
const std::shared_ptr<const cpp::ProgramDesc>& program_desc);
#ifdef ENABLE_ARM_FP16
void WeightFP32ToFP16();
#endif

private:
std::shared_ptr<cpp::ProgramDesc> program_desc_;
Expand Down
7 changes: 5 additions & 2 deletions lite/core/optimizer/mir/fusion/quant_dequant_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,11 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {

// process new quant op pass: quantize_linear and dequantize_linear
// pass1: input+quantize_linear+dequantize_linear --> input
fusion::QuantDequantLinearOpFuser quant_dequant_linear_fuser;
quant_dequant_linear_fuser(graph.get());
for (auto share_zero_point : {true, false}) {
fusion::QuantDequantLinearOpFuser quant_dequant_linear_fuser(
share_zero_point);
quant_dequant_linear_fuser(graph.get());
}
// pass2: weight+dequantize_linear --> weight
fusion::DequantLinearOpFuser dequantize_linear_fuser;
dequantize_linear_fuser(graph.get());
Expand Down
20 changes: 14 additions & 6 deletions lite/core/optimizer/mir/fusion/quant_dequant_op_fuser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -640,9 +640,6 @@ void QuantDequantLinearOpFuser::BuildPattern() {
->assert_is_op_input("quantize_linear", "ZeroPoint");
auto* quant_op_output =
VarNode("quant_op_output")->assert_is_op_output("quantize_linear", "Y");
auto* dequant_op_zero_point =
VarNode("dequant_op_zero_point")
->assert_is_op_input("dequantize_linear", "ZeroPoint");
auto* dequant_op_out =
VarNode("dequant_op_out")->assert_is_op_output("dequantize_linear", "Y");

Expand All @@ -653,9 +650,19 @@ void QuantDequantLinearOpFuser::BuildPattern() {

quant_op->LinksFrom({quant_op_input, quant_op_scale, quant_op_zero_point})
.LinksTo({quant_op_output});
dequant_op
->LinksFrom({quant_op_output, quant_op_scale, dequant_op_zero_point})
.LinksTo({dequant_op_out});

if (shared_zero_point_) {
dequant_op
->LinksFrom({quant_op_output, quant_op_scale, quant_op_zero_point})
.LinksTo({dequant_op_out});
} else {
auto* dequant_op_zero_point =
VarNode("dequant_op_zero_point")
->assert_is_op_input("dequantize_linear", "ZeroPoint");
dequant_op
->LinksFrom({quant_op_output, quant_op_scale, dequant_op_zero_point})
.LinksTo({dequant_op_out});
}
VLOG(4) << "QuantDequantLinearOpFuser";
}

Expand Down Expand Up @@ -705,6 +712,7 @@ void QuantDequantLinearOpFuser::InsertNewNode(SSAGraph* graph,
if (!out_scale_node->IsStmt()) continue;
auto* out_scale_scope = out_scale_node->stmt()->op()->scope();
auto* out_scale_op_info = out_scale_node->stmt()->op_info();
if (out_scale_op_info->Type() != "quantize_linear") continue;
if (!out_scale_op_info->HasInput("Scale")) continue;
std::string out_scale_name = out_scale_op_info->Input("Scale").front();
auto* out_scale_tensor =
Expand Down
5 changes: 4 additions & 1 deletion lite/core/optimizer/mir/fusion/quant_dequant_op_fuser.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ class DynamicQuantOpFuser : public FuseBase {
*/
class QuantDequantLinearOpFuser : public FuseBase {
public:
QuantDequantLinearOpFuser() {}
explicit QuantDequantLinearOpFuser(const bool shared_zero_point) {
shared_zero_point_ = shared_zero_point;
}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;

Expand All @@ -152,6 +154,7 @@ class QuantDequantLinearOpFuser : public FuseBase {
"mul",
"matmul",
"matmul_v2"};
bool shared_zero_point_{};
};

/* The pattern like "dequantize_linear_op + quantized_op "
Expand Down
13 changes: 9 additions & 4 deletions lite/core/optimizer/mir/memory_optimize_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,17 @@ void MemoryOptimizePass::CollectLifeCycleByDevice(
}
if (inplace) {
for (auto& in_param_name : inplace_op_node->second.first) {
const auto& in_arg_names = op_info->Input(in_param_name);
invalid_var_names.insert(in_arg_names.begin(), in_arg_names.end());
if (op_info->HasInput(in_param_name)) {
const auto& in_arg_names = op_info->Input(in_param_name);
invalid_var_names.insert(in_arg_names.begin(), in_arg_names.end());
}
}
for (auto& out_param_name : inplace_op_node->second.second) {
const auto& out_arg_names = op_info->Output(out_param_name);
invalid_var_names.insert(out_arg_names.begin(), out_arg_names.end());
if (op_info->HasOutput(out_param_name)) {
const auto& out_arg_names = op_info->Output(out_param_name);
invalid_var_names.insert(out_arg_names.begin(),
out_arg_names.end());
}
}
}
}
Expand Down