diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 5f24dbda9d35..eee5d9a685b3 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -557,7 +557,7 @@ def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target): groups = attrs.groups assert layout == "NCHW", "only support nchw for now" assert dilation == (1, 1), "not support dilate now" - assert groups == 1, "only support groups == 1 for now" + assert groups == 1, "only support groups == 1 when targetting cuda/gpu" strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_conv2d_transpose(topi.cuda.conv2d_transpose_nchw), diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index f3632c4197ea..0a328ca51154 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -446,7 +446,7 @@ def deformable_conv2d_strategy(attrs, inputs, out_type, target): # conv2d_transpose -def wrap_compute_conv2d_transpose(topi_compute): +def wrap_compute_conv2d_transpose(topi_compute, has_groups=False): """wrap conv2d_transpose topi compute""" def compute_conv2d_transpose(attrs, inputs, out_dtype): @@ -456,7 +456,11 @@ def compute_conv2d_transpose(attrs, inputs, out_dtype): out_dtype = attrs.out_dtype out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype output_padding = get_const_tuple(attrs.output_padding) - out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype, output_padding) + # out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype, output_padding) + args = [inputs[0], inputs[1], strides, padding, out_dtype, output_padding] + if has_groups: + args.append(attrs.groups) + out = topi_compute(*args) return [out] return compute_conv2d_transpose @@ -471,13 +475,19 @@ def conv2d_transpose_strategy(attrs, inputs, out_type, target): groups = attrs.groups assert layout == "NCHW", "only support nchw for now" assert dilation == (1, 1), "not support dilate now" - assert groups == 1, "only support groups == 1 for now" strategy = _op.OpStrategy() - strategy.add_implementation( - wrap_compute_conv2d_transpose(topi.nn.conv2d_transpose_nchw), - wrap_topi_schedule(topi.generic.schedule_conv2d_transpose_nchw), - name="conv2d_transpose_nchw.generic", - ) + if groups == 1: + strategy.add_implementation( + wrap_compute_conv2d_transpose(topi.nn.conv2d_transpose_nchw), + wrap_topi_schedule(topi.generic.schedule_conv2d_transpose_nchw), + name="conv2d_transpose_nchw.generic", + ) + else: # group_transpose_conv2d + strategy.add_implementation( + wrap_compute_conv2d_transpose(topi.nn.group_conv2d_transpose_nchw, has_groups=True), + wrap_topi_schedule(topi.generic.schedule_group_conv2d_transpose_nchw), + name="group_conv2d_transpose_nchw.generic", + ) return strategy diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 1c8d1b478cb1..a421b120fab4 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -281,13 +281,19 @@ def conv2d_transpose_strategy_cpu(attrs, inputs, out_type, target): groups = attrs.groups assert layout == "NCHW", "only support nchw for now" assert dilation == (1, 1), "not support dilate now" - assert groups == 1, "only support groups == 1 for now" strategy = _op.OpStrategy() - strategy.add_implementation( - wrap_compute_conv2d_transpose(topi.x86.conv2d_transpose_nchw), - wrap_topi_schedule(topi.x86.schedule_conv2d_transpose_nchw), - name="conv2d_transpose_nchw.x86", - ) + if groups == 1: + strategy.add_implementation( + wrap_compute_conv2d_transpose(topi.x86.conv2d_transpose_nchw), + wrap_topi_schedule(topi.x86.schedule_conv2d_transpose_nchw), + name="conv2d_transpose_nchw.x86", + ) + else: + strategy.add_implementation( + wrap_compute_conv2d_transpose(topi.nn.group_conv2d_transpose_nchw, has_groups=True), + wrap_topi_schedule(topi.generic.schedule_group_conv2d_transpose_nchw), + name="group_conv2d_transpose_nchw.x86", + ) return strategy diff --git a/python/tvm/topi/generic/nn.py b/python/tvm/topi/generic/nn.py index 1b3214154687..22a90aa2cd07 100644 --- a/python/tvm/topi/generic/nn.py +++ b/python/tvm/topi/generic/nn.py @@ -428,6 +428,23 @@ def schedule_group_conv2d_nchw(outs): return _default_schedule(outs, False) +def schedule_group_conv2d_transpose_nchw(outs): + """Schedule for group_conv2d_transpose_nchw + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of group_conv2d_nhwc + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + def schedule_group_conv2d_nhwc(outs): """Schedule for group_conv2d_nhwc diff --git a/python/tvm/topi/nn/conv2d_transpose.py b/python/tvm/topi/nn/conv2d_transpose.py index 99c7442240c7..2871699350ed 100644 --- a/python/tvm/topi/nn/conv2d_transpose.py +++ b/python/tvm/topi/nn/conv2d_transpose.py @@ -16,6 +16,8 @@ # under the License. # pylint: disable=invalid-name, unused-variable, unused-argument """Transposed 2D convolution operators (sometimes called Deconvolution).""" +import collections + import tvm from tvm import relay, te @@ -25,6 +27,22 @@ from .utils import get_pad_tuple +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + assert len(x) == n, f"Input can only have {n} elements, but got {len(x)} instead: {x}." + return x + return tuple(repeat(x, n)) + + return parse + + +_single = _ntuple(1) +_pair = _ntuple(2) +_triple = _ntuple(3) +_quadruple = _ntuple(4) + + def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype, output_padding): """Transposed 2D convolution nchw forward operator. @@ -116,6 +134,117 @@ def declaration_conv2d_transpose_impl(data, kernel, strides, padding, out_dtype, return Output +def group_conv2d_transpose_nchw(data, kernel, stride, padding, out_dtype, output_padding, groups): + """Group convolution operator in NCHW layout. + + Parameters + ---------- + data : tvm.te.Tensor + 4-D with shape [batch, in_channel, in_height, in_width] + + kernel : tvm.te.Tensor + 4-D with shape [in_channel, out_channel // groups, filter_height, filter_width] + + stride : int or a list/tuple of two ints + Stride size, or [stride_height, stride_width] + + padding : int or a list/tuple of 2 or 4 ints + padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints + + out_dtype : str + The output data type. This is used for mixed precision. + + output_padding : tuple of ints + Used to get the right output shape for gradients + + groups : int + number of groups + + out_dtype : str + The output type. This is used for mixed precision. + + Returns + ------- + Output : tvm.te.Tensor + 4-D with shape [batch, out_channel, out_height, out_width] + """ + if groups == 1: + return conv2d_transpose_nchw(data, kernel, stride, padding, out_dtype, output_padding) + + # some pre-processing and prelimnary checks + if out_dtype is None: + out_dtype = data.dtype + + batch, in_channels, in_h, in_w = data.shape + _, out_c, filter_h, filter_w = kernel.shape + assert ( + in_channels % groups == 0 + ), f"input channels {in_channels} must divide group size {groups}" + # assert out_c % groups == 0, f"output channels {in_c} must divide group size {groups}" + + strides = _pair(stride) + # padding = _pair(padding) + # output_padding = _pair(output_padding) + # dilation = _pair(dilation) + + stride_h, stride_w = strides + opad_h, opad_w = output_padding + assert ( + opad_h < stride_h and opad_w < stride_w + ), f"[{output_padding}] opad_h:{opad_h} < stride_h:{stride_h} \ + and opad_w:{opad_w} < stride_w:{stride_w} does not satisfy." + # dilate data + data_dilate = dilate(data, [1, 1, stride_h, stride_w], name="data_dilate") + # pad data + fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (filter_h, filter_w)) + bpad_top = filter_h - 1 - fpad_top + bpad_bottom = filter_h - 1 - fpad_bottom + opad_h + bpad_left = filter_w - 1 - fpad_left + bpad_right = filter_w - 1 - fpad_right + opad_w + data_pad = pad( + data_dilate, [0, 0, bpad_top, bpad_left], [0, 0, bpad_bottom, bpad_right], name="data_pad" + ) + # transform kernel layout from IOHW to OIHW, and rotate kernel by 180 degrees + kernel_transform = te.compute( + (out_c, in_channels, filter_h, filter_w), + lambda i, o, h, w: kernel[o][i][filter_h - 1 - h][filter_w - 1 - w], + name="kernel_transform", + ) + + batch, in_channels, in_h, in_w = data_pad.shape + out_c, _, filter_h, filter_w = kernel_transform.shape + + # convolution stage + out_channels = simplify(out_c * groups) + + out_h = simplify(in_h - filter_h + 1) + out_w = simplify(in_w - filter_w + 1) + dc = te.reduce_axis((0, in_channels // groups), name="dc") + dh = te.reduce_axis((0, filter_h), name="dh") + dw = te.reduce_axis((0, filter_w), name="dw") + + # data: batch, in_channels, out_h, out_w + # weight: out_channels // G, in_channels, out_h, out_w + return te.compute( + (batch, out_channels, out_h, out_w), + lambda b, c, h, w: te.sum( + data_pad[ + b, c // (out_channels // groups) * (in_channels // groups) + dc, h + dh, w + dw + ].astype(out_dtype) + * kernel_transform[ + c % (out_channels // groups), + c // (out_channels // groups) * (in_channels // groups) + dc, + dh, + dw, + ].astype(out_dtype), + axis=[dc, dh, dw], + ), + tag="group_conv2d_transpose_nchw", + ) + + def layout_transform(tensor: "relay.Expr", current_layout: str, desired_layout: str): """Transform a tensor with the current layout to the desired layout. diff --git a/python/tvm/topi/testing/conv2d_transpose_python.py b/python/tvm/topi/testing/conv2d_transpose_python.py index c7c0d9f2529a..a38d8bc9f031 100644 --- a/python/tvm/topi/testing/conv2d_transpose_python.py +++ b/python/tvm/topi/testing/conv2d_transpose_python.py @@ -22,7 +22,7 @@ from tvm.topi.nn.utils import get_pad_tuple -def conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding): +def _conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding): """Transposed convolution operator in NCHW layout. Parameters @@ -141,3 +141,41 @@ def conv2d_transpose_nhwc_python( ) res_nhwc = np.transpose(res_nchw, (0, 2, 3, 1)) return res_nhwc + + +def conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding, groups=1): + """Convolution operator in NCHW layout. + + Parameters + ---------- + a_np : numpy.ndarray + 4-D with shape [batch, in_channel, in_height, in_width] + + w_np : numpy.ndarray + 4-D with shape [in_channel, num_filter // groups, filter_height, filter_width] + + stride : int or a list/tuple of two ints + Stride size, or [stride_height, stride_width] + + padding : int or str + Padding size, or ['VALID', 'SAME'] + + output_padding : int or a list/tuple of two ints + Use to disambiguate the output shape. + + groups : int + Number of groups + + Returns + ------- + b_np : np.ndarray + 4-D with shape [batch, out_channel, out_height, out_width] + """ + a_slices = np.array_split(a_np, groups, axis=1) + w_slices = np.array_split(w_np, groups, axis=0) + b_slices = [ + _conv2d_transpose_nchw_python(a_slice, w_slice, stride, padding, output_padding) + for a_slice, w_slice in zip(a_slices, w_slices) + ] + b_np = np.concatenate(b_slices, axis=1) + return b_np diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index d9958076adc1..e5e64fa5be65 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -1053,18 +1053,18 @@ bool Conv2DTransposeRel(const Array& types, int num_inputs, const Attrs& a const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); ICHECK(trans_in_layout.defined()) - << "Conv only support input layouts that are convertible from NCHW." + << "Conv2DTransposed only support input layouts that are convertible from NCHW." << " But got " << in_layout; const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kIOHW); ICHECK(trans_kernel_layout.defined()) - << "Conv only support kernel layouts that are convertible from IOHW." + << "Conv2DTransposed only support kernel layouts that are convertible from IOHW." << " But got " << kernel_layout; Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); ICHECK(trans_out_layout.defined()) - << "Conv only support output layouts that are convertible from NCHW." + << "Conv2DTransposed only support output layouts that are convertible from NCHW." << " But got " << out_layout; IndexExpr channels, dilated_ksize_y, dilated_ksize_x; @@ -1099,16 +1099,21 @@ bool Conv2DTransposeRel(const Array& types, int num_inputs, const Attrs& a // check the size ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && reporter->AssertEQ(param->kernel_size[1], wshape[3])) - << "Conv2D: shape of weight is inconsistent with kernel_size, " + << "Conv2DTransposed: shape of weight is inconsistent with kernel_size, " << " kernel_size=" << param->kernel_size << " wshape=" << Array(wshape); } if (param->channels.defined()) { - ICHECK(reporter->AssertEQ(param->channels, wshape[1])) - << "Conv2D: shape of weight is inconsistent with channels, " - << " channels=" << param->channels << " wshape=" << Array(wshape); + ICHECK(reporter->AssertEQ(indexdiv(param->channels, param->groups), wshape[1])) + << "Conv2DTransposed: shape of weight is inconsistent with out_channels, " + << " out_channels // groups != weight.shape[1] " + << " out_channels=" << param->channels << " groups=" << param->groups + << " weight.shape=" << Array(wshape); } if (!dshape_nchw[1].as() && !wshape[0].as()) { - ICHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0])); + ICHECK(reporter->AssertEQ(dshape_nchw[1], wshape[0])) + << "Conv2DTransposed: shape of weight is inconsistent with in_channels." + << " data.shape= " << Array(dshape_nchw) << " groups= " << param->groups + << " weight.shape= " << Array(wshape); } channels = wshape[1]; dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; diff --git a/tests/python/topi/python/test_topi_group_conv2d_transpose.py b/tests/python/topi/python/test_topi_group_conv2d_transpose.py new file mode 100644 index 000000000000..90b7500c6cd4 --- /dev/null +++ b/tests/python/topi/python/test_topi_group_conv2d_transpose.py @@ -0,0 +1,156 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Example code to do group transpose convolution.""" + +import numpy as np +import tvm +import tvm.testing +import tvm.topi.testing +from tvm import te, topi +from tvm.contrib.pickle_memoize import memoize +from tvm.topi.utils import get_const_tuple + +_group_conv2d_nchw_implement = { + "generic": ( + topi.nn.group_conv2d_transpose_nchw, + topi.generic.schedule_group_conv2d_transpose_nchw, + ), +} + + +def verify_group_conv2d_transpose_nchw( + batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding, + output_padding, + groups, +): + print( + "Workload: (%d, %d, %s, %d, %s, %s, %s, %s, %d)" + % (batch, in_channel, in_size, num_filter, kernel, stride, padding, output_padding, groups) + ) + + in_height, in_width = in_size + kernel_height, kernel_width = kernel + + A = te.placeholder((batch, in_channel, in_height, in_width), name="A") + W = te.placeholder((in_channel, num_filter // groups, kernel_height, kernel_width), name="W") + bias = te.placeholder((num_filter, 1, 1), name="bias") + + a_shape = get_const_tuple(A.shape) + w_shape = get_const_tuple(W.shape) + bias_shape = get_const_tuple(bias.shape) + dtype = A.dtype + + @memoize("topi.tests.test_topi_group_conv2d_transpose.verify_group_conv2d_transpose_nchw") + def get_ref_data(): + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(dtype) + b_np = np.random.uniform(size=bias_shape).astype(dtype) + c_np = tvm.topi.testing.conv2d_transpose_nchw_python( + a_np, w_np, stride, padding, output_padding, groups + ).astype(dtype) + + return a_np, w_np, b_np, c_np + + a_np, w_np, b_np, c_np = get_ref_data() + + def check_target(target): + dev = tvm.device(target, 0) + if not tvm.testing.device_enabled(target): + print("Skip because %s is not enabled" % target) + return + + print("Running on target: %s" % target) + with tvm.target.Target(target): + fcompute, fschedule = tvm.topi.testing.dispatch(target, _group_conv2d_nchw_implement) + C = fcompute(A, W, stride, padding, dtype, output_padding, groups) + s = fschedule([C]) + + a = tvm.nd.array(a_np, dev) + w = tvm.nd.array(w_np, dev) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev) + func = tvm.build( + s, + [A, W, C], + target, + name="group_conv2d_transpose_%d_%d_%s_%d_%s_%s_%s_%s_%d" + % ( + batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding, + output_padding, + groups, + ), + ) + func(a, w, c) + c = c.numpy() + for measurement, reference in zip(c, c_np): + tvm.testing.assert_allclose(measurement, reference, rtol=1e-5) + + for target in ["llvm"]: + check_target(target) + + +@tvm.testing.uses_gpu +def test_group_conv2d_transpose_nchw(): + verify_group_conv2d_transpose_nchw(1, 1, (224, 224), 1, (1, 1), (1, 1), (0, 0, 0, 0), (0, 0), 1) + verify_group_conv2d_transpose_nchw( + 1, 3, (224, 224), 32, (3, 3), (1, 1), (0, 0, 0, 0), (0, 0), 1 + ) + verify_group_conv2d_transpose_nchw( + 1, 3, (224, 224), 32, (3, 3), (3, 3), (0, 0, 0, 0), (0, 0), 1 + ) + verify_group_conv2d_transpose_nchw( + 1, 3, (224, 224), 32, (3, 3), (1, 1), (0, 0, 0, 0), (0, 0), 1 + ) + verify_group_conv2d_transpose_nchw( + 1, 3, (224, 224), 32, (3, 3), (2, 2), (1, 1, 1, 1), (0, 0), 1 + ) + verify_group_conv2d_transpose_nchw(1, 4, (32, 32), 4, (5, 5), (1, 1), (0, 0, 0, 0), (0, 0), 2) + verify_group_conv2d_transpose_nchw(1, 9, (32, 32), 9, (5, 5), (1, 1), (0, 0, 0, 0), (0, 0), 3) + verify_group_conv2d_transpose_nchw(1, 4, (32, 32), 16, (5, 5), (2, 2), (1, 1, 1, 1), (0, 0), 4) + verify_group_conv2d_transpose_nchw( + 1, 32, (8192, 1), 8, (31, 1), (2, 1), (14, 0, 15, 0), (0, 0), 2 + ) + verify_group_conv2d_transpose_nchw( + 1, 512, (8, 1), 256, (31, 1), (2, 1), (14, 0, 15, 0), (0, 0), 16 + ) + verify_group_conv2d_transpose_nchw( + 1, 512, (8, 1), 256, (31, 1), (2, 1), (14, 0, 15, 0), (1, 0), 16 + ) + verify_group_conv2d_transpose_nchw( + 1, 64, (64, 64), 64, (4, 4), (1, 1), (0, 0, 0, 0), (0, 0), 64 + ) + verify_group_conv2d_transpose_nchw( + 1, 128, (32, 32), 128, (4, 4), (1, 1), (0, 0, 0, 0), (0, 0), 128 + ) + verify_group_conv2d_transpose_nchw( + 1, 256, (16, 16), 256, (4, 4), (1, 1), (0, 0, 0, 0), (0, 0), 256 + ) + + +if __name__ == "__main__": + test_group_conv2d_transpose_nchw()