Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OpenCL]support opencl expand #8078

Merged
merged 6 commits into from
Dec 31, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions lite/backends/opencl/cl_kernel/image/expand_kernel.cl
Original file line number Diff line number Diff line change
Expand Up @@ -201,25 +201,23 @@ __kernel void expend_cn(__private const int OUT_C,
__private const int output_height,

__read_only image2d_t input,
__write_only image2d_t output,
__private const int n_times,
__private const int c_times,
__private const int h_times,
__private const int w_times) {
__write_only image2d_t output) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);

if (out_c >= OUT_C || out_w >= OUT_W || out_nh >= OUT_NH) {
return;
}

const int IN_N = IN_NH / input_height;
const int OUT_N = OUT_NH / output_height;
const int out_n = out_nh / output_height;
const int out_h = out_nh % output_height;
const int in_c = out_c;
const int in_w = out_w / w_times;
const int in_h = out_h / h_times;
const int in_n = out_n / n_times;
const int in_w = out_w % input_width;
const int in_h = out_h % input_height;
const int in_n = out_n % IN_N;

const int in_nh = in_n * input_height + in_h;

int2 output_pos = (int2)(out_c * OUT_W + out_w, out_nh);
Expand Down
40 changes: 20 additions & 20 deletions lite/kernels/opencl/expand_image_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,18 @@ class ExpandComputeImage2D : public KernelLite<TARGET(kOpenCL),
CHECK(expand_times[1] == 1) << "expand image do not support expend c now";

// do not confuse with these cases.it is use to support expend c in future
if (in_dims[1] == 1) {
kernel_func_name_ = "expend_c1";
} else if (in_dims[1] == 2) {
kernel_func_name_ = "expend_c2";
} else if (in_dims[1] == 3) {
kernel_func_name_ = "expend_c3";
} else if (in_dims[1] == 4) {
kernel_func_name_ = "expend_c4";
} else {
kernel_func_name_ = "expend_cn";
}

// if (in_dims[1] == 1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的注释都删除掉吧

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已改

// kernel_func_name_ = "expend_c1";
// } else if (in_dims[1] == 2) {
// kernel_func_name_ = "expend_c2";
// } else if (in_dims[1] == 3) {
// kernel_func_name_ = "expend_c3";
// } else if (in_dims[1] == 4) {
// kernel_func_name_ = "expend_c4";
// } else {
// kernel_func_name_ = "expend_cn";
// }
kernel_func_name_ = "expend_cn";
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(kernel_func_name_,
Expand Down Expand Up @@ -161,14 +161,14 @@ class ExpandComputeImage2D : public KernelLite<TARGET(kOpenCL),
status = kernel.setArg(11, *out_img);
CL_CHECK_FATAL(status);

status = kernel.setArg(12, expand_times_n);
CL_CHECK_FATAL(status);
status = kernel.setArg(13, expand_times_c);
CL_CHECK_FATAL(status);
status = kernel.setArg(14, expand_times_h);
CL_CHECK_FATAL(status);
status = kernel.setArg(15, expand_times_w);
CL_CHECK_FATAL(status);
// status = kernel.setArg(12, expand_times_n);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已改

// CL_CHECK_FATAL(status);
// status = kernel.setArg(13, expand_times_c);
// CL_CHECK_FATAL(status);
// status = kernel.setArg(14, expand_times_h);
// CL_CHECK_FATAL(status);
// status = kernel.setArg(15, expand_times_w);
// CL_CHECK_FATAL(status);

status = EnqueueNDRangeKernel(context,
kernel,
Expand Down
38 changes: 33 additions & 5 deletions lite/tests/unittest_py/op/test_expand_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ def __init__(self, *args, **kwargs):
Place(TargetType.Host, PrecisionType.FP32, DataLayoutType.NCHW)
]
self.enable_testing_on_place(thread=[1, 4], places=host_places)
opencl_places = [
Place(TargetType.OpenCL, PrecisionType.FP16,
DataLayoutType.ImageDefault), Place(
TargetType.OpenCL, PrecisionType.FP16,
DataLayoutType.ImageFolder),
Place(TargetType.OpenCL, PrecisionType.FP32, DataLayoutType.NCHW),
Place(TargetType.OpenCL, PrecisionType.Any,
DataLayoutType.ImageDefault), Place(
TargetType.OpenCL, PrecisionType.Any,
DataLayoutType.ImageFolder),
Place(TargetType.OpenCL, PrecisionType.Any, DataLayoutType.NCHW),
Place(TargetType.Host, PrecisionType.FP32)
]
self.enable_testing_on_place(places=opencl_places)

def is_program_valid(self,
program_config: ProgramConfig,
Expand All @@ -47,10 +61,20 @@ def is_program_valid(self,
return True

def sample_program_configs(self, draw):
in_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=8), min_size=3, max_size=4))
if self.get_target() == "OpenCL":
in_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=8),
min_size=4,
max_size=4))
else:
in_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=8),
min_size=2,
max_size=4))
expand_shape = draw(
st.lists(
st.integers(
Expand Down Expand Up @@ -108,6 +132,10 @@ def gnerate_inputs(with_tensor):
min_value=1, max_value=8),
min_size=len(in_shape),
max_size=len(in_shape)))
if self.get_target() == "OpenCL":
with_tensor = False
attr_shape[1] = 1

inputs = gnerate_inputs(with_tensor)
expand_op = OpConfig(
type="expand",
Expand All @@ -129,7 +157,7 @@ def add_ignore_pass_case(self):
pass

def test(self, *args, **kwargs):
self.run_and_statis(quant=False, max_examples=300)
self.run_and_statis(quant=False, max_examples=100)


if __name__ == "__main__":
Expand Down