Skip to content

Commit

Permalink
[Relay][ONNX] 1-D global and adaptive pooling. (apache#7906)
Browse files Browse the repository at this point in the history
* 1D adaptive pooling added and tested.

* Apply formatting.

* Add onnx integration and tests.

* Busted by lint.
  • Loading branch information
Josh Fromm authored and Trevor Morris committed May 6, 2021
1 parent 90c862c commit f46b556
Show file tree
Hide file tree
Showing 10 changed files with 495 additions and 37 deletions.
18 changes: 17 additions & 1 deletion include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,22 @@ struct GlobalPool2DAttrs : public tvm::AttrsNode<GlobalPool2DAttrs> {
}
};

/*! \brief Attributes for adaptive pool operator */
/*! \brief Attributes for 1d adaptive pool operator */
struct AdaptivePool1DAttrs : public tvm::AttrsNode<AdaptivePool1DAttrs> {
Array<IndexExpr> output_size;
std::string layout;

TVM_DECLARE_ATTRS(AdaptivePool1DAttrs, "relay.attrs.AdaptivePool1DAttrs") {
TVM_ATTR_FIELD(output_size).set_default(Array<IndexExpr>({})).describe("Output width.");
TVM_ATTR_FIELD(layout).set_default("NCW").describe(
"Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
"'N', 'C', 'W' stands for batch, channel, and width"
"dimensions respectively. Pooling is applied on the"
"'W' dimension.");
}
};

/*! \brief Attributes for 2d adaptive pool operator */
struct AdaptivePool2DAttrs : public tvm::AttrsNode<AdaptivePool2DAttrs> {
Array<IndexExpr> output_size;
std::string layout;
Expand All @@ -777,6 +792,7 @@ struct AdaptivePool2DAttrs : public tvm::AttrsNode<AdaptivePool2DAttrs> {
}
};

/*! \brief Attributes for 3d adaptive pool operator */
struct AdaptivePool3DAttrs : public tvm::AttrsNode<AdaptivePool3DAttrs> {
Array<IndexExpr> output_size;
std::string layout;
Expand Down
15 changes: 15 additions & 0 deletions include/tvm/topi/nn/pooling.h
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,21 @@ inline Tensor adaptive_pool3d(const Tensor& x, const Array<PrimExpr>& output_siz
return adaptive_pool_impl(x, output_size, pool_type, {depth_axis, height_axis, width_axis});
}

/*!
* \brief Adaptively perform pooling on one dimensional data.
* See the two dimensional version above for details.
* \param x The input tensor
* \param output_size Vector of one int: {output_width}
* \param pool_type The type of pooling operator
* \param layout The input layout. The default is "NCW".
*/
inline Tensor adaptive_pool1d(const Tensor& x, const Array<PrimExpr>& output_size,
PoolType pool_type, const std::string& layout = "NCW") {
int width_axis = -1;
ICHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout;
return adaptive_pool_impl(x, output_size, pool_type, {width_axis});
}

/*!
* \brief Perform global pooling on height and width dimension of data.
* It decides the height and width dimension according to the layout string,
Expand Down
40 changes: 38 additions & 2 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,42 @@ def _impl_v1(cls, inputs, attr, params):
return out


class GlobalAveragePool(OnnxOpConverter):
"""Operator converter for GlobalAveragePool"""

@classmethod
def _impl_v1(cls, inputs, attr, params):
rank = len(infer_shape(inputs[0]))
if rank == 3:
return _op.nn.global_avg_pool1d(inputs[0])
if rank == 4:
return _op.nn.global_avg_pool2d(inputs[0])
if rank == 5:
return _op.nn.global_avg_pool3d(inputs[0])
raise NotImplementedError(
"Global average pooling is only implemented for 1D, 2D, and 3D kernels, got %dD."
% (rank - 2),
)


class GlobalMaxPool(OnnxOpConverter):
"""Operator converter for GlobalMaxPool"""

@classmethod
def _impl_v1(cls, inputs, attr, params):
rank = len(infer_shape(inputs[0]))
if rank == 3:
return _op.nn.global_max_pool1d(inputs[0])
if rank == 4:
return _op.nn.global_max_pool2d(inputs[0])
if rank == 5:
return _op.nn.global_max_pool3d(inputs[0])
raise NotImplementedError(
"Global max pooling is only implemented for 1D, 2D, and 3D kernels, got %dD."
% (rank - 2),
)


class Div(Elemwise):
"""Operator converter for Divide."""

Expand Down Expand Up @@ -2775,8 +2811,8 @@ def _get_convert_map(opset):
"MaxUnpool": MaxUnpool.get_converter(opset),
"Conv": Conv.get_converter(opset),
"ConvTranspose": ConvTranspose.get_converter(opset),
"GlobalAveragePool": Renamer("global_avg_pool2d"),
"GlobalMaxPool": Renamer("global_max_pool2d"),
"GlobalAveragePool": GlobalAveragePool.get_converter(opset),
"GlobalMaxPool": GlobalMaxPool.get_converter(opset),
"BatchNormalization": BatchNorm.get_converter(opset),
"InstanceNormalization": InstanceNorm.get_converter(opset),
# 'LpNormalization'
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,16 @@ def compute_contrib_conv3d_winograd_weight_transform(attrs, inputs, out_dtype):
reg.register_pattern("nn.avg_pool2d_grad", OpPattern.OUT_ELEMWISE_FUSABLE)


# adaptive_max_pool1d
reg.register_schedule("nn.adaptive_max_pool1d", strategy.schedule_adaptive_pool)
reg.register_pattern("nn.adaptive_max_pool1d", OpPattern.OUT_ELEMWISE_FUSABLE)


# adaptive_avg_pool1d
reg.register_schedule("nn.adaptive_avg_pool1d", strategy.schedule_adaptive_pool)
reg.register_pattern("nn.adaptive_avg_pool1d", OpPattern.OUT_ELEMWISE_FUSABLE)


# global_max_pool2d
reg.register_schedule("nn.global_max_pool2d", strategy.schedule_adaptive_pool)
reg.register_pattern("nn.global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
Expand Down
153 changes: 153 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2964,6 +2964,94 @@ def space_to_depth(data, block_size, layout="NCHW"):
return _make.space_to_depth(data, block_size, layout)


def adaptive_max_pool1d(data, output_size=None, layout="NCW"):
r"""1D adaptive max pooling operator. This operator is experimental.
This operator takes data as input and does 1D max value calculation
across each window represented by W.
In the default case, where the data_layout is `NCW`
a data Tensor with shape `(batch_size, in_channels, width)`,
to produce an output Tensor with shape
(batch_size, in_channels, output_width).
The pooling kernel and stride sizes are automatically chosen for
desired output sizes.
For output_size:
If this argument is not provided, input height and width will be used
as output height and width.
If a single integer is provided for output_size, the output size is
(N x C x output_size) for any input (NCW).
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
output_size : tuple of int. optional
Output height and width.
layout : str, optional
Layout of the input.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
output_size = [] or output_size
if isinstance(output_size, int):
output_size = [output_size]
return _make.adaptive_max_pool1d(data, output_size, layout)


def adaptive_avg_pool1d(data, output_size=None, layout="NCW"):
r"""1D adaptive average pooling operator. This operator is experimental.
This operator takes data as input and does 1D average value calculation
across each window represented by W.
In the default case, where the data_layout is `NCW`
a data Tensor with shape `(batch_size, in_channels, width)`,
to produce an output Tensor with shape
(batch_size, in_channels, output_width).
The pooling kernel and stride sizes are automatically chosen for
desired output sizes.
For output_size:
If this argument is not provided, input height and width will be used
as output width.
If a single integer is provided for output_size, the output size is
(N x C x output_size) for any input (NCW).
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
output_size : tuple of int. optional
Output height and width.
layout : str, optional
Layout of the input.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
output_size = [] or output_size
if isinstance(output_size, int):
output_size = [output_size]
return _make.adaptive_avg_pool1d(data, output_size, layout)


def adaptive_max_pool2d(data, output_size=None, layout="NCHW"):
r"""2D adaptive max pooling operator. This operator is experimental.
Expand Down Expand Up @@ -3142,6 +3230,71 @@ def adaptive_avg_pool3d(data, output_size=None, layout="NCDHW"):
return _make.adaptive_avg_pool3d(data, output_size, layout)


def global_max_pool1d(data, layout="NCW"):
r"""1D global maximum pooling operator.
This operator takes data as input and does 1D max value calculation
across each window represented by W.
In the default case, where the data_layout is `NCW`
a data Tensor with shape `(batch_size, in_channels, width)`,
to produce an output Tensor with the following rule:
with data of shape (b, c, w)
.. math::
\mbox{out}(b, c, 1) = \max_{n=0, \ldots, w} \mbox{data}(b, c, n)
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
layout : str, optional
Layout of the input.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
output_size = [1]
return _make.adaptive_max_pool1d(data, output_size, layout)


def global_avg_pool1d(data, layout="NCW"):
r"""1D global average pooling operator.
This operator takes data as input and does 1D average value calculation
across each window represented by W.
In the default case, where the data_layout is `NCW`
a data Tensor with shape `(batch_size, in_channels, width)`,
to produce an output Tensor with the following rule:
with data of shape (b, c, w)
.. math::
\mbox{out}(b, c, 1) = \frac{1}{w} \sum_{n=0}^{w-1} \mbox{data}(b, c, n)
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
layout : str, optional
Layout of the input.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
output_size = [1]
return _make.adaptive_avg_pool1d(data, output_size, layout)


def global_max_pool3d(data, layout="NCDHW"):
r"""3D global maximum pooling operator.
Expand Down
37 changes: 27 additions & 10 deletions python/tvm/topi/testing/adaptive_pool_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ def _end_index(index, odim, idim):
return int(np.ceil((index + 1) * idim / odim))


def _pool1d(in_size, out_size, np_data, np_op):
out = np.zeros(out_size).astype(np_data.dtype)
ow = out_size[0]
for l in range(ow):
l_start = _start_index(l, ow, in_size[0])
l_end = _end_index(l, ow, in_size[0])
l_sl = slice(l_start, l_end)
out[l] = np_op(np_data[l_sl])
return out


def _pool2d(in_size, out_size, np_data, np_op):
out = np.zeros(out_size).astype(np_data.dtype)
oh, ow = out_size
Expand Down Expand Up @@ -61,8 +72,8 @@ def _pool3d(in_size, out_size, np_data, np_op):
return out


def adaptive_pool_nchw(np_data, out_size, pool_op, np_op):
""" The reference function for adaptive pool, nchw layout """
def adaptive_pool_channel_first(np_data, out_size, pool_op, np_op):
""" The reference function for adaptive pool, channel first layout """
ishape = np_data.shape
n, c = ishape[:2]
oshape = (n, c) + out_size
Expand All @@ -75,16 +86,18 @@ def adaptive_pool_nchw(np_data, out_size, pool_op, np_op):
return np_out


def adaptive_pool_nhwc(np_data, out_size, pool_op, np_op):
""" The reference function for adaptive pool, nhwc layout """
def adaptive_pool_channel_last(np_data, out_size, pool_op, np_op):
""" The reference function for adaptive pool, channel last layout """
ishape = np_data.shape
n, c = ishape[0], ishape[-1]
oshape = (n,) + out_size + (c,)
np_out = np.zeros(oshape).astype(np_data.dtype)

for i in range(n):
for j in range(c):
if len(out_size) == 2:
if len(out_size) == 1:
np_out[i, :, j] = pool_op(ishape[1:-1], out_size, np_data[i, :, j], np_op)
elif len(out_size) == 2:
np_out[i, :, :, j] = pool_op(ishape[1:-1], out_size, np_data[i, :, :, j], np_op)
else:
np_out[i, :, :, :, j] = pool_op(
Expand All @@ -96,16 +109,20 @@ def adaptive_pool_nhwc(np_data, out_size, pool_op, np_op):

def adaptive_pool(np_data, out_size, pool_type, layout):
""" The reference function for adaptive pool, for 2d and 3d """
if len(out_size) == 2:
if isinstance(out_size, int):
out_size = (out_size,)
if len(out_size) == 1:
pool_op = _pool1d
elif len(out_size) == 2:
pool_op = _pool2d
else:
assert len(out_size) == 3
pool_op = _pool3d

np_op = np.mean if pool_type == "avg" else np.max

if layout in ["NCHW", "NCDHW"]:
return adaptive_pool_nchw(np_data, out_size, pool_op, np_op)
if layout in ["NCW", "NCHW", "NCDHW"]:
return adaptive_pool_channel_first(np_data, out_size, pool_op, np_op)

assert layout in ["NHWC", "NDHWC"]
return adaptive_pool_nhwc(np_data, out_size, pool_op, np_op)
assert layout in ["NWC", "NHWC", "NDHWC"]
return adaptive_pool_channel_last(np_data, out_size, pool_op, np_op)
Loading

0 comments on commit f46b556

Please sign in to comment.