From c7cca3913a79724e61d02be7cc8d0e3111936350 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 13 Apr 2022 09:23:22 +0900 Subject: [PATCH] Support `qnn.conv2d` in FoldExplicitPading (#10982) * wip support pad + qnn.conv2d folding * works * Added test but structural equality is failing * fixed structural equality test using map_free_vars=True --- src/relay/transforms/fold_explicit_padding.cc | 33 ++++++++++++-- .../relay/test_pass_fold_explicit_padding.py | 45 +++++++++++++++++++ 2 files changed, 74 insertions(+), 4 deletions(-) diff --git a/src/relay/transforms/fold_explicit_padding.cc b/src/relay/transforms/fold_explicit_padding.cc index b600953e0765..6aac995e35a7 100644 --- a/src/relay/transforms/fold_explicit_padding.cc +++ b/src/relay/transforms/fold_explicit_padding.cc @@ -49,8 +49,18 @@ class SimplifyConvPad { conv1d_ = IsOp("nn.conv1d"); conv2d_ = IsOp("nn.conv2d"); conv3d_ = IsOp("nn.conv3d"); + conv_ = (conv1d_ || conv2d_ || conv3d_)({pad_, w_}); - pattern_ = conv_; + + input_zero_point_ = IsWildcard(); + kernel_zero_point_ = IsWildcard(); + input_scale_ = IsWildcard(); + kernel_scale_ = IsWildcard(); + + qconv2d_ = IsOp("qnn.conv2d")( + {pad_, w_, input_zero_point_, kernel_zero_point_, input_scale_, kernel_scale_}); + + pattern_ = conv_ || qconv2d_; } template @@ -121,9 +131,21 @@ class SimplifyConvPad { ICHECK(param); Array args = pad_node->args; + auto x = node_map[x_][0]; + auto w = node_map[w_][0]; + // Possibly perform more optimizations if the pad_value is 0 const ConstantNode* pad_value = args[1].as(); - if (param->pad_mode == "constant" && pad_value && ToScalar(pad_value->data) == 0.0) { + if (node_map.find(qconv2d_) != node_map.end()) { + Attrs attrs = GetAttrs(param, call_node->attrs.as()); + auto input_zero_point = node_map[input_zero_point_][0]; + auto kernel_zero_point = node_map[kernel_zero_point_][0]; + auto input_scale = node_map[input_scale_][0]; + auto kernel_scale = node_map[kernel_scale_][0]; + return Call(call_node->op, + {x, w, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, attrs, + call_node->type_args, call_node->span); + } else if (param->pad_mode == "constant" && pad_value && ToScalar(pad_value->data) == 0.0) { Attrs attrs; if (node_map.count(conv1d_)) { attrs = GetAttrs(param, call_node->attrs.as()); @@ -137,8 +159,6 @@ class SimplifyConvPad { if (!attrs.defined()) { return post; } - auto x = node_map[x_][0]; - auto w = node_map[w_][0]; return Call(call_node->op, {x, w}, attrs, call_node->type_args, call_node->span); } return post; @@ -158,6 +178,11 @@ class SimplifyConvPad { DFPattern conv1d_; DFPattern conv2d_; DFPattern conv3d_; + DFPattern qconv2d_; + DFPattern input_zero_point_; + DFPattern kernel_zero_point_; + DFPattern input_scale_; + DFPattern kernel_scale_; }; class SimplifyExplicitPadding { diff --git a/tests/python/relay/test_pass_fold_explicit_padding.py b/tests/python/relay/test_pass_fold_explicit_padding.py index effebaaf1e8b..48b5e510d0a9 100644 --- a/tests/python/relay/test_pass_fold_explicit_padding.py +++ b/tests/python/relay/test_pass_fold_explicit_padding.py @@ -100,5 +100,50 @@ def validate(ndim, pad_width, pad_value, pad_mode, orig_padding, layout): validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 0, "edge", orig_pad * ndim, "NCHW") +def fold_pad_qconv2d(): + def before(): + x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8") + weight = relay.var("weight", shape=(3, 3, 64, 64), dtype="int8") + input_zero_point = 10 + pad = relay.nn.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]], pad_value=input_zero_point) + return relay.qnn.op.conv2d( + pad, + weight, + relay.const(input_zero_point, "int32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "float32"), + channels=64, + kernel_size=(3, 3), + padding=(0, 0), + data_layout="NHWC", + kernel_layout="HWIO", + ) + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8") + weight = relay.var("weight", shape=(3, 3, 64, 64), dtype="int8") + input_zero_point = 10 + return relay.qnn.op.conv2d( + x, + weight, + relay.const(input_zero_point, "int32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "float32"), + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + + a = run_opt_pass(before(), relay.transform.FoldExplicitPadding()) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b, map_free_vars=True), "Actual = \n" + str(a) + + if __name__ == "__main__": test_simplify_conv_pad() + fold_pad_qconv2d()