From 45bed88eb49e3cd4a63ace1c241fbef84f3b6cb3 Mon Sep 17 00:00:00 2001 From: An Wang Date: Fri, 27 May 2022 18:11:11 -0700 Subject: [PATCH] [Pass] Add MaxPool, AvgPool to FoldExplicitPadding (#11494) * fold first steps * spitballing * check pad is really optd away * new pool test passes * stuff * refactoring midway * things actually kinda work * complete tests * lint and complete tests * clean * fix comments --- include/tvm/relay/attrs/nn.h | 14 +- src/relay/transforms/fold_explicit_padding.cc | 206 +++++++++++++++--- .../relay/test_pass_fold_explicit_padding.py | 180 ++++++++++++++- 3 files changed, 351 insertions(+), 49 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 7386c25f1a5a..ff611d1f44db 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -891,10 +891,9 @@ struct MaxPool1DAttrs : public tvm::AttrsNode { .set_default(Array({0})) .describe( "If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "three int : back, bottom, right will use same padding as front, top, left" - "six int : padding width in the order of (front, top, left, back, bottom, right)"); + "Padding supports both symmetric and asymmetric as" + "one int : same padding used on each side" + "two int : indicates left padding, right padding"); TVM_ATTR_FIELD(layout).set_default("NCW").describe( "Dimension ordering of input data. Can be 'NCW', 'NWC', etc." "'N', 'C', 'W' stands for batch, channel, and width" @@ -933,10 +932,9 @@ struct AvgPool1DAttrs : public tvm::AttrsNode { .set_default(Array({0})) .describe( "If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "three int : back, bottom, right will use same padding as front, top, left" - "six int : padding width in the order of (front, top, left, back, bottom, right)"); + "Padding supports both symmetric and asymmetric as" + "one int : same padding used on each side" + "two int : indicates left padding, right padding"); TVM_ATTR_FIELD(layout).set_default("NCW").describe( "Dimension ordering of input data. Can be 'NCW', 'NHC', etc." "'N', 'C', 'W' stands for batch, channel, and width" diff --git a/src/relay/transforms/fold_explicit_padding.cc b/src/relay/transforms/fold_explicit_padding.cc index 60b52c170abb..c60f36c7540e 100644 --- a/src/relay/transforms/fold_explicit_padding.cc +++ b/src/relay/transforms/fold_explicit_padding.cc @@ -22,11 +22,15 @@ * \brief A pass for folding explicit pads into other ops. */ +#include #include #include #include #include +#include #include +#include +#include #include "../op/tensor/transform.h" #include "pattern_utils.h" @@ -35,46 +39,70 @@ namespace tvm { namespace relay { /*! - * \brief SimplifyConvPad matches a pad followed by a conv/convtranspose/pool/etc + * \brief SimplifyExplicitPad matches a pad followed by a conv/maxpool/avgpool * with a pad attribute and merges the padding into the kernel. */ -class SimplifyConvPad { +class SimplifyExplicitPad { public: DFPattern pattern() const { return pattern_; } - SimplifyConvPad() { + SimplifyExplicitPad() { x_ = IsWildcard(); - w_ = IsWildcard(); pad_ = IsOp("nn.pad")({x_, IsWildcard()}); + + // pad->conv patterns + w_ = IsWildcard(); conv1d_ = IsOp("nn.conv1d"); conv2d_ = IsOp("nn.conv2d"); conv3d_ = IsOp("nn.conv3d"); - - conv_ = (conv1d_ || conv2d_ || conv3d_)({pad_, w_}); + contrib_conv2d_nchwc_ = IsOp("nn.contrib_conv2d_NCHWc"); + conv_ = (conv1d_ || conv2d_ || conv3d_ || contrib_conv2d_nchwc_)({pad_, w_}); input_zero_point_ = IsWildcard(); kernel_zero_point_ = IsWildcard(); input_scale_ = IsWildcard(); kernel_scale_ = IsWildcard(); - qconv2d_ = IsOp("qnn.conv2d")( {pad_, w_, input_zero_point_, kernel_zero_point_, input_scale_, kernel_scale_}); - pattern_ = conv_ || qconv2d_; + // pad->pool patterns + avg_pool1d_ = IsOp("nn.avg_pool1d"); + avg_pool2d_ = IsOp("nn.avg_pool2d"); + avg_pool3d_ = IsOp("nn.avg_pool3d"); + max_pool1d_ = IsOp("nn.max_pool1d"); + max_pool2d_ = IsOp("nn.max_pool2d"); + max_pool3d_ = IsOp("nn.max_pool3d"); + max_pool_ = max_pool1d_ || max_pool2d_ || max_pool3d_; + pool_ = (max_pool_ || avg_pool1d_ || avg_pool2d_ || avg_pool3d_)({pad_}); + + pattern_ = conv_ || qconv2d_ || pool_; } template - Attrs MakeConvAttrs(const T* old_attrs, const Array padding) const { - ICHECK(old_attrs); + Array get_combined_padding(const T* old_attrs, Array padding) const { ICHECK(padding.size() == old_attrs->padding.size()) << "Number of dimensions to pad and convolution padding attributes should have the same " "extent"; - auto new_attrs = make_object(); Array combined_padding; for (size_t i = 0; i < padding.size(); ++i) { combined_padding.push_back(padding[i] + old_attrs->padding[i]); } + return combined_padding; + } + + template + Attrs MakeConvAttrs(const PadAttrs* param, const T* old_attrs) const { + // Creates attrs from old_attrs with fields shared by 1D, 2D, 3D conv attrs + ICHECK(old_attrs); + ICHECK(param); + auto padding = get_padding(param, old_attrs->data_layout); + if (!padding) { + return Attrs(); + } + auto combined_padding = get_combined_padding(old_attrs, padding.value()); + + auto new_attrs = make_object(); new_attrs->strides = old_attrs->strides; new_attrs->padding = combined_padding; new_attrs->dilation = old_attrs->dilation; @@ -89,22 +117,85 @@ class SimplifyConvPad { } template - Attrs GetAttrs(const PadAttrs* param, const T* attrs) const { + Attrs MakeConv2D3DAttrs(const PadAttrs* param, const T* old_attrs) const { + // Propagate additional Conv2D- and Conv3D-specific attrs + auto attrs = MakeConvAttrs(param, old_attrs); + if (!attrs.defined()) { + return Attrs(); + } + + T* new_attrs = const_cast(attrs.template as()); + new_attrs->auto_scheduler_rewritten_layout = old_attrs->auto_scheduler_rewritten_layout; + return attrs; + } + + template + Attrs MakePoolAttrs(const PadAttrs* param, const T* old_attrs) const { + // Creates attrs from old_attrs with fields shared by 1D, 2D, 3D pool attrs + ICHECK(old_attrs); ICHECK(param); - ICHECK(attrs); - ICHECK(attrs->data_layout.size() == param->pad_width.size()) + auto padding = get_padding(param, old_attrs->layout); + if (!padding) { + return Attrs(); + } + auto combined_padding = get_combined_padding(old_attrs, padding.value()); + + auto new_attrs = make_object(); + new_attrs->pool_size = old_attrs->pool_size; + new_attrs->strides = old_attrs->strides; + new_attrs->dilation = old_attrs->dilation; + new_attrs->padding = combined_padding; + new_attrs->layout = old_attrs->layout; + new_attrs->out_layout = old_attrs->out_layout; + new_attrs->ceil_mode = old_attrs->ceil_mode; + return Attrs(new_attrs); + } + + template + Attrs MakeAvgPoolAttrs(const PadAttrs* param, const T* old_attrs) const { + // Propagate additional AvgPool-specific attrs + auto attrs = MakePoolAttrs(param, old_attrs); + if (!attrs.defined()) { + return attrs; + } + + T* new_attrs = const_cast(attrs.template as()); + new_attrs->count_include_pad = old_attrs->count_include_pad; + if (!new_attrs->count_include_pad) { + // AvgPool's divisor doesn't include padding, so don't fold the explicit pad + // unless all original pad items are 0. + for (IndexExpr pad : old_attrs->padding) { + const IntImmNode* maybe_int_imm = pad.as(); + if (!maybe_int_imm || maybe_int_imm->value != 0) { + // Return undefined attrs to signal that we don't want to fold explicit pad + return Attrs(); + } + } + // Turn on `count_include_pad` to preserve original pad first, then pool behavior + // where AvgPool's divisor implicitly includes padding. + new_attrs->count_include_pad = true; + } + + return attrs; + } + + static const Optional> get_padding(const PadAttrs* param, + std::string data_layout) { + // Gets spatial axes padding from the given PadAttrs `param`. If padding + // is non-zero on non-spatial axes, return NullOpt. + ICHECK(param); + ICHECK(data_layout.size() == param->pad_width.size()) << "Data Layout and padding attributes should have the same extent"; - std::string data_layout = attrs->data_layout; std::set image_dims({'H', 'W', 'D'}); Array padding; // If we're padding a non-spatial dimension, don't simplify - // Convolution can only pad on spatial axes + // Convolution/Pool can only pad on spatial axes for (size_t i = 0; i < param->pad_width.size(); ++i) { if (!image_dims.count(data_layout[i])) { for (size_t j = 0; j < param->pad_width[i].size(); ++j) { if (param->pad_width[i][j] != 0) { - return Attrs(); + return NullOpt; } } } @@ -116,8 +207,7 @@ class SimplifyConvPad { } } } - - return MakeConvAttrs(attrs, padding); + return padding; } Expr callback(const Expr& pre, const Expr& post, @@ -131,40 +221,75 @@ class SimplifyConvPad { ICHECK(param); auto x = node_map[x_][0]; - auto w = node_map[w_][0]; - // Possibly perform more optimizations if the pad_value is 0 const Expr& pv = pad_node->args[1]; const ConstantNode* pad_value = pv.as(); + auto pad_scalar = ToScalar(pad_value->data); + if (node_map.find(qconv2d_) != node_map.end()) { - Attrs attrs = GetAttrs(param, call_node->attrs.as()); + Attrs attrs = MakeConv2D3DAttrs(param, call_node->attrs.as()); + if (!attrs.defined()) { + return post; + } auto input_zero_point = node_map[input_zero_point_][0]; auto kernel_zero_point = node_map[kernel_zero_point_][0]; auto input_scale = node_map[input_scale_][0]; auto kernel_scale = node_map[kernel_scale_][0]; // Fold Padding and QNN Convolution only if pad value == input zero point. if (IsEqualScalar(input_zero_point, pv)) { + auto w = node_map[w_][0]; return Call(call_node->op, {x, w, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, attrs, call_node->type_args, call_node->span); - } else { - return post; } - } else if (param->pad_mode == "constant" && pad_value && ToScalar(pad_value->data) == 0.0) { + return post; + } + + if (param->pad_mode == "constant" && pad_value) { Attrs attrs; - if (node_map.count(conv1d_)) { - attrs = GetAttrs(param, call_node->attrs.as()); - } else if (node_map.count(conv2d_)) { - attrs = GetAttrs(param, call_node->attrs.as()); - } else if (node_map.count(conv3d_)) { - attrs = GetAttrs(param, call_node->attrs.as()); - } else { - return post; + if (pad_scalar == 0.0) { + // Fold Padding and Conv/AvgPool only if pad_value == 0. + if (node_map.count(conv_)) { + if (node_map.count(conv1d_)) { + attrs = MakeConvAttrs(param, call_node->attrs.as()); + } else if (node_map.count(conv2d_)) { + attrs = MakeConv2D3DAttrs(param, call_node->attrs.as()); + } else if (node_map.count(conv3d_)) { + attrs = MakeConv2D3DAttrs(param, call_node->attrs.as()); + } + if (!attrs.defined()) { + return post; + } + auto w = node_map[w_][0]; + return Call(call_node->op, {x, w}, attrs, call_node->type_args, call_node->span); + } else if (node_map.count(avg_pool1d_)) { + attrs = MakeAvgPoolAttrs(param, call_node->attrs.as()); + } else if (node_map.count(avg_pool2d_)) { + attrs = MakeAvgPoolAttrs(param, call_node->attrs.as()); + } else if (node_map.count(avg_pool3d_)) { + attrs = MakeAvgPoolAttrs(param, call_node->attrs.as()); + } + } else if (node_map.count(max_pool_)) { + // Fold Padding and MaxPool only if pad_value is the min possible value for the dtype + auto min_value = tvm::min_value(tvm::runtime::DataType(pad_value->data->dtype)); + const FloatImmNode* maybe_min_float = min_value.as(); + const IntImmNode* maybe_min_int = min_value.as(); + + if ((maybe_min_float && pad_scalar == maybe_min_float->value) || + (maybe_min_int && pad_scalar == maybe_min_int->value)) { + if (node_map.count(max_pool1d_)) { + attrs = MakePoolAttrs(param, call_node->attrs.as()); + } else if (node_map.count(max_pool2d_)) { + attrs = MakePoolAttrs(param, call_node->attrs.as()); + } else if (node_map.count(max_pool3d_)) { + attrs = MakePoolAttrs(param, call_node->attrs.as()); + } + } } if (!attrs.defined()) { return post; } - return Call(call_node->op, {x, w}, attrs, call_node->type_args, call_node->span); + return Call(call_node->op, {x}, attrs, call_node->type_args, call_node->span); } return post; } @@ -183,18 +308,27 @@ class SimplifyConvPad { DFPattern conv1d_; DFPattern conv2d_; DFPattern conv3d_; + DFPattern contrib_conv2d_nchwc_; DFPattern qconv2d_; DFPattern input_zero_point_; DFPattern kernel_zero_point_; DFPattern input_scale_; DFPattern kernel_scale_; + /*! \brief Pattern pool */ + DFPattern pool_; + DFPattern avg_pool1d_; + DFPattern avg_pool2d_; + DFPattern avg_pool3d_; + DFPattern max_pool1d_; + DFPattern max_pool2d_; + DFPattern max_pool3d_; + DFPattern max_pool_; }; class SimplifyExplicitPadding { public: explicit SimplifyExplicitPadding(IRModule mod) : mod_(mod) { - CreateCallback(SimplifyConvPad()); - // TODO(mbrookhart): ConvTranspose(Pad(x)), Pool(Pad(x)) + CreateCallback(SimplifyExplicitPad()); } template void CreateCallback(const T& pattern) { diff --git a/tests/python/relay/test_pass_fold_explicit_padding.py b/tests/python/relay/test_pass_fold_explicit_padding.py index 2887c0774b21..41e2500d4ffa 100644 --- a/tests/python/relay/test_pass_fold_explicit_padding.py +++ b/tests/python/relay/test_pass_fold_explicit_padding.py @@ -25,7 +25,7 @@ def test_simplify_conv_pad(): convs = [relay.nn.conv1d, relay.nn.conv2d, relay.nn.conv3d] - def validate(ndim, pad_width, pad_value, pad_mode, orig_padding, layout): + def validate(ndim, pad_width, pad_value, pad_mode, orig_padding, layout, no_fold=False): if layout[1] == "C": shape = [1, 3] + [10] * ndim wshape = [8, 3] + [3] * ndim @@ -69,6 +69,10 @@ def validate(ndim, pad_width, pad_value, pad_mode, orig_padding, layout): mod1 = tvm.IRModule.from_expr(conv) mod2 = tvm.IRModule.from_expr(zz) + if not no_fold: + op_freqs = relay.analysis.list_op_freqs(mod2) + assert "nn.pad" not in op_freqs + with tvm.transform.PassContext(): func1 = relay.create_executor( "vm", mod=mod1, device=tvm.cpu(), target="llvm" @@ -76,11 +80,13 @@ def validate(ndim, pad_width, pad_value, pad_mode, orig_padding, layout): func2 = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm").evaluate() x_np = np.random.rand(*shape).astype("float32") w_np = np.random.rand(*wshape).astype("float32") + result1 = func1(x_np, w_np) result2 = func2(x_np, w_np) tvm.testing.assert_allclose(result1.numpy(), result2.numpy(), rtol=1e-5, atol=1e-5) + # Test fold cases for orig_pad in [[0, 0], [2, 0], [0, 2]]: for i_pad in [[0, 0], [1, 1], [1, 0]]: for ndim in [1, 2, 3]: @@ -95,12 +101,175 @@ def validate(ndim, pad_width, pad_value, pad_mode, orig_padding, layout): padding = [[0, 0]] * 2 + [i_pad] * ndim validate(ndim, padding, 0, "constant", orig_pad * ndim, layout) + + # Test no fold cases ndim = 2 - validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 1, "constant", orig_pad * ndim, "NCHW") - validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 0, "edge", orig_pad * ndim, "NCHW") + # Conv only folds when pad_value=0 + validate( + ndim, [[0, 0]] * 2 + [i_pad] * ndim, 1, "constant", orig_pad * ndim, "NCHW", no_fold=True + ) + # Conv only folds when pad's pad_mode="constant" + validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 0, "edge", orig_pad * ndim, "NCHW", no_fold=True) + + +def get_min_value(dtype): + if np.issubdtype(dtype, np.floating): + return np.finfo(dtype).min + elif np.issubdtype(dtype, np.integer): + return np.iinfo(dtype).min + else: + raise ValueError("Cannot get min value for dtypes that are not integer or not floating") + + +def test_simplify_pool_pad(): + max_pools = [relay.nn.max_pool1d, relay.nn.max_pool2d, relay.nn.max_pool3d] + avg_pools = [relay.nn.avg_pool1d, relay.nn.avg_pool2d, relay.nn.avg_pool3d] + + def validate( + pools, + ndim, + pad_width, + pad_value, + orig_padding, + layout, + pool_size, + pad_mode="constant", + dtype="float32", + no_fold=False, + **kwargs, + ): + pad_value_const = relay.const(pad_value, dtype=dtype) + + if layout[1] == "C": + shape = [1, 3] + [10] * ndim + elif layout[-1] == "C": + shape = [1] + [10] * ndim + [3] + else: + raise ValueError("This test only supports NC* and N*C") + + x = relay.var("x", shape=shape, dtype=dtype) + pad = relay.nn.pad(x, pad_width, pad_value_const, pad_mode) + if layout[1] == "C": + pool = pools[ndim - 1](pad, padding=orig_padding, pool_size=pool_size, **kwargs) + else: + pool = pools[ndim - 1]( + pad, padding=orig_padding, layout=layout, pool_size=pool_size, **kwargs + ) + + if pools == max_pools: + foldable_pad_value = get_min_value(dtype) + else: + foldable_pad_value = 0 + + if pad_mode == "constant" and pad_value == foldable_pad_value: + new_padding = [] + for j in range(2): + for i in range(len(pad_width)): + if layout[i] in ["D", "H", "W"]: + new_padding.append(pad_width[i][j]) + for i in range(len(new_padding)): + new_padding[i] += orig_padding[i] + + if pools == avg_pools and all(v == 0 for v in orig_padding): + # If the orig padding for AvgPool is all zero and the pad op to fold + # has non-zero pad width, the resultant folded AvgPool will have + # count_include_pad=True so AvgPool's divisor is agnostic of pad boundaries + kwargs["count_include_pad"] = True + if layout[1] == "C": + after = pools[ndim - 1](x, padding=new_padding, pool_size=pool_size, **kwargs) + else: + after = pools[ndim - 1]( + x, padding=new_padding, layout=layout, pool_size=pool_size, **kwargs + ) + else: + after = pool + + zz = run_opt_pass(pool, transform.FoldExplicitPadding()) + expected = run_opt_pass(after, transform.InferType()) + + assert tvm.ir.structural_equal(zz, expected) + + mod1 = tvm.IRModule.from_expr(pool) + mod2 = tvm.IRModule.from_expr(zz) + + if not no_fold: + op_freqs = relay.analysis.list_op_freqs(mod2) + assert "nn.pad" not in op_freqs + + with tvm.transform.PassContext(): + func1 = relay.create_executor( + "vm", mod=mod1, device=tvm.cpu(), target="llvm" + ).evaluate() + + func2 = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm").evaluate() + x_np = np.random.rand(*shape).astype(dtype) + + result1 = func1(x_np) + result2 = func2(x_np) + + tvm.testing.assert_allclose(result1.numpy(), result2.numpy(), rtol=1e-5, atol=1e-5) + + # Test fold cases + float_min_val = get_min_value("float32") + for orig_pad in [[0, 0], [2, 0]]: + for i_pad in [[1, 1], [1, 0]]: + for ndim in [1, 2, 3]: + for channels_last in [0, 1]: + if channels_last: + layout = "NDHWC" + layout = layout[0:1] + layout[4 - ndim : 4] + layout[-1:] + padding = [[0, 0]] + [i_pad] * ndim + [[0, 0]] + else: + layout = "NCDHW" + layout = layout[0:2] + layout[5 - ndim :] + padding = [[0, 0]] * 2 + [i_pad] * ndim + + validate(max_pools, ndim, padding, float_min_val, orig_pad * ndim, layout, 2) + + # Check Pool pad folding when pad width on pad op is all zero. + validate(max_pools, 1, [[0, 0], [0, 0], [0, 0]], float_min_val, [2, 0], "NCW", 2) + # Check MaxPool pad folding with int dtype + int_min_val = get_min_value("int32") + validate( + max_pools, + 2, + [[0, 0], [0, 0], [0, 2], [2, 0]], + int_min_val, + [2, 0, 0, 0], + "NCHW", + 2, + dtype="int32", + ) + # Fold when original AvgPool has its own padding but count_include_pad=True + validate( + avg_pools, + 2, + [[0, 0], [0, 0], [0, 2], [2, 0]], + 0, + [0, 0, 1, 0], + "NCHW", + 2, + count_include_pad=True, + ) + # Fold when count_include_pad=False but original AvgPool has no orig padding + validate(avg_pools, 2, [[0, 0], [0, 0], [0, 2], [2, 0]], 0, [0, 0, 0, 0], "NCHW", 2) + + # Test no fold cases + # AvgPool only folds pad when count_include_pad (False by default) is True + validate( + avg_pools, 2, [[0, 0], [0, 0], [0, 2], [2, 0]], 0, [0, 0, 0, 0], "NCHW", 2, no_fold=True + ) + # MaxPool only folds pad when pad_value is the min for its dtype + validate(max_pools, 1, [[0, 0], [0, 0], [0, 2]], 0, [0, 0], "NCHW", 2, no_fold=True) + # AvgPool only folds pad when pad_value=0 + validate(avg_pools, 1, [[0, 0], [0, 0], [0, 2]], 1, [0, 0], "NCHW", 2, no_fold=True) + # Pools only fold when pad_mode="constant" + validate( + avg_pools, 1, [[0, 0], [0, 0], [0, 2]], 0, [0, 0], "NCHW", 2, pad_mode="edge", no_fold=True + ) -def fold_pad_qconv2d(): +def test_fold_pad_qconv2d(): def before(): x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8") weight = relay.var("weight", shape=(3, 3, 64, 64), dtype="int8") @@ -174,5 +343,6 @@ def get_expr(): if __name__ == "__main__": test_simplify_conv_pad() - fold_pad_qconv2d() + test_simplify_pool_pad() + test_fold_pad_qconv2d() test_pad_qconv2d_no_fold()