diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index 8afbd91ce37c..7f1582a5dbab 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -156,10 +156,12 @@ class RelayToTIRVisitor : public MixedModeMutator { ToArg(dilation_w), ToArg(dilation_h), ToArg(clip_min), ToArg(clip_max)}; - // layout NHWC + // CMSIS-NN data structure "cmsis_nn_dims" for ifm expects input layout as NHWC + // This is the same layout we expect in Relay Array input_shape = conv2d_call->args[0]->type_as()->shape; - // OHWI for Conv2D and IHWO for depthwise + // CMSIS-NN data structure "cmsis_nn_dims" for weights expects following layouts + // OHWI for Conv2D and IHWO for Depthwise convolutions Array filter_shape = conv2d_call->args[1]->type_as()->shape; Array bias_shape{1, 1, 1, out_channels}; @@ -179,7 +181,10 @@ class RelayToTIRVisitor : public MixedModeMutator { std::string cmsisnn_api = "arm_convolve_wrapper_s8"; if (depth_multiplier != -1) { cmsisnn_api = "arm_depthwise_conv_wrapper_s8"; - Array depthwise_filter_shape{1, filter_shape[0], filter_shape[1], out_channels}; + int filter_pos_h = kernel_layout.find("H"); + int filter_pos_w = kernel_layout.find("W"); + Array depthwise_filter_shape{1, filter_shape[filter_pos_h], + filter_shape[filter_pos_w], out_channels}; filter_shape = depthwise_filter_shape; } diff --git a/tests/python/contrib/test_cmsisnn/utils.py b/tests/python/contrib/test_cmsisnn/utils.py index 145dbf4b499c..ca2738c06687 100644 --- a/tests/python/contrib/test_cmsisnn/utils.py +++ b/tests/python/contrib/test_cmsisnn/utils.py @@ -86,17 +86,20 @@ def make_module(func): return mod -def get_same_padding(data, kernel, dilation, stride, cmsisnn_padding=True): - """Provides CMSIS-NN padding when output dim == input dim""" +def get_same_padding(in_shape, kernel, dilation, stride): + """ + Provides CMSIS-NN padding when output dim == input dim. + This is TFLu's "SAME" padding case. + """ dilated_kernel_h = dilation[0] * (kernel[0] - 1) + 1 dilated_kernel_w = dilation[1] * (kernel[1] - 1) + 1 - out = int(math.ceil(float(data[0]) / float(stride[0]))) - pad = max(0, (out - 1) * stride[0] + dilated_kernel_h - data[0]) - pad_top, pad_bottom = (pad, 0) if cmsisnn_padding else (0, pad) + out = int(math.ceil(float(in_shape[0]) / float(stride[0]))) + pad = max(0, (out - 1) * stride[0] + dilated_kernel_h - in_shape[0]) + pad_top, pad_bottom = (pad, 0) - out = int(math.ceil(float(data[1]) / float(stride[1]))) - pad = max(0, (out - 1) * stride[1] + dilated_kernel_w - data[1]) - pad_left, pad_right = (pad, 0) if cmsisnn_padding else (0, pad) + out = int(math.ceil(float(in_shape[1]) / float(stride[1]))) + pad = max(0, (out - 1) * stride[1] + dilated_kernel_w - in_shape[1]) + pad_left, pad_right = (pad, 0) return [pad_top, pad_left, pad_bottom, pad_right]