Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Conv2DTransposed] Fix wrong shape check and add new TOPI module to support groups #9465

Merged
merged 29 commits into from
Nov 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2408739
f wrong type check in conv2d_transpose
Lyken17 Nov 8, 2021
5dc34a9
add test case for conv2d transpose
Lyken17 Nov 8, 2021
0961ea5
Merge branch 'main' of https://github.com/dmlc/tvm into main
Lyken17 Nov 8, 2021
2885a57
add groups support for conv2d_transpose
Lyken17 Nov 8, 2021
8f5a979
add naive implementation and schedule for conv2d with groups
Lyken17 Nov 8, 2021
9bea044
enable tests for cpu and arm_cpu, raise error for cuda platform
Lyken17 Nov 8, 2021
a1c7308
revert the cuda and generic strategy
Lyken17 Nov 10, 2021
fcc2f00
revert back the x86 strategy
Lyken17 Nov 10, 2021
59e807b
revert back the arm_cpu strategy
Lyken17 Nov 10, 2021
3eacec4
revert back the arm_cpu strategy
Lyken17 Nov 10, 2021
478ff9a
revert back the arm_cpu strategy
Lyken17 Nov 10, 2021
11c2c75
fix EOF of x86
Lyken17 Nov 10, 2021
064ee1a
clang lint updated c++ code
Lyken17 Nov 10, 2021
efea36a
update topi implementation
Lyken17 Nov 10, 2021
7e21200
Revert test
Lyken17 Nov 10, 2021
14c771b
Revert test
Lyken17 Nov 10, 2021
8d48f22
add generic/x86/arm specialization for conv2d_transpose with groups > 1
Lyken17 Nov 10, 2021
2d766f2
remove commentted codes
Lyken17 Nov 10, 2021
795053e
fix lint
alicja-SiMa-ai Nov 10, 2021
61db0d3
fix lint
Lyken17 Nov 10, 2021
7af455c
fix c++ lint
Lyken17 Nov 11, 2021
6bc2ced
fix lint
Lyken17 Nov 11, 2021
3d7ce97
fix python lint
Lyken17 Nov 11, 2021
a30b5d5
remove comments and reformat
Lyken17 Nov 11, 2021
56fa3bc
lint file
Lyken17 Nov 11, 2021
0a1a437
lint code
Lyken17 Nov 11, 2021
562fdab
fix lint
Lyken17 Nov 11, 2021
3e0db60
update logging information in convolution.h
Lyken17 Nov 11, 2021
87c3c6a
resolve conflicts and fix lint
Lyken17 Nov 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target):
groups = attrs.groups
assert layout == "NCHW", "only support nchw for now"
assert dilation == (1, 1), "not support dilate now"
assert groups == 1, "only support groups == 1 for now"
assert groups == 1, "only support groups == 1 when targetting cuda/gpu"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.cuda.conv2d_transpose_nchw),
Expand Down
26 changes: 18 additions & 8 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def deformable_conv2d_strategy(attrs, inputs, out_type, target):


# conv2d_transpose
def wrap_compute_conv2d_transpose(topi_compute):
def wrap_compute_conv2d_transpose(topi_compute, has_groups=False):
"""wrap conv2d_transpose topi compute"""

def compute_conv2d_transpose(attrs, inputs, out_dtype):
Expand All @@ -456,7 +456,11 @@ def compute_conv2d_transpose(attrs, inputs, out_dtype):
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
output_padding = get_const_tuple(attrs.output_padding)
out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype, output_padding)
# out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype, output_padding)
args = [inputs[0], inputs[1], strides, padding, out_dtype, output_padding]
if has_groups:
args.append(attrs.groups)
out = topi_compute(*args)
return [out]

return compute_conv2d_transpose
Expand All @@ -471,13 +475,19 @@ def conv2d_transpose_strategy(attrs, inputs, out_type, target):
groups = attrs.groups
assert layout == "NCHW", "only support nchw for now"
assert dilation == (1, 1), "not support dilate now"
assert groups == 1, "only support groups == 1 for now"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.nn.conv2d_transpose_nchw),
wrap_topi_schedule(topi.generic.schedule_conv2d_transpose_nchw),
name="conv2d_transpose_nchw.generic",
)
if groups == 1:
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.nn.conv2d_transpose_nchw),
wrap_topi_schedule(topi.generic.schedule_conv2d_transpose_nchw),
name="conv2d_transpose_nchw.generic",
)
else: # group_transpose_conv2d
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.nn.group_conv2d_transpose_nchw, has_groups=True),
wrap_topi_schedule(topi.generic.schedule_group_conv2d_transpose_nchw),
name="group_conv2d_transpose_nchw.generic",
)
return strategy


Expand Down
18 changes: 12 additions & 6 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,13 +281,19 @@ def conv2d_transpose_strategy_cpu(attrs, inputs, out_type, target):
groups = attrs.groups
assert layout == "NCHW", "only support nchw for now"
assert dilation == (1, 1), "not support dilate now"
assert groups == 1, "only support groups == 1 for now"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.x86.conv2d_transpose_nchw),
wrap_topi_schedule(topi.x86.schedule_conv2d_transpose_nchw),
name="conv2d_transpose_nchw.x86",
)
if groups == 1:
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.x86.conv2d_transpose_nchw),
wrap_topi_schedule(topi.x86.schedule_conv2d_transpose_nchw),
name="conv2d_transpose_nchw.x86",
)
else:
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.nn.group_conv2d_transpose_nchw, has_groups=True),
wrap_topi_schedule(topi.generic.schedule_group_conv2d_transpose_nchw),
name="group_conv2d_transpose_nchw.x86",
)
return strategy


Expand Down
17 changes: 17 additions & 0 deletions python/tvm/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,23 @@ def schedule_group_conv2d_nchw(outs):
return _default_schedule(outs, False)


def schedule_group_conv2d_transpose_nchw(outs):
"""Schedule for group_conv2d_transpose_nchw

Parameters
----------
outs: Array of Tensor
The computation graph description of group_conv2d_nhwc
in the format of an array of tensors.

Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)


def schedule_group_conv2d_nhwc(outs):
"""Schedule for group_conv2d_nhwc

Expand Down
129 changes: 129 additions & 0 deletions python/tvm/topi/nn/conv2d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
# pylint: disable=invalid-name, unused-variable, unused-argument
"""Transposed 2D convolution operators (sometimes called Deconvolution)."""
import collections

import tvm
from tvm import relay, te

Expand All @@ -25,6 +27,22 @@
from .utils import get_pad_tuple


def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
assert len(x) == n, f"Input can only have {n} elements, but got {len(x)} instead: {x}."
return x
return tuple(repeat(x, n))

return parse


_single = _ntuple(1)
_pair = _ntuple(2)
_triple = _ntuple(3)
_quadruple = _ntuple(4)


def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype, output_padding):
"""Transposed 2D convolution nchw forward operator.

Expand Down Expand Up @@ -116,6 +134,117 @@ def declaration_conv2d_transpose_impl(data, kernel, strides, padding, out_dtype,
return Output


def group_conv2d_transpose_nchw(data, kernel, stride, padding, out_dtype, output_padding, groups):
"""Group convolution operator in NCHW layout.

Parameters
----------
data : tvm.te.Tensor
4-D with shape [batch, in_channel, in_height, in_width]

kernel : tvm.te.Tensor
4-D with shape [in_channel, out_channel // groups, filter_height, filter_width]

stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]

padding : int or a list/tuple of 2 or 4 ints
padding size, or
[pad_height, pad_width] for 2 ints, or
[pad_top, pad_left, pad_bottom, pad_right] for 4 ints

out_dtype : str
The output data type. This is used for mixed precision.

output_padding : tuple of ints
Used to get the right output shape for gradients

groups : int
number of groups

out_dtype : str
The output type. This is used for mixed precision.

Returns
-------
Output : tvm.te.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
if groups == 1:
return conv2d_transpose_nchw(data, kernel, stride, padding, out_dtype, output_padding)

# some pre-processing and prelimnary checks
if out_dtype is None:
out_dtype = data.dtype

batch, in_channels, in_h, in_w = data.shape
_, out_c, filter_h, filter_w = kernel.shape
assert (
in_channels % groups == 0
), f"input channels {in_channels} must divide group size {groups}"
# assert out_c % groups == 0, f"output channels {in_c} must divide group size {groups}"

strides = _pair(stride)
# padding = _pair(padding)
# output_padding = _pair(output_padding)
# dilation = _pair(dilation)

stride_h, stride_w = strides
opad_h, opad_w = output_padding
assert (
opad_h < stride_h and opad_w < stride_w
), f"[{output_padding}] opad_h:{opad_h} < stride_h:{stride_h} \
and opad_w:{opad_w} < stride_w:{stride_w} does not satisfy."
# dilate data
data_dilate = dilate(data, [1, 1, stride_h, stride_w], name="data_dilate")
# pad data
fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (filter_h, filter_w))
bpad_top = filter_h - 1 - fpad_top
bpad_bottom = filter_h - 1 - fpad_bottom + opad_h
bpad_left = filter_w - 1 - fpad_left
bpad_right = filter_w - 1 - fpad_right + opad_w
data_pad = pad(
data_dilate, [0, 0, bpad_top, bpad_left], [0, 0, bpad_bottom, bpad_right], name="data_pad"
)
# transform kernel layout from IOHW to OIHW, and rotate kernel by 180 degrees
kernel_transform = te.compute(
(out_c, in_channels, filter_h, filter_w),
lambda i, o, h, w: kernel[o][i][filter_h - 1 - h][filter_w - 1 - w],
name="kernel_transform",
)

batch, in_channels, in_h, in_w = data_pad.shape
out_c, _, filter_h, filter_w = kernel_transform.shape

# convolution stage
out_channels = simplify(out_c * groups)

out_h = simplify(in_h - filter_h + 1)
out_w = simplify(in_w - filter_w + 1)
dc = te.reduce_axis((0, in_channels // groups), name="dc")
dh = te.reduce_axis((0, filter_h), name="dh")
dw = te.reduce_axis((0, filter_w), name="dw")

# data: batch, in_channels, out_h, out_w
# weight: out_channels // G, in_channels, out_h, out_w
return te.compute(
(batch, out_channels, out_h, out_w),
lambda b, c, h, w: te.sum(
data_pad[
b, c // (out_channels // groups) * (in_channels // groups) + dc, h + dh, w + dw
].astype(out_dtype)
* kernel_transform[
c % (out_channels // groups),
c // (out_channels // groups) * (in_channels // groups) + dc,
dh,
dw,
].astype(out_dtype),
axis=[dc, dh, dw],
),
tag="group_conv2d_transpose_nchw",
)


def layout_transform(tensor: "relay.Expr", current_layout: str, desired_layout: str):
"""Transform a tensor with the current layout to the desired layout.

Expand Down
40 changes: 39 additions & 1 deletion python/tvm/topi/testing/conv2d_transpose_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tvm.topi.nn.utils import get_pad_tuple


def conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding):
def _conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding):
"""Transposed convolution operator in NCHW layout.

Parameters
Expand Down Expand Up @@ -141,3 +141,41 @@ def conv2d_transpose_nhwc_python(
)
res_nhwc = np.transpose(res_nchw, (0, 2, 3, 1))
return res_nhwc


def conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding, groups=1):
"""Convolution operator in NCHW layout.

Parameters
----------
a_np : numpy.ndarray
4-D with shape [batch, in_channel, in_height, in_width]

w_np : numpy.ndarray
4-D with shape [in_channel, num_filter // groups, filter_height, filter_width]

stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]

padding : int or str
Padding size, or ['VALID', 'SAME']

output_padding : int or a list/tuple of two ints
Use to disambiguate the output shape.

groups : int
Number of groups

Returns
-------
b_np : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
a_slices = np.array_split(a_np, groups, axis=1)
w_slices = np.array_split(w_np, groups, axis=0)
b_slices = [
_conv2d_transpose_nchw_python(a_slice, w_slice, stride, padding, output_padding)
for a_slice, w_slice in zip(a_slices, w_slices)
]
b_np = np.concatenate(b_slices, axis=1)
return b_np
21 changes: 13 additions & 8 deletions src/relay/op/nn/convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -1053,18 +1053,18 @@ bool Conv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a

const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
ICHECK(trans_in_layout.defined())
<< "Conv only support input layouts that are convertible from NCHW."
<< "Conv2DTransposed only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;

const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kIOHW);
ICHECK(trans_kernel_layout.defined())
<< "Conv only support kernel layouts that are convertible from IOHW."
<< "Conv2DTransposed only support kernel layouts that are convertible from IOHW."
<< " But got " << kernel_layout;

Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW);
ICHECK(trans_out_layout.defined())
<< "Conv only support output layouts that are convertible from NCHW."
<< "Conv2DTransposed only support output layouts that are convertible from NCHW."
<< " But got " << out_layout;

IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
Expand Down Expand Up @@ -1099,16 +1099,21 @@ bool Conv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
// check the size
ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
reporter->AssertEQ(param->kernel_size[1], wshape[3]))
<< "Conv2D: shape of weight is inconsistent with kernel_size, "
<< "Conv2DTransposed: shape of weight is inconsistent with kernel_size, "
<< " kernel_size=" << param->kernel_size << " wshape=" << Array<IndexExpr>(wshape);
}
if (param->channels.defined()) {
ICHECK(reporter->AssertEQ(param->channels, wshape[1]))
<< "Conv2D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels << " wshape=" << Array<IndexExpr>(wshape);
ICHECK(reporter->AssertEQ(indexdiv(param->channels, param->groups), wshape[1]))
<< "Conv2DTransposed: shape of weight is inconsistent with out_channels, "
<< " out_channels // groups != weight.shape[1] "
<< " out_channels=" << param->channels << " groups=" << param->groups
<< " weight.shape=" << Array<IndexExpr>(wshape);
}
if (!dshape_nchw[1].as<tir::AnyNode>() && !wshape[0].as<tir::AnyNode>()) {
ICHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0]));
ICHECK(reporter->AssertEQ(dshape_nchw[1], wshape[0]))
<< "Conv2DTransposed: shape of weight is inconsistent with in_channels."
<< " data.shape= " << Array<IndexExpr>(dshape_nchw) << " groups= " << param->groups
<< " weight.shape= " << Array<IndexExpr>(wshape);
}
channels = wshape[1];
dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
Expand Down
Loading