From f46b556b38862de27b9979ce125313db52365030 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Thu, 22 Apr 2021 08:44:43 -0700 Subject: [PATCH] [Relay][ONNX] 1-D global and adaptive pooling. (#7906) * 1D adaptive pooling added and tested. * Apply formatting. * Add onnx integration and tests. * Busted by lint. --- include/tvm/relay/attrs/nn.h | 18 ++- include/tvm/topi/nn/pooling.h | 15 ++ python/tvm/relay/frontend/onnx.py | 40 ++++- python/tvm/relay/op/nn/_nn.py | 10 ++ python/tvm/relay/op/nn/nn.py | 153 ++++++++++++++++++ .../tvm/topi/testing/adaptive_pool_python.py | 37 +++-- src/relay/op/nn/pooling.cc | 132 +++++++++++++++ tests/python/frontend/onnx/test_forward.py | 38 +++++ tests/python/relay/test_op_level10.py | 14 +- tests/python/relay/test_op_level2.py | 75 ++++++--- 10 files changed, 495 insertions(+), 37 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 43704e39953c1..f4c47c2dae8fa 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -760,7 +760,22 @@ struct GlobalPool2DAttrs : public tvm::AttrsNode { } }; -/*! \brief Attributes for adaptive pool operator */ +/*! \brief Attributes for 1d adaptive pool operator */ +struct AdaptivePool1DAttrs : public tvm::AttrsNode { + Array output_size; + std::string layout; + + TVM_DECLARE_ATTRS(AdaptivePool1DAttrs, "relay.attrs.AdaptivePool1DAttrs") { + TVM_ATTR_FIELD(output_size).set_default(Array({})).describe("Output width."); + TVM_ATTR_FIELD(layout).set_default("NCW").describe( + "Dimension ordering of input data. Can be 'NCW', 'NWC', etc." + "'N', 'C', 'W' stands for batch, channel, and width" + "dimensions respectively. Pooling is applied on the" + "'W' dimension."); + } +}; + +/*! \brief Attributes for 2d adaptive pool operator */ struct AdaptivePool2DAttrs : public tvm::AttrsNode { Array output_size; std::string layout; @@ -777,6 +792,7 @@ struct AdaptivePool2DAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes for 3d adaptive pool operator */ struct AdaptivePool3DAttrs : public tvm::AttrsNode { Array output_size; std::string layout; diff --git a/include/tvm/topi/nn/pooling.h b/include/tvm/topi/nn/pooling.h index 8c30e673b3049..e40759907e6ba 100644 --- a/include/tvm/topi/nn/pooling.h +++ b/include/tvm/topi/nn/pooling.h @@ -613,6 +613,21 @@ inline Tensor adaptive_pool3d(const Tensor& x, const Array& output_siz return adaptive_pool_impl(x, output_size, pool_type, {depth_axis, height_axis, width_axis}); } +/*! + * \brief Adaptively perform pooling on one dimensional data. + * See the two dimensional version above for details. + * \param x The input tensor + * \param output_size Vector of one int: {output_width} + * \param pool_type The type of pooling operator + * \param layout The input layout. The default is "NCW". + */ +inline Tensor adaptive_pool1d(const Tensor& x, const Array& output_size, + PoolType pool_type, const std::string& layout = "NCW") { + int width_axis = -1; + ICHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout; + return adaptive_pool_impl(x, output_size, pool_type, {width_axis}); +} + /*! * \brief Perform global pooling on height and width dimension of data. * It decides the height and width dimension according to the layout string, diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index b8cb1f602656f..4b159a5716892 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -501,6 +501,42 @@ def _impl_v1(cls, inputs, attr, params): return out +class GlobalAveragePool(OnnxOpConverter): + """Operator converter for GlobalAveragePool""" + + @classmethod + def _impl_v1(cls, inputs, attr, params): + rank = len(infer_shape(inputs[0])) + if rank == 3: + return _op.nn.global_avg_pool1d(inputs[0]) + if rank == 4: + return _op.nn.global_avg_pool2d(inputs[0]) + if rank == 5: + return _op.nn.global_avg_pool3d(inputs[0]) + raise NotImplementedError( + "Global average pooling is only implemented for 1D, 2D, and 3D kernels, got %dD." + % (rank - 2), + ) + + +class GlobalMaxPool(OnnxOpConverter): + """Operator converter for GlobalMaxPool""" + + @classmethod + def _impl_v1(cls, inputs, attr, params): + rank = len(infer_shape(inputs[0])) + if rank == 3: + return _op.nn.global_max_pool1d(inputs[0]) + if rank == 4: + return _op.nn.global_max_pool2d(inputs[0]) + if rank == 5: + return _op.nn.global_max_pool3d(inputs[0]) + raise NotImplementedError( + "Global max pooling is only implemented for 1D, 2D, and 3D kernels, got %dD." + % (rank - 2), + ) + + class Div(Elemwise): """Operator converter for Divide.""" @@ -2775,8 +2811,8 @@ def _get_convert_map(opset): "MaxUnpool": MaxUnpool.get_converter(opset), "Conv": Conv.get_converter(opset), "ConvTranspose": ConvTranspose.get_converter(opset), - "GlobalAveragePool": Renamer("global_avg_pool2d"), - "GlobalMaxPool": Renamer("global_max_pool2d"), + "GlobalAveragePool": GlobalAveragePool.get_converter(opset), + "GlobalMaxPool": GlobalMaxPool.get_converter(opset), "BatchNormalization": BatchNorm.get_converter(opset), "InstanceNormalization": InstanceNorm.get_converter(opset), # 'LpNormalization' diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index af64873ee9049..3d817c7378b54 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -504,6 +504,16 @@ def compute_contrib_conv3d_winograd_weight_transform(attrs, inputs, out_dtype): reg.register_pattern("nn.avg_pool2d_grad", OpPattern.OUT_ELEMWISE_FUSABLE) +# adaptive_max_pool1d +reg.register_schedule("nn.adaptive_max_pool1d", strategy.schedule_adaptive_pool) +reg.register_pattern("nn.adaptive_max_pool1d", OpPattern.OUT_ELEMWISE_FUSABLE) + + +# adaptive_avg_pool1d +reg.register_schedule("nn.adaptive_avg_pool1d", strategy.schedule_adaptive_pool) +reg.register_pattern("nn.adaptive_avg_pool1d", OpPattern.OUT_ELEMWISE_FUSABLE) + + # global_max_pool2d reg.register_schedule("nn.global_max_pool2d", strategy.schedule_adaptive_pool) reg.register_pattern("nn.global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index e5ca7e4a4717b..c449651f1130c 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -2964,6 +2964,94 @@ def space_to_depth(data, block_size, layout="NCHW"): return _make.space_to_depth(data, block_size, layout) +def adaptive_max_pool1d(data, output_size=None, layout="NCW"): + r"""1D adaptive max pooling operator. This operator is experimental. + + This operator takes data as input and does 1D max value calculation + across each window represented by W. + + + In the default case, where the data_layout is `NCW` + a data Tensor with shape `(batch_size, in_channels, width)`, + to produce an output Tensor with shape + (batch_size, in_channels, output_width). + + The pooling kernel and stride sizes are automatically chosen for + desired output sizes. + + For output_size: + If this argument is not provided, input height and width will be used + as output height and width. + + If a single integer is provided for output_size, the output size is + (N x C x output_size) for any input (NCW). + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + output_size : tuple of int. optional + Output height and width. + + layout : str, optional + Layout of the input. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + output_size = [] or output_size + if isinstance(output_size, int): + output_size = [output_size] + return _make.adaptive_max_pool1d(data, output_size, layout) + + +def adaptive_avg_pool1d(data, output_size=None, layout="NCW"): + r"""1D adaptive average pooling operator. This operator is experimental. + + This operator takes data as input and does 1D average value calculation + across each window represented by W. + + + In the default case, where the data_layout is `NCW` + a data Tensor with shape `(batch_size, in_channels, width)`, + to produce an output Tensor with shape + (batch_size, in_channels, output_width). + + The pooling kernel and stride sizes are automatically chosen for + desired output sizes. + + For output_size: + If this argument is not provided, input height and width will be used + as output width. + + If a single integer is provided for output_size, the output size is + (N x C x output_size) for any input (NCW). + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + output_size : tuple of int. optional + Output height and width. + + layout : str, optional + Layout of the input. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + output_size = [] or output_size + if isinstance(output_size, int): + output_size = [output_size] + return _make.adaptive_avg_pool1d(data, output_size, layout) + + def adaptive_max_pool2d(data, output_size=None, layout="NCHW"): r"""2D adaptive max pooling operator. This operator is experimental. @@ -3142,6 +3230,71 @@ def adaptive_avg_pool3d(data, output_size=None, layout="NCDHW"): return _make.adaptive_avg_pool3d(data, output_size, layout) +def global_max_pool1d(data, layout="NCW"): + r"""1D global maximum pooling operator. + + This operator takes data as input and does 1D max value calculation + across each window represented by W. + + In the default case, where the data_layout is `NCW` + a data Tensor with shape `(batch_size, in_channels, width)`, + to produce an output Tensor with the following rule: + + with data of shape (b, c, w) + .. math:: + + \mbox{out}(b, c, 1) = \max_{n=0, \ldots, w} \mbox{data}(b, c, n) + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + layout : str, optional + Layout of the input. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + output_size = [1] + return _make.adaptive_max_pool1d(data, output_size, layout) + + +def global_avg_pool1d(data, layout="NCW"): + r"""1D global average pooling operator. + + This operator takes data as input and does 1D average value calculation + across each window represented by W. + + In the default case, where the data_layout is `NCW` + a data Tensor with shape `(batch_size, in_channels, width)`, + to produce an output Tensor with the following rule: + + with data of shape (b, c, w) + + .. math:: + + \mbox{out}(b, c, 1) = \frac{1}{w} \sum_{n=0}^{w-1} \mbox{data}(b, c, n) + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + layout : str, optional + Layout of the input. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + output_size = [1] + return _make.adaptive_avg_pool1d(data, output_size, layout) + + def global_max_pool3d(data, layout="NCDHW"): r"""3D global maximum pooling operator. diff --git a/python/tvm/topi/testing/adaptive_pool_python.py b/python/tvm/topi/testing/adaptive_pool_python.py index 79f42c8f6dc69..dd8fadd71f144 100644 --- a/python/tvm/topi/testing/adaptive_pool_python.py +++ b/python/tvm/topi/testing/adaptive_pool_python.py @@ -27,6 +27,17 @@ def _end_index(index, odim, idim): return int(np.ceil((index + 1) * idim / odim)) +def _pool1d(in_size, out_size, np_data, np_op): + out = np.zeros(out_size).astype(np_data.dtype) + ow = out_size[0] + for l in range(ow): + l_start = _start_index(l, ow, in_size[0]) + l_end = _end_index(l, ow, in_size[0]) + l_sl = slice(l_start, l_end) + out[l] = np_op(np_data[l_sl]) + return out + + def _pool2d(in_size, out_size, np_data, np_op): out = np.zeros(out_size).astype(np_data.dtype) oh, ow = out_size @@ -61,8 +72,8 @@ def _pool3d(in_size, out_size, np_data, np_op): return out -def adaptive_pool_nchw(np_data, out_size, pool_op, np_op): - """ The reference function for adaptive pool, nchw layout """ +def adaptive_pool_channel_first(np_data, out_size, pool_op, np_op): + """ The reference function for adaptive pool, channel first layout """ ishape = np_data.shape n, c = ishape[:2] oshape = (n, c) + out_size @@ -75,8 +86,8 @@ def adaptive_pool_nchw(np_data, out_size, pool_op, np_op): return np_out -def adaptive_pool_nhwc(np_data, out_size, pool_op, np_op): - """ The reference function for adaptive pool, nhwc layout """ +def adaptive_pool_channel_last(np_data, out_size, pool_op, np_op): + """ The reference function for adaptive pool, channel last layout """ ishape = np_data.shape n, c = ishape[0], ishape[-1] oshape = (n,) + out_size + (c,) @@ -84,7 +95,9 @@ def adaptive_pool_nhwc(np_data, out_size, pool_op, np_op): for i in range(n): for j in range(c): - if len(out_size) == 2: + if len(out_size) == 1: + np_out[i, :, j] = pool_op(ishape[1:-1], out_size, np_data[i, :, j], np_op) + elif len(out_size) == 2: np_out[i, :, :, j] = pool_op(ishape[1:-1], out_size, np_data[i, :, :, j], np_op) else: np_out[i, :, :, :, j] = pool_op( @@ -96,7 +109,11 @@ def adaptive_pool_nhwc(np_data, out_size, pool_op, np_op): def adaptive_pool(np_data, out_size, pool_type, layout): """ The reference function for adaptive pool, for 2d and 3d """ - if len(out_size) == 2: + if isinstance(out_size, int): + out_size = (out_size,) + if len(out_size) == 1: + pool_op = _pool1d + elif len(out_size) == 2: pool_op = _pool2d else: assert len(out_size) == 3 @@ -104,8 +121,8 @@ def adaptive_pool(np_data, out_size, pool_type, layout): np_op = np.mean if pool_type == "avg" else np.max - if layout in ["NCHW", "NCDHW"]: - return adaptive_pool_nchw(np_data, out_size, pool_op, np_op) + if layout in ["NCW", "NCHW", "NCDHW"]: + return adaptive_pool_channel_first(np_data, out_size, pool_op, np_op) - assert layout in ["NHWC", "NDHWC"] - return adaptive_pool_nhwc(np_data, out_size, pool_op, np_op) + assert layout in ["NWC", "NHWC", "NDHWC"] + return adaptive_pool_channel_last(np_data, out_size, pool_op, np_op) diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index 0f38979a7ca10..cd7f6808845b0 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -349,6 +349,137 @@ RELAY_REGISTER_OP("nn.global_max_pool2d") .set_attr("FInferCorrectLayout", PoolInferCorrectLayout) .set_attr("FTVMCompute", GlobalPool2DCompute); +// relay.nn.adaptive_pool_1d +TVM_REGISTER_NODE_TYPE(AdaptivePool1DAttrs); + +bool AdaptivePool1DRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) { + return false; + } + const auto dshape = data->shape; + ICHECK_GE(dshape.size(), 1U) << "Pool2D only support input >= 1-D: input must have width"; + const auto* param = attrs.as(); + ICHECK(param != nullptr); + + Layout layout(param->layout); + ICHECK(layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('w'))) + << "Invalid layout " << layout << ". Pool1D layout must have W, which cannot be split"; + + const auto widx = layout.IndexOf(LayoutAxis::Get('W')); + Array oshape(dshape); + auto output_size = param->output_size; + ICHECK_LE(output_size.size(), 1U) << "output_size must have 1 element."; + IndexExpr output_width; + if (output_size.empty()) { + output_width = dshape[widx]; + } else { + output_width = output_size[0]; + } + + oshape.Set(widx, output_width); + + // assign output type + reporter->Assign(types[1], TensorType(oshape, data->dtype)); + return true; +} + +template +Array AdaptivePool1DCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + static const Layout kNCW("NCW"); + const auto* param = attrs.as(); + ICHECK(param != nullptr); + Layout layout(param->layout); + ICHECK(tir::BijectiveLayout(layout, kNCW).defined()) + << "Adaptive pool1d currently only supports layouts that are convertible from NCW"; + ICHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1) + << "Adaptive pool2d does not support input split on width"; + + ICHECK(inputs[0].ndim() == 3U || inputs[0].ndim() == 4U) + << "Pool1D only support 3-D input (e.g., NCW)" + << " or 4-D input (last dimension is a split of channel)"; + + auto output_size = param->output_size; + const auto widx = layout.IndexOf(LayoutAxis::Get('W')); + IndexExpr output_width; + if (output_size.empty()) { + output_width = inputs[0]->shape[widx]; + } else { + output_width = output_size[0]; + } + return Array{ + topi::nn::adaptive_pool1d(inputs[0], Array{output_width}, mode, layout.name())}; +} + +// relay.nn.adaptive_avg_pool1d +Expr MakeAdaptiveAvgPool1D(Expr data, Array output_size, String layout) { + auto attrs = make_object(); + attrs->output_size = std::move(output_size); + attrs->layout = std::move(layout); + static const Op& op = Op::Get("nn.adaptive_avg_pool1d"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_avg_pool1d").set_body_typed(MakeAdaptiveAvgPool1D); + +RELAY_REGISTER_OP("nn.adaptive_avg_pool1d") + .describe(R"code(Adaptive average pooling operation for 1D data. + +- **data**: This depends on the `layout` parameter. Input is 3D array of shape + (batch_size, channels, width) if `layout` is `NCW`. +- **output_size**: If this argument is not provided, input width will be used + as output width. + If an integer is provided for output_size, the output size is + (N x C x output_size) for any input (NCW). +- **out**: This depends on the `layout` parameter. Output is 3D array of shape + (batch_size, channels, output_width) if `layout` is `NCW`. + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(10) + .add_type_rel("AdaptiveAvgPool1D", AdaptivePool1DRel) + .set_attr("FInferCorrectLayout", + PoolInferCorrectLayout) + .set_attr("FTVMCompute", AdaptivePool1DCompute); + +// relay.nn.adaptive_max_pool1d +Expr MakeAdaptiveMaxPool1D(Expr data, Array output_size, String layout) { + auto attrs = make_object(); + attrs->output_size = std::move(output_size); + attrs->layout = std::move(layout); + static const Op& op = Op::Get("nn.adaptive_max_pool1d"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_max_pool1d").set_body_typed(MakeAdaptiveMaxPool1D); + +RELAY_REGISTER_OP("nn.adaptive_max_pool1d") + .describe(R"code(Adaptive max pooling operation for 1D data. + +- **data**: This depends on the `layout` parameter. Input is 3D array of shape + (batch_size, channels, width) if `layout` is `NCW`. +- **output_size**: If this argument is not provided, input width will be used + as output width. + If an integer is provided for output_size, the output size is + (N x C x output_size) for any input (NCW). +- **out**: This depends on the `layout` parameter. Output is 3D array of shape + (batch_size, channels, output_width) if `layout` is `NCW`. + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(10) + .add_type_rel("AdaptiveMaxPool1D", AdaptivePool1DRel) + .set_attr("FInferCorrectLayout", + PoolInferCorrectLayout) + .set_attr("FTVMCompute", AdaptivePool1DCompute); + // relay.nn.adaptive_pool_2d TVM_REGISTER_NODE_TYPE(AdaptivePool2DAttrs); @@ -501,6 +632,7 @@ RELAY_REGISTER_OP("nn.adaptive_max_pool2d") PoolInferCorrectLayout) .set_attr("FTVMCompute", AdaptivePool2DCompute); +// relay.nn.adaptive_pool3d TVM_REGISTER_NODE_TYPE(AdaptivePool3DAttrs); bool AdaptivePool3DRel(const Array& types, int num_inputs, const Attrs& attrs, diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 783408e7c6d95..0a702c5696e04 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -2830,6 +2830,44 @@ def test_pooling(): ) +def verify_global_pooling(x_shape, mode): + out_shape = x_shape[:2] + [1] * (len(x_shape) - 2) + + if mode == "max": + node_type = "GlobalMaxPool" + elif mode == "average": + node_type = "GlobalAveragePool" + else: + raise ValueError("Pool method {} is not supported.".format(mode)) + + pool_node = helper.make_node(node_type, inputs=["x"], outputs=["y"]) + + graph = helper.make_graph( + [pool_node], + "global_pooling_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape))], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(out_shape))], + ) + + model = helper.make_model(graph, producer_name="global_pooling_test") + verify_with_ort(model, [x_shape], [out_shape], use_vm=False, convert_to_static=True) + + +@tvm.testing.uses_gpu +def test_global_pooling(): + # Test each pooling mode across all N-D inputs. + for mode in ["average", "max"]: + # 1D Pooling (NCW) + verify_global_pooling([1, 8, 8], mode) + verify_global_pooling([4, 1, 4], mode) + # 2D Pooling (NCHW) + verify_global_pooling([1, 8, 8, 8], mode) + verify_global_pooling([4, 1, 6, 4], mode) + # 3D Pooling (NCDHW) + verify_global_pooling([1, 8, 6, 8, 8], mode) + verify_global_pooling([4, 1, 2, 6, 4], mode) + + def verify_mod(x_shape, y_shape, fmod, out_shape, dtype="float32"): x_np = np.random.uniform(-100.0, 100.0, x_shape).astype(dtype) y_np = np.random.uniform(-100.0, 100.0, y_shape).astype(dtype) diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 597a1c69e8ee3..0faa31fdc0beb 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -439,18 +439,30 @@ def verify_adaptive_pool(dshape, out_size, pool_type, layout, dtype, opfunc): tvm.testing.assert_allclose(relay_out.asnumpy(), np_out, rtol=1e-5, atol=1e-5) +def verify_adaptive_pool1d(dshape, out_size, pool_type, layout="NCW", dtype="float32"): + opfunc = relay.nn.adaptive_avg_pool1d if pool_type == "avg" else relay.nn.adaptive_max_pool1d + verify_adaptive_pool(dshape, out_size, pool_type, layout, dtype, opfunc) + + def verify_adaptive_pool2d(dshape, out_size, pool_type, layout="NCHW", dtype="float32"): opfunc = relay.nn.adaptive_avg_pool2d if pool_type == "avg" else relay.nn.adaptive_max_pool2d verify_adaptive_pool(dshape, out_size, pool_type, layout, dtype, opfunc) -def verify_adaptive_pool3d(dshape, out_size, pool_type, layout="NCHW", dtype="float32"): +def verify_adaptive_pool3d(dshape, out_size, pool_type, layout="NCDHW", dtype="float32"): opfunc = relay.nn.adaptive_avg_pool3d if pool_type == "avg" else relay.nn.adaptive_max_pool3d verify_adaptive_pool(dshape, out_size, pool_type, layout, dtype, opfunc) @tvm.testing.uses_gpu def test_adaptive_pool(): + verify_adaptive_pool1d((1, 9, 224), (1), "max") + verify_adaptive_pool1d((1, 3, 224), (3), "avg") + verify_adaptive_pool1d((1, 3, 224), (3), "avg", dtype="int32") + verify_adaptive_pool1d((1, 14, 78), (13), "max") + verify_adaptive_pool1d((1, 5, 97), (96), "avg") + verify_adaptive_pool1d((1, 224, 3), (1), "max", layout="NWC") + verify_adaptive_pool1d((1, 3, 224), (3), "avg", layout="NWC") verify_adaptive_pool2d((1, 9, 224, 224), (1, 1), "max") verify_adaptive_pool2d((1, 3, 224, 224), (2, 3), "avg") verify_adaptive_pool2d((1, 3, 224, 224), (2, 3), "avg", dtype="int32") diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 9640868458370..755618cccf06c 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -1007,37 +1007,66 @@ def test_pool2d(): _test_global_pool2d(relay.nn.global_avg_pool2d, np.mean) +def _test_pool1d(opfunc, pool_size=(2,), strides=(2,), padding=(0, 0), dtype="float32"): + n, c, w = te.var("n"), 10, 224 + x = relay.var("x", relay.TensorType((n, c, w), "float32")) + y = opfunc(x, pool_size=(1,)) + assert "pool_size=" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((n, 10, 224), "float32") + # test execution + dshape = (1, 3, 32) + for shape_dtype in ["int32", "int64"]: + x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in dshape], dtype=dtype) + pool_type = "max" if "max" in str(opfunc) else "avg" + y = opfunc(x, pool_size=pool_size, strides=strides, padding=padding) + func = relay.Function([x], y) + data = np.random.uniform(size=dshape).astype(dtype) + ref_res = tvm.topi.testing.pool1d_ncw_python( + data, (2,), (2,), (0, 0), (1, 3, 16), pool_type, False + ) + for target, dev in tvm.testing.enabled_targets(): + intrp1 = relay.create_executor("graph", device=dev, target=target) + op_res1 = intrp1.evaluate(func)(data) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + + +def _test_global_pool1d(opfunc, reffunc): + n, c, w = te.size_var("n"), te.size_var("c"), 224 + x = relay.var("x", relay.TensorType((n, w, c), "float32")) + y = opfunc(x, layout="NWC") + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((n, 1, c), "float32") + + n, c, w = te.size_var("n"), te.size_var("c"), te.size_var("w") + x = relay.var("x", relay.TensorType((n, c, w), "float32")) + y = opfunc(x) + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((n, c, 1), "float32") + # test execution + dtype = "float32" + dshape = (1, 1024, 7) + x = relay.var("x", shape=dshape) + y = opfunc(x) + func = relay.Function([x], y) + data = np.random.uniform(size=dshape).astype(dtype) + ref_res = reffunc(data, axis=(2,), keepdims=True) + for target, dev in tvm.testing.enabled_targets(): + intrp1 = relay.create_executor("graph", device=dev, target=target) + op_res1 = intrp1.evaluate(func)(data) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + + @tvm.testing.uses_gpu def test_pool1d(): - def _test_pool1d(opfunc, pool_size=(2,), strides=(2,), padding=(0, 0), dtype="float32"): - n, c, w = te.var("n"), 10, 224 - x = relay.var("x", relay.TensorType((n, c, w), "float32")) - y = opfunc(x, pool_size=(1,)) - assert "pool_size=" in y.astext() - yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType((n, 10, 224), "float32") - # test execution - dshape = (1, 3, 32) - for shape_dtype in ["int32", "int64"]: - x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in dshape], dtype=dtype) - pool_type = "max" if "max" in str(opfunc) else "avg" - y = opfunc(x, pool_size=pool_size, strides=strides, padding=padding) - func = relay.Function([x], y) - data = np.random.uniform(size=dshape).astype(dtype) - ref_res = tvm.topi.testing.pool1d_ncw_python( - data, (2,), (2,), (0, 0), (1, 3, 16), pool_type, False - ) - for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data) - tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) - _test_pool1d(relay.nn.max_pool1d) _test_pool1d(relay.nn.max_pool1d, dtype="int32") _test_pool1d(relay.nn.max_pool1d, pool_size=2, strides=2, padding=0) _test_pool1d(relay.nn.avg_pool1d) _test_pool1d(relay.nn.avg_pool1d, dtype="int32") _test_pool1d(relay.nn.avg_pool1d, pool_size=2, strides=2, padding=0) + _test_global_pool1d(relay.nn.global_max_pool1d, np.max) + _test_global_pool1d(relay.nn.global_avg_pool1d, np.mean) @tvm.testing.uses_gpu