From 3a06aebb6ae49f7cdd16d951752dc0f23efd7754 Mon Sep 17 00:00:00 2001 From: newway Date: Mon, 1 Nov 2021 14:20:43 +0800 Subject: [PATCH] [xpu] refactor fc int31 for KL2; test=develop --- lite/kernels/xpu/__xpu__fc_compute.cc | 40 +++++++++++++-------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/lite/kernels/xpu/__xpu__fc_compute.cc b/lite/kernels/xpu/__xpu__fc_compute.cc index ac257a7f9b9..02cc66b6273 100644 --- a/lite/kernels/xpu/__xpu__fc_compute.cc +++ b/lite/kernels/xpu/__xpu__fc_compute.cc @@ -126,26 +126,26 @@ void XPUFcCompute::Run() { } // TODO(weihaoji): remove fc_int31 and fc_int16 after xpu fc wrapper refactor if (param.precision == "int31") { - int r = xdnn::fc_int31( - ctx.GetRawContext(), /* context */ - false, /* TransA */ - true, /* TransB */ - m, /* m */ - n, /* n */ - k, /* k */ - 1.0f, /* alpha */ - param.input->data(), /* A */ - nullptr, /* max_a ptr */ - reinterpret_cast(quant_weight_guard_->addr_), /* B */ - w_max, /* max_b */ - 0.0f, /* beta */ - param.output->mutable_data(TARGET(kXPU)), /* C */ - nullptr, /* max_c ptr */ - bias, /* bias */ - act /* act_type */); - CHECK_EQ(r, 0); - r = xdnn::findmax( - ctx.GetRawContext(), param.output->data(), m * n, output_max); + int r = xdnn::fc_fusion( + ctx.GetRawContext(), // ctx + param.input->data(), // x + reinterpret_cast(quant_weight_guard_->addr_), // w + param.output->mutable_data(TARGET(kXPU)), // y + m, // m + n, // n + k, // k + false, // x_trans + true, // w_trans + input_max, // x_maxptr + reinterpret_cast(weight_max_guard_->addr_), // w_maxptr + output_max, // y_maxptr + k, // ldx + k, // ldw + n, // ldy + 1.0f, // alpha + 0.0f, // beta + bias, // bias + act); CHECK_EQ(r, 0); } else if (param.precision == "int16") { int r = 0;