Skip to content

Commit

Permalink
[xpu] refactor fc int31 for KL2; test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
newway committed Nov 1, 2021
1 parent f1eba3a commit 3a06aeb
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions lite/kernels/xpu/__xpu__fc_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,26 +126,26 @@ void XPUFcCompute::Run() {
}
// TODO(weihaoji): remove fc_int31 and fc_int16 after xpu fc wrapper refactor
if (param.precision == "int31") {
int r = xdnn::fc_int31(
ctx.GetRawContext(), /* context */
false, /* TransA */
true, /* TransB */
m, /* m */
n, /* n */
k, /* k */
1.0f, /* alpha */
param.input->data<float>(), /* A */
nullptr, /* max_a ptr */
reinterpret_cast<const float*>(quant_weight_guard_->addr_), /* B */
w_max, /* max_b */
0.0f, /* beta */
param.output->mutable_data<float>(TARGET(kXPU)), /* C */
nullptr, /* max_c ptr */
bias, /* bias */
act /* act_type */);
CHECK_EQ(r, 0);
r = xdnn::findmax<float>(
ctx.GetRawContext(), param.output->data<float>(), m * n, output_max);
int r = xdnn::fc_fusion<float, float, float, int>(
ctx.GetRawContext(), // ctx
param.input->data<float>(), // x
reinterpret_cast<const float*>(quant_weight_guard_->addr_), // w
param.output->mutable_data<float>(TARGET(kXPU)), // y
m, // m
n, // n
k, // k
false, // x_trans
true, // w_trans
input_max, // x_maxptr
reinterpret_cast<const float*>(weight_max_guard_->addr_), // w_maxptr
output_max, // y_maxptr
k, // ldx
k, // ldw
n, // ldy
1.0f, // alpha
0.0f, // beta
bias, // bias
act);
CHECK_EQ(r, 0);
} else if (param.precision == "int16") {
int r = 0;
Expand Down

0 comments on commit 3a06aeb

Please sign in to comment.