Skip to content

Commit

Permalink
[XPU] Fixed the mode error in pad3d (PaddlePaddle#9506)
Browse files Browse the repository at this point in the history
  • Loading branch information
wbn03 authored and newway committed Nov 19, 2022
1 parent a16416f commit c5f8551
Showing 1 changed file with 64 additions and 43 deletions.
107 changes: 64 additions & 43 deletions lite/kernels/xpu/pad3d_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,50 +36,71 @@ void Pad3dCompute<T>::Run() {
auto* in_data = x->template data<T>();
auto* out = param.Out;
T* out_data = out->template mutable_data<T>(TARGET(kXPU));
bool is_ncdhw;
int n, c, d, h, w;
if (data_format == "NCDHW") {
is_ncdhw = true;
n = in_dims[0];
c = in_dims[1];
d = in_dims[2];
h = in_dims[3];
w = in_dims[4];
} else if (data_format == "NDHWC") {
is_ncdhw = false;
n = in_dims[0];
c = in_dims[4];
d = in_dims[1];
h = in_dims[2];
w = in_dims[3];
} else {
LOG(FATAL) << "xpu unsupport data_format: " << data_format;
}
// trans pad format
std::vector<int> padding(6);
padding[0] = pads[4];
padding[1] = pads[5];
padding[2] = pads[2];
padding[3] = pads[3];
padding[4] = pads[0];
padding[5] = pads[1];

if (mode == "reflect" || mode == "constant" || mode == "replicate" ||
mode == "circular") {
if (data_format == "NCDHW") {
std::vector<int> pad_left = {0, 0, pads[4], pads[2], pads[0]};
std::vector<int> pad_right = {0, 0, pads[5], pads[3], pads[1]};

int n_shape = in_dims[0];
int c_shape = in_dims[1];
int d_shape = in_dims[2];
int h_shape = in_dims[3];
int w_shape = in_dims[4];

std::vector<int> xshape = {n_shape, c_shape, d_shape, h_shape, w_shape};

int r = xdnn::pad<T>(ctx.GetRawContext(),
in_data,
out_data,
xshape,
pad_left,
pad_right,
value);
CHECK_EQ(r, 0);
} else if (data_format == "NDHWC") {
std::vector<int> pad_left = {0, pads[4], pads[2], pads[0], 0};
std::vector<int> pad_right = {0, pads[5], pads[3], pads[1], 0};

int n_shape = in_dims[0];
int d_shape = in_dims[1];
int h_shape = in_dims[2];
int w_shape = in_dims[3];
int c_shape = in_dims[4];
std::vector<int> xshape = {n_shape, d_shape, h_shape, w_shape, c_shape};

int r = xdnn::pad<T>(ctx.GetRawContext(),
in_data,
out_data,
xshape,
pad_left,
pad_right,
value);
CHECK_EQ(r, 0);
}

if (mode == "constant") {
int r = xdnn::constant_pad3d<T>(ctx.GetRawContext(),
in_data,
out_data,
n,
c,
d,
h,
w,
padding,
value,
is_ncdhw);
CHECK_EQ(r, 0);
} else if (mode == "reflect") {
int r = xdnn::reflection_pad3d<T>(ctx.GetRawContext(),
in_data,
out_data,
n,
c,
d,
h,
w,
padding,
is_ncdhw);
CHECK_EQ(r, 0);
} else if (mode == "replicate") {
int r = xdnn::replication_pad3d<T>(ctx.GetRawContext(),
in_data,
out_data,
n,
c,
d,
h,
w,
padding,
is_ncdhw);
CHECK_EQ(r, 0);
} else {
LOG(FATAL) << "xpu unsupport mode: " << mode;
}
Expand Down

0 comments on commit c5f8551

Please sign in to comment.