Skip to content

Commit

Permalink
Support qnn.conv2d in FoldExplicitPading (#10982)
Browse files Browse the repository at this point in the history
* wip support pad + qnn.conv2d folding

* works

* Added test but structural equality is failing

* fixed structural equality test using map_free_vars=True
  • Loading branch information
masahi committed Apr 13, 2022
1 parent a2d973d commit c7cca39
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 4 deletions.
33 changes: 29 additions & 4 deletions src/relay/transforms/fold_explicit_padding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,18 @@ class SimplifyConvPad {
conv1d_ = IsOp("nn.conv1d");
conv2d_ = IsOp("nn.conv2d");
conv3d_ = IsOp("nn.conv3d");

conv_ = (conv1d_ || conv2d_ || conv3d_)({pad_, w_});
pattern_ = conv_;

input_zero_point_ = IsWildcard();
kernel_zero_point_ = IsWildcard();
input_scale_ = IsWildcard();
kernel_scale_ = IsWildcard();

qconv2d_ = IsOp("qnn.conv2d")(
{pad_, w_, input_zero_point_, kernel_zero_point_, input_scale_, kernel_scale_});

pattern_ = conv_ || qconv2d_;
}

template <typename T>
Expand Down Expand Up @@ -121,9 +131,21 @@ class SimplifyConvPad {
ICHECK(param);
Array<Expr> args = pad_node->args;

auto x = node_map[x_][0];
auto w = node_map[w_][0];

// Possibly perform more optimizations if the pad_value is 0
const ConstantNode* pad_value = args[1].as<ConstantNode>();
if (param->pad_mode == "constant" && pad_value && ToScalar(pad_value->data) == 0.0) {
if (node_map.find(qconv2d_) != node_map.end()) {
Attrs attrs = GetAttrs(param, call_node->attrs.as<Conv2DAttrs>());
auto input_zero_point = node_map[input_zero_point_][0];
auto kernel_zero_point = node_map[kernel_zero_point_][0];
auto input_scale = node_map[input_scale_][0];
auto kernel_scale = node_map[kernel_scale_][0];
return Call(call_node->op,
{x, w, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, attrs,
call_node->type_args, call_node->span);
} else if (param->pad_mode == "constant" && pad_value && ToScalar(pad_value->data) == 0.0) {
Attrs attrs;
if (node_map.count(conv1d_)) {
attrs = GetAttrs(param, call_node->attrs.as<Conv1DAttrs>());
Expand All @@ -137,8 +159,6 @@ class SimplifyConvPad {
if (!attrs.defined()) {
return post;
}
auto x = node_map[x_][0];
auto w = node_map[w_][0];
return Call(call_node->op, {x, w}, attrs, call_node->type_args, call_node->span);
}
return post;
Expand All @@ -158,6 +178,11 @@ class SimplifyConvPad {
DFPattern conv1d_;
DFPattern conv2d_;
DFPattern conv3d_;
DFPattern qconv2d_;
DFPattern input_zero_point_;
DFPattern kernel_zero_point_;
DFPattern input_scale_;
DFPattern kernel_scale_;
};

class SimplifyExplicitPadding {
Expand Down
45 changes: 45 additions & 0 deletions tests/python/relay/test_pass_fold_explicit_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,50 @@ def validate(ndim, pad_width, pad_value, pad_mode, orig_padding, layout):
validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 0, "edge", orig_pad * ndim, "NCHW")


def fold_pad_qconv2d():
def before():
x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8")
weight = relay.var("weight", shape=(3, 3, 64, 64), dtype="int8")
input_zero_point = 10
pad = relay.nn.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]], pad_value=input_zero_point)
return relay.qnn.op.conv2d(
pad,
weight,
relay.const(input_zero_point, "int32"),
relay.const(1, "int32"),
relay.const(1, "float32"),
relay.const(1, "float32"),
channels=64,
kernel_size=(3, 3),
padding=(0, 0),
data_layout="NHWC",
kernel_layout="HWIO",
)

def expected():
x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8")
weight = relay.var("weight", shape=(3, 3, 64, 64), dtype="int8")
input_zero_point = 10
return relay.qnn.op.conv2d(
x,
weight,
relay.const(input_zero_point, "int32"),
relay.const(1, "int32"),
relay.const(1, "float32"),
relay.const(1, "float32"),
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
)

a = run_opt_pass(before(), relay.transform.FoldExplicitPadding())
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b, map_free_vars=True), "Actual = \n" + str(a)


if __name__ == "__main__":
test_simplify_conv_pad()
fold_pad_qconv2d()

0 comments on commit c7cca39

Please sign in to comment.