From 4c2e3cde7dcfde264eb6d763dc58e6318f42c41f Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Sun, 5 Apr 2020 14:59:38 -0700 Subject: [PATCH] [Relay][Topi][AutoTVM] Winograd support for Conv3D (#5186) * Functional conv3d winograd working. * Formatted python code. * registered conv3d winograd compute and started adding relay without_weight_transform operator. * Add topi testing for conv3d winograd. * Format file. * small tweak to unrolling to prevent build sticking. * Refactoring convolution ops in relay. * Refactored relay convolutions. * Bug fixes. * Fixed static bug in convolution. * Added conv3d alter op layout and related support. * Bug fixes and testing done. * Fix a few autotvm bugs. * Drop silly debug print. * Removed debug_skip_region. * Add variant of conv3d_winograd that doesn't transform depth. * initial infrastructure done for depthless conv. * Fix no_depth schedule bugs. * automatic topi switching between depth and depthless winograd. * Fixed bug in schedule. * lint fixes. * Removed indents in convolution.cc * missed a few indents oops. * fixed flop count. * One more small tweak. * Change kernel pack inner axes order. * Style changes. * Comment fixes. --- docs/langref/relay_op.rst | 5 + include/tvm/relay/attrs/nn.h | 71 +- python/tvm/relay/op/nn/_nn.py | 23 + python/tvm/relay/op/nn/nn.py | 100 +- python/tvm/relay/op/nn/util.py | 43 + python/tvm/relay/op/op_attrs.py | 16 +- python/tvm/relay/op/strategy/cuda.py | 32 + python/tvm/relay/op/strategy/generic.py | 13 + src/relay/op/nn/convolution.cc | 900 ++++++------------ src/relay/op/nn/convolution.h | 531 +++++++++++ tests/python/relay/test_op_level2.py | 92 +- topi/python/topi/cuda/__init__.py | 2 + topi/python/topi/cuda/conv3d_alter_op.py | 95 ++ topi/python/topi/cuda/conv3d_winograd.py | 627 ++++++++++++ topi/python/topi/generic/nn.py | 37 + topi/python/topi/nn/conv3d.py | 75 +- .../tests/python/test_topi_conv3d_winograd.py | 151 +++ 17 files changed, 2192 insertions(+), 621 deletions(-) create mode 100644 topi/python/topi/cuda/conv3d_alter_op.py create mode 100644 topi/python/topi/cuda/conv3d_winograd.py create mode 100644 topi/tests/python/test_topi_conv3d_winograd.py diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index ac636f81ab3d6..f1d7d442a14ce 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -82,8 +82,13 @@ This level enables typical convnet models. tvm.relay.nn.pad tvm.relay.nn.lrn tvm.relay.nn.l2_normalize + tvm.relay.nn.bitpack + tvm.relay.nn.bitserial_dense + tvm.relay.nn.bitserial_conv2d tvm.relay.nn.contrib_conv2d_winograd_without_weight_transform tvm.relay.nn.contrib_conv2d_winograd_weight_transform + tvm.relay.nn.contrib_conv3d_winograd_without_weight_transform + tvm.relay.nn.contrib_conv3d_winograd_weight_transform **Level 3: Additional Math And Transform Operators** diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 5794ddd0217b2..536e4145db292 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -156,12 +156,12 @@ struct Conv2DAttrs : public tvm::AttrsNode { }; /*! \brief Attributes used in winograd weight transformation operators */ -struct Conv2DWinogradWeightTransformAttrs : - public tvm::AttrsNode { +struct ConvWinogradWeightTransformAttrs : + public tvm::AttrsNode { int tile_size; - TVM_DECLARE_ATTRS(Conv2DWinogradWeightTransformAttrs, - "relay.attrs.Conv2DWinogradWeightTransformAttrs") { + TVM_DECLARE_ATTRS(ConvWinogradWeightTransformAttrs, + "relay.attrs.ConvWinogradWeightTransformAttrs") { TVM_ATTR_FIELD(tile_size) .describe("Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)"); } @@ -306,6 +306,69 @@ struct Conv3DAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes used in 3d winograd convolution operators */ +struct Conv3DWinogradAttrs : public tvm::AttrsNode { + int tile_size; + Array strides; + Array padding; + Array dilation; + int groups; + IndexExpr channels; + Array kernel_size; + std::string data_layout; + std::string kernel_layout; + std::string out_layout; + DataType out_dtype; + + TVM_DECLARE_ATTRS(Conv3DWinogradAttrs, "relay.attrs.Conv3DWinogradAttrs") { + TVM_ATTR_FIELD(tile_size) + .describe("The tile size of winograd. E.g. 2 for F(2x2x2, 3x3x3) and 4 for F(4x4x4, 3x3x3)"); + TVM_ATTR_FIELD(strides).set_default(Array({1, 1, 1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding).set_default(Array({0, 0, 0})) + .describe("If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "three int : back, bottom, right will use same padding as front, top, left" + "six int : padding width in the order of (front, top, left, back, bottom," + "right)"); + TVM_ATTR_FIELD(dilation).set_default(Array({1, 1, 1})) + .describe("Specifies the dilation rate to use for dilated convolution."); + TVM_ATTR_FIELD(groups).set_default(1) + .describe("Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(channels) + .describe("The number of output channels in the convolution." + " If it is not set, inferred by shape of the weight.") + .set_default(NullValue()); + TVM_ATTR_FIELD(kernel_size) + .describe("Specifies the dimensions of the convolution window.") + .set_default(NullValue >()); + TVM_ATTR_FIELD(data_layout).set_default("NCDHW") + .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Convolution is applied on the 'D', 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout).set_default("OIDHW") + .describe("Dimension ordering of weight. Can be 'OIDHW', 'OIDHW16o16i', etc." + "'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height," + "and width dimensions respectively."); + TVM_ATTR_FIELD(out_layout).set_default("") + .describe("Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Default to be same as input layout."); + + // use 0 bits to indicate none. + TVM_ATTR_FIELD(out_dtype) + .set_default(NullValue()) + .describe("Output data type, set to explicit type under mixed precision setting"); + } +}; + + /*! \brief Attributes used in softmax operators */ struct SoftmaxAttrs : public tvm::AttrsNode { int axis; diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 65a1162d6a2d8..39d98c0a81ab8 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -178,6 +178,29 @@ def legalize_conv2d_transpose(attrs, inputs, types): reg.register_strategy("nn.conv3d", strategy.conv3d_strategy) reg.register_pattern("nn.conv3d", OpPattern.OUT_ELEMWISE_FUSABLE) +@reg.register_alter_op_layout("nn.conv3d") +def alter_op_layout_conv3d(attrs, inputs, tinfos, out_type): + """Alternate the layout of conv3d""" + return topi.nn.conv3d_alter_layout(attrs, inputs, tinfos, out_type) + +# conv3d_winograd related operators +reg.register_strategy("nn.contrib_conv3d_winograd_without_weight_transform", + strategy.conv3d_winograd_without_weight_transfrom_strategy) +reg.register_pattern("nn.contrib_conv3d_winograd_without_weight_transform", + OpPattern.OUT_ELEMWISE_FUSABLE) + +@reg.register_compute("nn.contrib_conv3d_winograd_weight_transform") +def compute_contrib_conv3d_winograd_weight_transform(attrs, inputs, out_dtype): + """Compute definition of contrib_conv3d_winograd_weight_transform""" + out = topi.nn.conv3d_winograd_weight_transform( + inputs[0], attrs.get_int('tile_size')) + return [out] + +reg.register_schedule("nn.contrib_conv3d_winograd_weight_transform", + strategy.schedule_conv3d_winograd_weight_transform) +reg.register_pattern("nn.contrib_conv3d_winograd_weight_transform", + OpPattern.OUT_ELEMWISE_FUSABLE) + # conv1d_transpose reg.register_strategy("nn.conv1d_transpose", strategy.conv1d_transpose_strategy) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 64148bb2414e5..a126e8dcba94a 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -19,7 +19,7 @@ from __future__ import absolute_import as _abs from ...expr import TupleWrapper from . import _make -from .util import get_pad_tuple2d +from .util import get_pad_tuple2d, get_pad_tuple3d def conv1d(data, @@ -295,13 +295,84 @@ def conv3d(data, strides = (strides, strides, strides) if isinstance(dilation, int): dilation = (dilation, dilation, dilation) - if isinstance(padding, int): - padding = (padding, padding, padding) + padding = get_pad_tuple3d(padding) return _make.conv3d(data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout, kernel_layout, out_layout, out_dtype) +def contrib_conv3d_winograd_without_weight_transform(data, + weight, + tile_size, + strides=(1, 1, 1), + padding=(0, 0, 0), + dilation=(1, 1, 1), + groups=1, + channels=None, + kernel_size=None, + data_layout="NCDHW", + kernel_layout="OIDHW", + out_layout="", + out_dtype=""): + r"""3D convolution with winograd algorithm. + + The basic parameters are the same as the ones in vanilla conv3d. + It assumes the weight is pre-transformed by nn.contrib_conv3d_winograd_weight_transform + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + weight : tvm.relay.Expr + The weight expressions. + + tile_size : int + The Tile size of winograd. E.g. 2 for F(2x2x2, 3x3x3) and 4 for F(4x4x4, 3x3x3) + + strides : tuple of int, optional + The strides of convolution. + + padding : tuple of int, optional + The padding of convolution on both sides of inputs before convolution. + + dilation : tuple of int, optional + Specifies the dilation rate to be used for dilated convolution. + + groups : int, optional + Number of groups for grouped convolution. + + channels : int, optional + Number of output channels of this convolution. + + kernel_size : tuple of int, optional + The spatial of the convolution kernel. + + data_layout : str, optional + Layout of the input. + + kernel_layout : str, optional + Layout of the weight. + + out_layout : str, optional + Layout of the output, by default, out_layout is the same as data_layout + + out_dtype : str, optional + Specifies the output data type for mixed precision conv2d. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + # convert 3-way padding to 6-way padding + padding = get_pad_tuple3d(padding) + return _make.contrib_conv3d_winograd_without_weight_transform( + data, weight, tile_size, strides, padding, dilation, + groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, out_dtype) + + def conv2d_transpose(data, weight, strides=(1, 1), @@ -1952,6 +2023,29 @@ def contrib_conv2d_winograd_weight_transform(weight, return _make.contrib_conv2d_winograd_weight_transform(weight, tile_size) +def contrib_conv3d_winograd_weight_transform(weight, + tile_size): + r"""Weight Transformation part for 3D convolution with winograd algorithm. + + We separate this as a single op to enable pre-compute for inference. + Use this together with nn.contrib_conv3d_winograd_without_weight_transform + + Parameters + ---------- + weight : tvm.relay.Expr + The weight expressions. + + tile_size : int + The Tile size of winograd. E.g. 2 for F(2x2x2, 3x3x3) and 4 for F(4x4x4, 3x3x3) + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + return _make.contrib_conv3d_winograd_weight_transform(weight, tile_size) + + def contrib_conv2d_winograd_nnpack_weight_transform(weight, convolution_algorithm, out_dtype=""): diff --git a/python/tvm/relay/op/nn/util.py b/python/tvm/relay/op/nn/util.py index 323ef7f9310e2..1fdcad73c74ef 100644 --- a/python/tvm/relay/op/nn/util.py +++ b/python/tvm/relay/op/nn/util.py @@ -54,3 +54,46 @@ def get_pad_tuple2d(padding): pad_top = (pad_h + 1) // 2 pad_left = (pad_w + 1) // 2 return pad_top, pad_left, pad_h - pad_top, pad_w - pad_left + + +def get_pad_tuple3d(padding): + """Common code to get the pad option + Parameters + ---------- + padding : Union[int, Tuple[int, ...]] + Padding size + Returns + ------- + pad_front : int + Padding size on front + pad_top : int + Padding size on top + pad_left : int + Padding size on left + pad_back : int + Padding size on back + pad_down : int + Padding size on down. + pad_right : int + Padding size on right. + """ + # compute the padding size + if isinstance(padding, container.Array): + padding = list(padding) + if isinstance(padding, (tuple, list)): + if len(padding) == 3: + pad_d = padding[0] * 2 + pad_h = padding[1] * 2 + pad_w = padding[2] * 2 + elif len(padding) == 6: + return padding[0], padding[1], padding[2], padding[3], padding[4], padding[5] + else: + raise ValueError("Size of padding can only be 3 or 6") + elif isinstance(padding, int): + pad_d = pad_h = pad_w = padding * 2 + else: + raise ValueError("Unknown padding option %s" % padding) + pad_front = (pad_d + 1) // 2 + pad_top = (pad_h + 1) // 2 + pad_left = (pad_w + 1) // 2 + return pad_front, pad_top, pad_left, pad_d - pad_front, pad_h - pad_top, pad_w - pad_left diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index edc2160e38bc4..1a07486cd095f 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -34,9 +34,19 @@ class Conv2DWinogradAttrs(Attrs): """Attributes for nn.contrib_conv2d_winograd_without_weight_transform""" -@tvm._ffi.register_object("relay.attrs.Conv2DWinogradWeightTransformAttrs") -class Conv2DWinogradWeightTransformAttrs(Attrs): - """Attributes for nn.contrib_conv2d_winograd_weight_transform""" +@tvm._ffi.register_object("relay.attrs.Conv3DAttrs") +class Conv3DAttrs(Attrs): + """Attributes for nn.conv3d""" + + +@tvm._ffi.register_object("relay.attrs.Conv3DWinogradAttrs") +class Conv3DWinogradAttrs(Attrs): + """Attributes for nn.contrib_conv3d_winograd_without_weight_transform""" + + +@tvm._ffi.register_object("relay.attrs.ConvWinogradWeightTransformAttrs") +class ConvWinogradWeightTransformAttrs(Attrs): + """Attributes for nn.contrib_convNd_winograd_weight_transform""" @tvm._ffi.register_object("relay.attrs.Conv2DWinogradNNPACKWeightTransformAttrs") diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index db03c5965470d..45ee7016912e1 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -233,13 +233,25 @@ def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target): def conv3d_strategy_cuda(attrs, inputs, out_type, target): """conv3d cuda strategy""" strategy = _op.OpStrategy() + _, kernel = inputs layout = attrs.data_layout + _, stride_h, stride_w = attrs.get_int_tuple("strides") + _, dilation_h, dilation_w = attrs.get_int_tuple("dilation") assert layout in ["NCDHW", "NDHWC"], "Not support this layout {} yet".format(layout) if layout == "NCDHW": strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_ncdhw), wrap_topi_schedule(topi.cuda.schedule_conv3d_ncdhw), name="conv3d_ncdhw.cuda", plevel=10) + _, _, _, kh, kw = get_const_tuple(kernel.shape) + if 2 < kh < 8 and 2 < kw < 8 and kh == kw and \ + stride_h == 1 and stride_w == 1 and \ + dilation_h == 1 and dilation_w == 1: + strategy.add_implementation( + wrap_compute_conv3d(topi.cuda.conv3d_ncdhw_winograd), + wrap_topi_schedule(topi.cuda.schedule_conv3d_ncdhw_winograd), + name="conv3d_ncdhw_winograd.cuda", + plevel=5) else: # layout == "NDHWC": strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_ndhwc), wrap_topi_schedule(topi.cuda.schedule_conv3d_ndhwc), @@ -252,6 +264,26 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target): plevel=15) return strategy +@conv3d_winograd_without_weight_transfrom_strategy.register(["cuda", "gpu"]) +def conv3d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_type, target): + """conv3d_winograd_without_weight_transfrom cuda strategy""" + dilation = attrs.get_int_tuple("dilation") + groups = attrs.get_int("groups") + layout = attrs.data_layout + assert dilation == (1, 1, 1), "Do not support dilate now" + assert groups == 1, "Do not supoort arbitrary group number" + strategy = _op.OpStrategy() + if layout == "NCDHW": + strategy.add_implementation( + wrap_compute_conv3d(topi.cuda.conv3d_ncdhw_winograd_without_weight_transform), + wrap_topi_schedule( + topi.cuda.schedule_conv3d_ncdhw_winograd_without_weight_transform), + name="conv3d_ncdhw_winograd_without_weight_transform.cuda") + else: + raise RuntimeError("Unsupported conv3d_winograd_without_weight_transfrom layout {}". + format(layout)) + return strategy + @conv1d_strategy.register(["cuda", "gpu"]) def conv1d_strategy_cuda(attrs, inputs, out_type, target): """conv1d cuda strategy""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 573df3675eee8..388e104dca290 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -374,6 +374,19 @@ def conv3d_strategy(attrs, inputs, out_type, target): raise ValueError("Not support this layout {} yet".format(layout)) return strategy +# conv3d_winograd_without_weight_transform +@override_native_generic_func("conv3d_winograd_without_weight_transform_strategy") +def conv3d_winograd_without_weight_transfrom_strategy(attrs, inputs, out_type, target): + """conv3d_winograd_without_weight_transfrom generic strategy""" + raise ValueError("No generic implemenation for conv3d_winograd_without_weight_transform") + +# conv3d_winograd_weight_transform +@generic_func +def schedule_conv3d_winograd_weight_transform(attrs, outs, target): + """Schedule conv3d_winograd_weight_transform""" + with target: + return topi.generic.schedule_conv3d_winograd_weight_transform(outs) + # conv1d def wrap_compute_conv1d(topi_compute): """wrap conv1d topi compute""" diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 547d5a6ff6921..66dab57fd9479 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -59,10 +59,113 @@ Expr MakeConv(Expr data, attrs->kernel_layout = std::move(kernel_layout); attrs->out_layout = std::move(out_layout); attrs->out_dtype = std::move(out_dtype); - static const Op& op = Op::Get(op_name); + const Op& op = Op::Get(op_name); return Call(op, {data, weight}, Attrs(attrs), {}); } +template +Expr MakeConvWinograd(Expr data, + Expr weight, + int tile_size, + Array strides, + Array padding, + Array dilation, + int groups, + IndexExpr channels, + Array kernel_size, + std::string data_layout, + std::string kernel_layout, + std::string out_layout, + DataType out_dtype, + std::string op_name) { + auto attrs = make_object(); + attrs->tile_size = tile_size; + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->dilation = std::move(dilation); + attrs->groups = groups; + attrs->channels = std::move(channels); + attrs->kernel_size = std::move(kernel_size); + attrs->data_layout = std::move(data_layout); + attrs->kernel_layout = std::move(kernel_layout); + attrs->out_layout = std::move(out_layout); + attrs->out_dtype = std::move(out_dtype); + const Op& op = Op::Get(op_name); + return Call(op, {data, weight}, Attrs(attrs), {}); +} + +Expr MakeConvWinogradWeightTransform(Expr weight, + int tile_size, + std::string op_name) { + auto attrs = make_object(); + attrs->tile_size = tile_size; + const Op& op = Op::Get(op_name); + return Call(op, {weight}, Attrs(attrs), {}); +} + +template +Expr MakeConvTranspose(Expr data, + Expr weight, + Array strides, + Array padding, + Array dilation, + int groups, + IndexExpr channels, + Array kernel_size, + std::string data_layout, + std::string kernel_layout, + std::string out_layout, + Array output_padding, + DataType out_dtype, + std::string op_name) { + auto attrs = make_object(); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->dilation = std::move(dilation); + attrs->groups = groups; + attrs->channels = std::move(channels); + attrs->kernel_size = std::move(kernel_size); + attrs->data_layout = std::move(data_layout); + attrs->kernel_layout = std::move(kernel_layout); + attrs->out_layout = std::move(out_layout); + attrs->output_padding = std::move(output_padding); + attrs->out_dtype = std::move(out_dtype); + const Op& op = Op::Get(op_name); + return Call(op, {data, weight}, Attrs(attrs), {}); +} + +template +Expr MakeDeformableConv(Expr data, + Expr offset, + Expr weight, + Array strides, + Array padding, + Array dilation, + int deformable_groups, + int groups, + int channels, + Array kernel_size, + std::string data_layout, + std::string kernel_layout, + std::string out_layout, + DataType out_dtype, + std::string op_name) { + auto attrs = make_object(); + attrs->strides = strides; + attrs->padding = padding; + attrs->dilation = dilation; + attrs->deformable_groups = deformable_groups; + attrs->groups = groups; + attrs->channels = channels; + attrs->kernel_size = kernel_size; + attrs->data_layout = data_layout; + attrs->kernel_layout = kernel_layout; + attrs->out_layout = out_layout; + attrs->out_dtype = out_dtype; + const Op& op = Op::Get(op_name); + return Call(op, {data, offset, weight}, Attrs{attrs}, {}); +} + // relay.nn.conv1d TVM_REGISTER_NODE_TYPE(Conv1DAttrs); @@ -153,6 +256,7 @@ with the layer input to produce a tensor of outputs. .add_type_rel("Conv2D", Conv2DRel) .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); + // relay.nn.conv3d TVM_REGISTER_NODE_TYPE(Conv3DAttrs); @@ -198,138 +302,29 @@ with the layer input to produce a tensor of outputs. .add_type_rel("Conv3D", Conv3DRel) .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); + // relay.nn.conv2d_transpose TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs); -bool Conv2DTransposeRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 3); - const auto* data = types[0].as(); - const auto* weight = types[1].as(); - if (data == nullptr) return false; - - static const Layout kNCHW("NCHW"); - static const Layout kOIHW("OIHW"); - - const Conv2DTransposeAttrs* param = attrs.as(); - CHECK(param != nullptr); - const Layout in_layout(param->data_layout); - const Layout kernel_layout(param->kernel_layout); - - const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); - CHECK(trans_in_layout.defined()) - << "Conv only support input layouts that are convertible from NCHW." - << " But got " << in_layout; - - const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); - CHECK(trans_kernel_layout.defined()) - << "Conv only support kernel layouts that are convertible from OIHW." - << " But got "<< kernel_layout; - - Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); - const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); - CHECK(trans_out_layout.defined()) - << "Conv only support output layouts that are convertible from NCHW." - << " But got " << out_layout; - - IndexExpr channels, dilated_ksize_y, dilated_ksize_x; - - auto dshape_nchw = trans_in_layout.ForwardShape(data->shape); - - // infer weight if the kernel_size and channels are defined - if (param->kernel_size.defined() && param->channels.defined()) { - CHECK_EQ(param->kernel_size.size(), 2); - CHECK_EQ(param->dilation.size(), 2); - - Array wshape({dshape_nchw[1], - indexdiv(param->channels, param->groups), - param->kernel_size[0], - param->kernel_size[1]}); - - wshape = trans_kernel_layout.BackwardShape(wshape); - dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; - dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; - channels = param->channels; - - // assign result to reporter - reporter->Assign(types[1], TensorType(wshape, data->dtype)); - } else { - // use weight to infer the conv shape. - if (weight == nullptr) return false; - auto wshape = trans_kernel_layout.ForwardShape(weight->shape); - if (param->kernel_size.defined()) { - CHECK_EQ(param->kernel_size.size(), 2); - // check the size - CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && - reporter->AssertEQ(param->kernel_size[1], wshape[3])) - << "Conv2D: shape of weight is inconsistent with kernel_size, " - << " kernel_size=" << param->kernel_size - << " wshape=" << Array(wshape); - } - if (param->channels.defined()) { - CHECK(reporter->AssertEQ(param->channels, wshape[1])) - << "Conv2D: shape of weight is inconsistent with channels, " - << " channels=" << param->channels - << " wshape=" << Array(wshape); - } - CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0])); - channels = wshape[1]; - dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; - dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; - } - // dilation - Array oshape({dshape_nchw[0], channels, 0, 0}); - IndexExpr pad_h, pad_w; - GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); - oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - - pad_h + param->output_padding[0])); - oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - - pad_w + param->output_padding[1])); - - DataType out_dtype = param->out_dtype; - if (out_dtype.bits() == 0) { - out_dtype = data->dtype; - } - oshape = trans_out_layout.BackwardShape(oshape); - reporter->Assign(types[2], TensorType(oshape, out_dtype)); - return true; -} - - -Expr MakeConv2DTranspose(Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - Array output_padding, - DataType out_dtype) { - auto attrs = make_object(); - attrs->channels = std::move(channels); - attrs->kernel_size = std::move(kernel_size); - attrs->strides = std::move(strides); - attrs->padding = std::move(padding); - attrs->output_padding = std::move(output_padding); - attrs->dilation = std::move(dilation); - attrs->groups = groups; - attrs->data_layout = std::move(data_layout); - attrs->kernel_layout = std::move(kernel_layout); - attrs->out_layout = std::move(out_layout); - attrs->out_dtype = std::move(out_dtype); - static const Op& op = Op::Get("nn.conv2d_transpose"); - return Call(op, {data, weight}, Attrs(attrs), {}); -} - - TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d_transpose") -.set_body_typed(MakeConv2DTranspose); +.set_body_typed([](Expr data, + Expr weight, + Array strides, + Array padding, + Array dilation, + int groups, + IndexExpr channels, + Array kernel_size, + std::string data_layout, + std::string kernel_layout, + std::string out_layout, + Array output_padding, + DataType out_dtype) { + return MakeConvTranspose( + data, weight, strides, padding, dilation, + groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, output_padding, out_dtype, "nn.conv2d_transpose"); +}); RELAY_REGISTER_OP("nn.conv2d_transpose") .describe(R"code(Transposed 2D convolution layer (sometimes called Deconvolution). @@ -359,136 +354,31 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW` .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(2) .set_attr("FInferCorrectLayout", - ConvInferCorrectLayout) -.add_type_rel("Conv2DTranspose", Conv2DTransposeRel); - + ConvInferCorrectLayout) +.add_type_rel("Conv2DTranspose", Conv2DTransposeRel); // relay.nn.conv1d_transpose TVM_REGISTER_NODE_TYPE(Conv1DTransposeAttrs); -bool Conv1DTransposeRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 3); - const auto* data = types[0].as(); - const auto* weight = types[1].as(); - if (data == nullptr) return false; - - static const Layout kNCW("NCW"); - static const Layout kOIW("OIW"); - - const Conv1DTransposeAttrs* param = attrs.as(); - CHECK(param != nullptr); - const Layout in_layout(param->data_layout); - const Layout kernel_layout(param->kernel_layout); - - const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCW); - CHECK(trans_in_layout.defined()) - << "Conv only support input layouts that are convertible from NCW." - << " But got " << in_layout; - - const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIW); - CHECK(trans_kernel_layout.defined()) - << "Conv only support kernel layouts that are convertible from OIW." - << " But got "<< kernel_layout; - - Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); - const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCW); - CHECK(trans_out_layout.defined()) - << "Conv only support output layouts that are convertible from NCW." - << " But got " << out_layout; - - IndexExpr channels, dilated_ksize_y, dilated_ksize_x; - - auto dshape_ncw = trans_in_layout.ForwardShape(data->shape); - - // infer weight if the kernel_size and channels are defined - if (param->kernel_size.defined() && param->channels.defined()) { - CHECK_EQ(param->kernel_size.size(), 1); - CHECK_EQ(param->dilation.size(), 1); - - Array wshape({dshape_ncw[1], - indexdiv(param->channels, param->groups), - param->kernel_size[0]}); - - wshape = trans_kernel_layout.BackwardShape(wshape); - dilated_ksize_x = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; - channels = param->channels; - - // assign result to reporter - reporter->Assign(types[1], TensorType(wshape, data->dtype)); - } else { - // use weight to infer the conv shape. - if (weight == nullptr) return false; - auto wshape = trans_kernel_layout.ForwardShape(weight->shape); - if (param->kernel_size.defined()) { - CHECK_EQ(param->kernel_size.size(), 1); - // check the size - CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2])) - << "Conv1D: shape of weight is inconsistent with kernel_size, " - << " kernel_size=" << param->kernel_size - << " wshape=" << Array(wshape); - } - if (param->channels.defined()) { - CHECK(reporter->AssertEQ(param->channels, wshape[1])) - << "Conv1D: shape of weight is inconsistent with channels, " - << " channels=" << param->channels - << " wshape=" << Array(wshape); - } - CHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), wshape[0])); - channels = wshape[1]; - dilated_ksize_x = 1 + (wshape[2] - 1) * param->dilation[0]; - } - // dilation - IndexExpr pad_w; - GetPaddingWidth(param->padding, &pad_w); - Array oshape({dshape_ncw[0], channels, 0}); - oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x - - pad_w + param->output_padding[0])); - - DataType out_dtype = param->out_dtype; - if (out_dtype.bits() == 0) { - out_dtype = data->dtype; - } - oshape = trans_out_layout.BackwardShape(oshape); - reporter->Assign(types[2], TensorType(oshape, out_dtype)); - return true; -} - - -Expr MakeConv1DTranspose(Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - Array output_padding, - DataType out_dtype) { - auto attrs = make_object(); - attrs->channels = std::move(channels); - attrs->kernel_size = std::move(kernel_size); - attrs->strides = std::move(strides); - attrs->padding = std::move(padding); - attrs->output_padding = std::move(output_padding); - attrs->dilation = std::move(dilation); - attrs->groups = groups; - attrs->data_layout = std::move(data_layout); - attrs->kernel_layout = std::move(kernel_layout); - attrs->out_layout = std::move(out_layout); - attrs->out_dtype = std::move(out_dtype); - static const Op& op = Op::Get("nn.conv1d_transpose"); - return Call(op, {data, weight}, Attrs(attrs), {}); -} - - TVM_REGISTER_GLOBAL("relay.op.nn._make.conv1d_transpose") -.set_body_typed(MakeConv1DTranspose); +.set_body_typed([](Expr data, + Expr weight, + Array strides, + Array padding, + Array dilation, + int groups, + IndexExpr channels, + Array kernel_size, + std::string data_layout, + std::string kernel_layout, + std::string out_layout, + Array output_padding, + DataType out_dtype) { + return MakeConvTranspose( + data, weight, strides, padding, dilation, + groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, output_padding, out_dtype, "nn.conv1d_transpose"); +}); RELAY_REGISTER_OP("nn.conv1d_transpose") .describe(R"code(Transposed 1D convolution layer (sometimes called Deconvolution). @@ -516,128 +406,30 @@ said convolution. .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(2) -.add_type_rel("Conv1DTranspose", Conv1DTransposeRel); - +.add_type_rel("Conv1DTranspose", Conv1DTransposeRel); // relay.nn.contrib_conv2d_winograd_without_weight_transform TVM_REGISTER_NODE_TYPE(Conv2DWinogradAttrs); -template -bool Conv2DWinogradRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 3); - const auto* data = types[0].as(); - if (data == nullptr) return false; - static const Layout kNCHW("NCHW"); - static const Layout kOIHW("OIHW"); - - const Param* param = attrs.as(); - CHECK(param != nullptr); - const Layout in_layout(param->data_layout); - const Layout kernel_layout(param->kernel_layout); - - const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); - CHECK(trans_in_layout.defined()) - << "Conv only support input layouts that are convertible from NCHW." - << " But got " << in_layout; - - const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); - CHECK(trans_kernel_layout.defined()) - << "Conv only support kernel layouts that are convertible from OIHW." - << " But got "<< kernel_layout; - - Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); - const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); - CHECK(trans_out_layout.defined()) - << "Conv only support output layouts that are convertible from NCHW." - << " But got " << out_layout; - - Array dshape_nchw = trans_in_layout.ForwardShape(data->shape); - - IndexExpr channels, dilated_ksize_y, dilated_ksize_x; - - CHECK(param->kernel_size.defined() && param->channels.defined()) - << "The kernel size and channels of a Conv must be set or infered by previous pass"; - - CHECK_EQ(param->kernel_size.size(), 2); - CHECK_EQ(param->dilation.size(), 2); - - channels = param->channels; - dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; - dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; - - // NOTE: Do not check weight shape here! - // Different backend requires different layout to compute - // the batch gemm stage in winograd efficiently, but we want to - // make this op work for all backends. - // So we accept all weight shapes, and assume the TOPI developers - // can handle this correctly in alter_op_layout. - - // dilation - Array oshape({dshape_nchw[0], channels, 0, 0}); - - IndexExpr pad_h, pad_w; - GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); - if (!dshape_nchw[2].as()) { - oshape.Set(2, (dshape_nchw[2] + pad_h - - dilated_ksize_y) / param->strides[0] + 1); - } else { - oshape.Set(2, dshape_nchw[2]); - } - if (!dshape_nchw[3].as()) { - oshape.Set(3, (dshape_nchw[3] + pad_w - - dilated_ksize_x) / param->strides[1] + 1); - } else { - oshape.Set(3, dshape_nchw[3]); - } - - DataType out_dtype = param->out_dtype; - if (out_dtype.bits() == 0) { - out_dtype = data->dtype; - } - oshape = trans_out_layout.BackwardShape(oshape); - // assign output type - reporter->Assign(types[2], TensorType(oshape, out_dtype)); - return true; -} - - -// Positional relay function to create conv2d winograd operator -// used by frontend FFI. -Expr MakeConv2DWinograd(Expr data, - Expr weight, - int tile_size, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - auto attrs = make_object(); - attrs->tile_size = tile_size; - attrs->strides = std::move(strides); - attrs->padding = std::move(padding); - attrs->dilation = std::move(dilation); - attrs->groups = groups; - attrs->channels = channels; - attrs->kernel_size = std::move(kernel_size); - attrs->data_layout = std::move(data_layout); - attrs->kernel_layout = std::move(kernel_layout); - attrs->out_layout = std::move(out_layout); - attrs->out_dtype = std::move(out_dtype); - static const Op& op = Op::Get("nn.contrib_conv2d_winograd_without_weight_transform"); - return Call(op, {data, weight}, Attrs(attrs), {}); -} - - TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_without_weight_transform") -.set_body_typed(MakeConv2DWinograd); +.set_body_typed([](Expr data, + Expr weight, + int tile_size, + Array strides, + Array padding, + Array dilation, + int groups, + IndexExpr channels, + Array kernel_size, + std::string data_layout, + std::string kernel_layout, + std::string out_layout, + DataType out_dtype) { + return MakeConvWinograd( + data, weight, tile_size, strides, padding, dilation, + groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, out_dtype, "nn.contrib_conv2d_winograd_without_weight_transform"); +}); RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform") @@ -662,46 +454,14 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform") ConvInferCorrectLayout); // relay.nn.contrib_conv2d_winograd_weight_transform -TVM_REGISTER_NODE_TYPE(Conv2DWinogradWeightTransformAttrs); - -bool Conv2DWinogradWeightTransformRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 2); - const auto* data = types[0].as(); - if (data == nullptr) return false; - - const Conv2DWinogradWeightTransformAttrs* param = attrs.as(); - CHECK(param != nullptr); - - CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout"; - - // each pad width element should be a pair of positive integers - std::vector oshape { - param->tile_size + data->shape[2] - 1, - param->tile_size + data->shape[3] - 1, - data->shape[0], - data->shape[1], - }; - - reporter->Assign(types[1], TensorType(Array(oshape), - data->dtype)); - return true; -} - -Expr MakeConv2DWinogradWeightTransform(Expr weight, - int tile_size) { - auto attrs = make_object(); - attrs->tile_size = tile_size; - static const Op& op = Op::Get("nn.contrib_conv2d_winograd_weight_transform"); - return Call(op, {weight}, Attrs(attrs), {}); -} - +TVM_REGISTER_NODE_TYPE(ConvWinogradWeightTransformAttrs); TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_weight_transform") -.set_body_typed(MakeConv2DWinogradWeightTransform); - +.set_body_typed([](Expr weight, + int tile_size) { + return MakeConvWinogradWeightTransform( + weight, tile_size, "nn.contrib_conv2d_winograd_weight_transform"); +}); RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform") .describe(R"code(Weight transformation of winograd fast convolution algorithm. @@ -711,47 +471,82 @@ weight transformation in advance. - **weight**: (channels, in_channels, kernel_size[0], kernel_size[1]) )code" TVM_ADD_FILELINE) -.set_attrs_type() +.set_attrs_type() .set_num_inputs(1) .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(10) .add_type_rel("Conv2DWinogradWeightTransform", Conv2DWinogradWeightTransformRel); +// relay.nn.contrib_conv3d_winograd_without_weight_transform +TVM_REGISTER_NODE_TYPE(Conv3DWinogradAttrs); + +TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv3d_winograd_without_weight_transform") +.set_body_typed([](Expr data, + Expr weight, + int tile_size, + Array strides, + Array padding, + Array dilation, + int groups, + IndexExpr channels, + Array kernel_size, + std::string data_layout, + std::string kernel_layout, + std::string out_layout, + DataType out_dtype) { + return MakeConvWinograd( + data, weight, tile_size, strides, padding, dilation, + groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, out_dtype, "nn.contrib_conv3d_winograd_without_weight_transform"); +}); + +RELAY_REGISTER_OP("nn.contrib_conv3d_winograd_without_weight_transform") +.describe(R"code(Compute conv3d with winograd algorithm. Only supports NCDHW layout. + This operator assumes the weight tensor is already pre-transformed by + nn.contrib_conv3d_winograd_weight_transform. + +- **data**: Input is 5D array of shape (batch_size, in_channels, depth, height, width) +- **weight**: Any shape + We do not check the shape for this input tensor. Since different backend + has different layout strategy. + +- **out**: Output is 5D array of shape (batch_size, channels, depth, out_height, out_width) +)code" TVM_ADD_FILELINE) +.set_attrs_type() +.set_num_inputs(2) +.add_argument("data", "Tensor", "The input tensor.") +.add_argument("weight", "Tensor", "The weight tensor.") +.set_support_level(10) +.add_type_rel("Conv3DWinograd", Conv3DWinogradRel) +.set_attr("FInferCorrectLayout", + ConvInferCorrectLayout); + +// relay.nn.contrib_conv3d_winograd_weight_transform +TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv3d_winograd_weight_transform") +.set_body_typed([](Expr weight, + int tile_size) { + return MakeConvWinogradWeightTransform( + weight, tile_size, "nn.contrib_conv3d_winograd_weight_transform"); +}); + +RELAY_REGISTER_OP("nn.contrib_conv3d_winograd_weight_transform") + .describe(R"code(Weight transformation of winograd fast 3d convolution algorithm. + +Separate this into another operator in order to enable Precompute Pass to compute the +weight transformation in advance. + +- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1], kernel_size[2]) +)code" TVM_ADD_FILELINE) +.set_attrs_type() +.set_num_inputs(1) +.add_argument("weight", "Tensor", "The weight tensor.") +.set_support_level(10) +.add_type_rel("Conv3DWinogradWeightTransform", Conv3DWinogradWeightTransformRel); + // relay.nn.contrib_conv2d_winograd_nnpack_weight_transform TVM_REGISTER_NODE_TYPE(Conv2DWinogradNNPACKWeightTransformAttrs); -bool Conv2DWinogradNNPACKWeightTransformRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 2); - const auto* data = types[0].as(); - if (data == nullptr) { - return false; - } - - const Conv2DWinogradNNPACKWeightTransformAttrs* param = - attrs.as(); - CHECK(param != nullptr); - - CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout"; - - std::vector oshape{ - data->shape[0], - data->shape[1], - 8, - 8, - }; - - DataType out_dtype = param->out_dtype; - if (out_dtype.bits() == 0) { - out_dtype = data->dtype; - } - reporter->Assign(types[1], TensorType(Array(oshape), out_dtype)); - return true; -} - Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight, int convolution_algorithm, DataType out_dtype) { @@ -779,38 +574,27 @@ weight transformation in advance. .set_support_level(10) .add_type_rel("Conv2DWinogradNNPACKWeightTransform", Conv2DWinogradNNPACKWeightTransformRel); + // Positional relay function to create conv2d NCHWc operator // used by frontend FFI. -Expr MakeConv2DNCHWc(Expr data, - Expr kernel, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - auto attrs = make_object(); - attrs->strides = std::move(strides); - attrs->padding = std::move(padding); - attrs->dilation = std::move(dilation); - attrs->groups = groups; - attrs->channels = channels; - attrs->kernel_size = std::move(kernel_size); - attrs->data_layout = std::move(data_layout); - attrs->kernel_layout = std::move(kernel_layout); - attrs->out_layout = std::move(out_layout); - attrs->out_dtype = std::move(out_dtype); - static const Op& op = Op::Get("nn.contrib_conv2d_NCHWc"); - return Call(op, {data, kernel}, Attrs(attrs), {}); -} - TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_NCHWc") -.set_body_typed(MakeConv2DNCHWc); - +.set_body_typed([](Expr data, + Expr weight, + Array strides, + Array padding, + Array dilation, + int groups, + IndexExpr channels, + Array kernel_size, + std::string data_layout, + std::string kernel_layout, + std::string out_layout, + DataType out_dtype) { + return MakeConv( + data, weight, strides, padding, dilation, + groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, out_dtype, "nn.contrib_conv2d_NCHWc"); +}); RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc") .describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout. @@ -831,35 +615,24 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc") // Positional relay function to create depthwise conv2d NCHWc operator // used by frontend FFI. -Expr MakeDepthwiseConv2DNCHWc(Expr data, - Expr kernel, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - auto attrs = make_object(); - attrs->strides = std::move(strides); - attrs->padding = std::move(padding); - attrs->dilation = std::move(dilation); - attrs->groups = groups; - attrs->channels = channels; - attrs->kernel_size = std::move(kernel_size); - attrs->data_layout = std::move(data_layout); - attrs->kernel_layout = std::move(kernel_layout); - attrs->out_layout = std::move(out_layout); - attrs->out_dtype = std::move(out_dtype); - static const Op& op = Op::Get("nn.contrib_depthwise_conv2d_NCHWc"); - return Call(op, {data, kernel}, Attrs(attrs), {}); -} - TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_depthwise_conv2d_NCHWc") -.set_body_typed(MakeDepthwiseConv2DNCHWc); +.set_body_typed([](Expr data, + Expr weight, + Array strides, + Array padding, + Array dilation, + int groups, + IndexExpr channels, + Array kernel_size, + std::string data_layout, + std::string kernel_layout, + std::string out_layout, + DataType out_dtype) { + return MakeConv( + data, weight, strides, padding, dilation, + groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, out_dtype, "nn.contrib_depthwise_conv2d_NCHWc"); +}); RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc") @@ -879,85 +652,6 @@ RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc") ConvInferCorrectLayout); -bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 4); - const auto* data = types[0].as(); - const auto* weight = types[2].as(); - - CHECK(data); - auto* param = attrs.as(); - CHECK_EQ(param->data_layout, "NCHW") << "data layout not supported."; - CHECK_EQ(param->kernel_layout, "OIHW") << "kernel_layout not supported."; - - IndexExpr channels, dilated_ksize_y, dilated_ksize_x, ksize_y, ksize_x; - - // infer weight shape if kernel_size and channels are defiend - if (param->kernel_size.defined() && param->channels.defined()) { - CHECK_EQ(param->kernel_size.size(), 2); - CHECK_EQ(param->dilation.size(), 2); - Array wshape( - {param->channels, - indexdiv(data->shape[1], param->groups), - param->kernel_size[0], - param->kernel_size[1]}); - channels = param->channels; - ksize_y = param->kernel_size[0]; - ksize_x = param->kernel_size[1]; - dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; - dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; - // assign result to reporter - reporter->Assign(types[2], TensorType(wshape, data->dtype)); - } else { - // use weight to infer the conv shape. - if (weight == nullptr) return false; - auto wshape = weight->shape; - if (param->kernel_size.defined()) { - CHECK_EQ(param->kernel_size.size(), 2); - // check the size - CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && - reporter->AssertEQ(param->kernel_size[1], wshape[3])) - << "DeformableConv2D: shape of weight is inconsistent with kernel_size, " - << " kernel_size=" << param->kernel_size - << " wshape=" << wshape; - } - if (param->channels.defined()) { - CHECK(reporter->AssertEQ(param->channels, wshape[0])) - << "DeformableConv2D: shape of weight is inconsistent with channels, " - << " channels=" << param->channels - << " wshape=" << wshape; - } - CHECK(reporter->AssertEQ(indexdiv(data->shape[1], param->groups), wshape[1])); - channels = wshape[0]; - ksize_y = wshape[2]; - ksize_x = wshape[3]; - dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; - dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; - } - // dilation - Array oshape({data->shape[0], channels, 0, 0}); - - IndexExpr pad_h, pad_w; - GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); - oshape.Set(2, indexdiv(data->shape[2] + pad_h - dilated_ksize_y, - param->strides[0]) + 1); - oshape.Set(3, indexdiv(data->shape[3] + pad_w - dilated_ksize_x, - param->strides[1]) + 1); - DataType out_dtype = param->out_dtype; - - // infer offset shape - Array offset_shape({data->shape[0], 2 * ksize_y * ksize_x * param->deformable_groups, - oshape[2], oshape[3]}); - reporter->Assign(types[1], TensorType(offset_shape, data->dtype)); - if (out_dtype.bits() == 0) { - out_dtype = data->dtype; - } - - reporter->Assign(types[3], TensorType(oshape, out_dtype)); - return true; -} - - TVM_REGISTER_NODE_TYPE(DeformableConv2DAttrs); RELAY_REGISTER_OP("nn.deformable_conv2d") @@ -986,42 +680,30 @@ by concating all the *g* results. .add_argument("offset", "Tensor", "The offset tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(5) -.add_type_rel("DeformableConv2D", DeformableConv2DRel); +.add_type_rel("DeformableConv2D", DeformableConv2DRel); // Positional relay function to create deformable_conv2d operator // used by frontend FFI. -Expr MakeDeformableConv2D(Expr data, - Expr offset, - Expr weight, - Array strides, - Array padding, - Array dilation, - int deformable_groups, - int groups, - int channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - auto attrs = make_object(); - attrs->strides = strides; - attrs->padding = padding; - attrs->dilation = dilation; - attrs->deformable_groups = deformable_groups; - attrs->groups = groups; - attrs->channels = channels; - attrs->kernel_size = kernel_size; - attrs->data_layout = data_layout; - attrs->kernel_layout = kernel_layout; - attrs->out_layout = out_layout; - attrs->out_dtype = out_dtype; - static const Op& op = Op::Get("nn.deformable_conv2d"); - return Call(op, {data, offset, weight}, Attrs{attrs}, {}); -} - TVM_REGISTER_GLOBAL("relay.op.nn._make.deformable_conv2d") -.set_body_typed(MakeDeformableConv2D); +.set_body_typed([](Expr data, + Expr offset, + Expr weight, + Array strides, + Array padding, + Array dilation, + int deformable_groups, + int groups, + int channels, + Array kernel_size, + std::string data_layout, + std::string kernel_layout, + std::string out_layout, + DataType out_dtype) { + return MakeDeformableConv( + data, offset, weight, strides, padding, dilation, + deformable_groups, groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, out_dtype, "nn.deformable_conv2d"); +}); } // namespace relay } // namespace tvm diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index 05c11719c3206..6c5aebe2bd4c1 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -29,12 +29,15 @@ #include #include +#include #include "../op_common.h" namespace tvm { namespace relay { + +// Standard convolution operator shape relations template bool Conv1DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { @@ -363,6 +366,533 @@ bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } + +// Winograd convolution shape relations +inline bool Conv2DWinogradWeightTransformRel(const Array& types, int num_inputs, + const Attrs& attrs, const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) return false; + + const ConvWinogradWeightTransformAttrs* param = attrs.as(); + CHECK(param != nullptr); + + CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout"; + + std::vector oshape { + param->tile_size + data->shape[2] - 1, + param->tile_size + data->shape[3] - 1, + data->shape[0], + data->shape[1], + }; + + reporter->Assign(types[1], TensorType(Array(oshape), + data->dtype)); + return true; +} + +inline bool Conv3DWinogradWeightTransformRel(const Array& types, int num_inputs, + const Attrs& attrs, const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) return false; + + const ConvWinogradWeightTransformAttrs* param = attrs.as(); + CHECK(param != nullptr); + + CHECK_EQ(data->shape.size(), 5) << "Only support NCDHW normal kernel layout"; + + // Shape of packed weights depends on whether depth is being transformed or not. + Array oshape({0, 0, 0, data->shape[0], data->shape[1]}); + auto* depth_imm = data->shape[2].as(); + bool transform_depth = (depth_imm->value > 2)&&(depth_imm->value < 8); + if (transform_depth) { + oshape.Set(0, param->tile_size + data->shape[2] - 1); + oshape.Set(1, param->tile_size + data->shape[3] - 1); + oshape.Set(2, param->tile_size + data->shape[4] - 1); + } else { + oshape.Set(0, param->tile_size + data->shape[3] - 1); + oshape.Set(1, param->tile_size + data->shape[4] - 1); + oshape.Set(2, data->shape[2]); + } + + reporter->Assign(types[1], TensorType(oshape, data->dtype)); + return true; +} + +inline bool Conv2DWinogradNNPACKWeightTransformRel(const Array& types, int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) { + return false; + } + + const Conv2DWinogradNNPACKWeightTransformAttrs* param = + attrs.as(); + CHECK(param != nullptr); + + CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout"; + + std::vector oshape{ + data->shape[0], + data->shape[1], + 8, + 8, + }; + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + reporter->Assign(types[1], TensorType(Array(oshape), out_dtype)); + return true; +} + +template +bool Conv2DWinogradRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + if (data == nullptr) return false; + static const Layout kNCHW("NCHW"); + static const Layout kOIHW("OIHW"); + + const AttrType* param = attrs.as(); + CHECK(param != nullptr); + const Layout in_layout(param->data_layout); + const Layout kernel_layout(param->kernel_layout); + + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); + CHECK(trans_in_layout.defined()) + << "Conv only support input layouts that are convertible from NCHW." + << " But got " << in_layout; + + const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); + CHECK(trans_kernel_layout.defined()) + << "Conv only support kernel layouts that are convertible from OIHW." + << " But got "<< kernel_layout; + + Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); + const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); + CHECK(trans_out_layout.defined()) + << "Conv only support output layouts that are convertible from NCHW." + << " But got " << out_layout; + + Array dshape_nchw = trans_in_layout.ForwardShape(data->shape); + + IndexExpr channels, dilated_ksize_y, dilated_ksize_x; + + CHECK(param->kernel_size.defined() && param->channels.defined()) + << "The kernel size and channels of a Conv must be set or inferred by previous pass"; + + CHECK_EQ(param->kernel_size.size(), 2); + CHECK_EQ(param->dilation.size(), 2); + + channels = param->channels; + dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; + dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; + + // NOTE: Do not check weight shape here! + // Different backend requires different layout to compute + // the batch gemm stage in winograd efficiently, but we want to + // make this op work for all backends. + // So we accept all weight shapes, and assume the TOPI developers + // can handle this correctly in alter_op_layout. + + // dilation + Array oshape({dshape_nchw[0], channels, 0, 0}); + + IndexExpr pad_h, pad_w; + GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); + if (!dshape_nchw[2].as()) { + oshape.Set(2, (dshape_nchw[2] + pad_h + - dilated_ksize_y) / param->strides[0] + 1); + } else { + oshape.Set(2, dshape_nchw[2]); + } + if (!dshape_nchw[3].as()) { + oshape.Set(3, (dshape_nchw[3] + pad_w + - dilated_ksize_x) / param->strides[1] + 1); + } else { + oshape.Set(3, dshape_nchw[3]); + } + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + oshape = trans_out_layout.BackwardShape(oshape); + // assign output type + reporter->Assign(types[2], TensorType(oshape, out_dtype)); + return true; +} + + +template +bool Conv3DWinogradRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + if (data == nullptr) return false; + static const Layout kNCDHW("NCDHW"); + static const Layout kOIDHW("OIDHW"); + + const AttrType* param = attrs.as(); + CHECK(param != nullptr); + const Layout in_layout(param->data_layout); + const Layout kernel_layout(param->kernel_layout); + + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCDHW); + CHECK(trans_in_layout.defined()) + << "Conv only support input layouts that are convertible from NCDHW." + << " But got " << in_layout; + + const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIDHW); + CHECK(trans_kernel_layout.defined()) + << "Conv only support kernel layouts that are convertible from OIDHW." + << " But got "<< kernel_layout; + + Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); + const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCDHW); + CHECK(trans_out_layout.defined()) + << "Conv only support output layouts that are convertible from NCDHW." + << " But got " << out_layout; + + Array dshape_ncdhw = trans_in_layout.ForwardShape(data->shape); + + IndexExpr channels, dilated_ksize_d, dilated_ksize_y, dilated_ksize_x; + + CHECK(param->kernel_size.defined() && param->channels.defined()) + << "The kernel size and channels of a Conv must be set or inferred by previous pass"; + + CHECK_EQ(param->kernel_size.size(), 3); + CHECK_EQ(param->dilation.size(), 3); + + channels = param->channels; + dilated_ksize_d = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; + dilated_ksize_y = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; + dilated_ksize_x = 1 + (param->kernel_size[2] - 1) * param->dilation[2]; + + // NOTE: Do not check weight shape here! + // Different backend requires different layout to compute + // the batch gemm stage in winograd efficiently, but we want to + // make this op work for all backends. + // So we accept all weight shapes, and assume the TOPI developers + // can handle this correctly in alter_op_layout. + + // dilation + Array oshape({dshape_ncdhw[0], channels, 0, 0, 0}); + + IndexExpr pad_d, pad_h, pad_w; + GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w); + if (!dshape_ncdhw[2].as()) { + oshape.Set(2, (dshape_ncdhw[2] + pad_d + - dilated_ksize_d) / param->strides[0] + 1); + } else { + oshape.Set(2, dshape_ncdhw[2]); + } + if (!dshape_ncdhw[2].as()) { + oshape.Set(3, (dshape_ncdhw[3] + pad_h + - dilated_ksize_y) / param->strides[1] + 1); + } else { + oshape.Set(3, dshape_ncdhw[3]); + } + if (!dshape_ncdhw[4].as()) { + oshape.Set(4, (dshape_ncdhw[4] + pad_w + - dilated_ksize_x) / param->strides[2] + 1); + } else { + oshape.Set(4, dshape_ncdhw[4]); + } + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + oshape = trans_out_layout.BackwardShape(oshape); + // assign output type + reporter->Assign(types[2], TensorType(oshape, out_dtype)); + return true; +} + + +// Transposed convolution shape relations +template +bool Conv1DTransposeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* weight = types[1].as(); + if (data == nullptr) return false; + + static const Layout kNCW("NCW"); + static const Layout kOIW("OIW"); + + const Conv1DTransposeAttrs* param = attrs.as(); + CHECK(param != nullptr); + const Layout in_layout(param->data_layout); + const Layout kernel_layout(param->kernel_layout); + + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCW); + CHECK(trans_in_layout.defined()) + << "Conv only support input layouts that are convertible from NCW." + << " But got " << in_layout; + + const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIW); + CHECK(trans_kernel_layout.defined()) + << "Conv only support kernel layouts that are convertible from OIW." + << " But got "<< kernel_layout; + + Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); + const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCW); + CHECK(trans_out_layout.defined()) + << "Conv only support output layouts that are convertible from NCW." + << " But got " << out_layout; + + IndexExpr channels, dilated_ksize_y, dilated_ksize_x; + + auto dshape_ncw = trans_in_layout.ForwardShape(data->shape); + + // infer weight if the kernel_size and channels are defined + if (param->kernel_size.defined() && param->channels.defined()) { + CHECK_EQ(param->kernel_size.size(), 1); + CHECK_EQ(param->dilation.size(), 1); + + Array wshape({dshape_ncw[1], + indexdiv(param->channels, param->groups), + param->kernel_size[0]}); + + wshape = trans_kernel_layout.BackwardShape(wshape); + dilated_ksize_x = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; + channels = param->channels; + + // assign result to reporter + reporter->Assign(types[1], TensorType(wshape, data->dtype)); + } else { + // use weight to infer the conv shape. + if (weight == nullptr) return false; + auto wshape = trans_kernel_layout.ForwardShape(weight->shape); + if (param->kernel_size.defined()) { + CHECK_EQ(param->kernel_size.size(), 1); + // check the size + CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2])) + << "Conv1D: shape of weight is inconsistent with kernel_size, " + << " kernel_size=" << param->kernel_size + << " wshape=" << Array(wshape); + } + if (param->channels.defined()) { + CHECK(reporter->AssertEQ(param->channels, wshape[1])) + << "Conv1D: shape of weight is inconsistent with channels, " + << " channels=" << param->channels + << " wshape=" << Array(wshape); + } + CHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), wshape[0])); + channels = wshape[1]; + dilated_ksize_x = 1 + (wshape[2] - 1) * param->dilation[0]; + } + // dilation + IndexExpr pad_w; + GetPaddingWidth(param->padding, &pad_w); + Array oshape({dshape_ncw[0], channels, 0}); + oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x - + pad_w + param->output_padding[0])); + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + oshape = trans_out_layout.BackwardShape(oshape); + reporter->Assign(types[2], TensorType(oshape, out_dtype)); + return true; +} + + +template +bool Conv2DTransposeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* weight = types[1].as(); + if (data == nullptr) return false; + + static const Layout kNCHW("NCHW"); + static const Layout kOIHW("OIHW"); + + const Conv2DTransposeAttrs* param = attrs.as(); + CHECK(param != nullptr); + const Layout in_layout(param->data_layout); + const Layout kernel_layout(param->kernel_layout); + + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); + CHECK(trans_in_layout.defined()) + << "Conv only support input layouts that are convertible from NCHW." + << " But got " << in_layout; + + const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); + CHECK(trans_kernel_layout.defined()) + << "Conv only support kernel layouts that are convertible from OIHW." + << " But got "<< kernel_layout; + + Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); + const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); + CHECK(trans_out_layout.defined()) + << "Conv only support output layouts that are convertible from NCHW." + << " But got " << out_layout; + + IndexExpr channels, dilated_ksize_y, dilated_ksize_x; + + auto dshape_nchw = trans_in_layout.ForwardShape(data->shape); + + // infer weight if the kernel_size and channels are defined + if (param->kernel_size.defined() && param->channels.defined()) { + CHECK_EQ(param->kernel_size.size(), 2); + CHECK_EQ(param->dilation.size(), 2); + + Array wshape({dshape_nchw[1], + indexdiv(param->channels, param->groups), + param->kernel_size[0], + param->kernel_size[1]}); + + wshape = trans_kernel_layout.BackwardShape(wshape); + dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; + dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; + channels = param->channels; + + // assign result to reporter + reporter->Assign(types[1], TensorType(wshape, data->dtype)); + } else { + // use weight to infer the conv shape. + if (weight == nullptr) return false; + auto wshape = trans_kernel_layout.ForwardShape(weight->shape); + if (param->kernel_size.defined()) { + CHECK_EQ(param->kernel_size.size(), 2); + // check the size + CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && + reporter->AssertEQ(param->kernel_size[1], wshape[3])) + << "Conv2D: shape of weight is inconsistent with kernel_size, " + << " kernel_size=" << param->kernel_size + << " wshape=" << Array(wshape); + } + if (param->channels.defined()) { + CHECK(reporter->AssertEQ(param->channels, wshape[1])) + << "Conv2D: shape of weight is inconsistent with channels, " + << " channels=" << param->channels + << " wshape=" << Array(wshape); + } + CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0])); + channels = wshape[1]; + dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; + dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; + } + // dilation + Array oshape({dshape_nchw[0], channels, 0, 0}); + IndexExpr pad_h, pad_w; + GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); + oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - + pad_h + param->output_padding[0])); + oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - + pad_w + param->output_padding[1])); + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + oshape = trans_out_layout.BackwardShape(oshape); + reporter->Assign(types[2], TensorType(oshape, out_dtype)); + return true; +} + + +// Deformable Convolution shape relations. +template +bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 4); + const auto* data = types[0].as(); + const auto* weight = types[2].as(); + + CHECK(data); + auto* param = attrs.as(); + CHECK_EQ(param->data_layout, "NCHW") << "data layout not supported."; + CHECK_EQ(param->kernel_layout, "OIHW") << "kernel_layout not supported."; + + IndexExpr channels, dilated_ksize_y, dilated_ksize_x, ksize_y, ksize_x; + + // infer weight shape if kernel_size and channels are defiend + if (param->kernel_size.defined() && param->channels.defined()) { + CHECK_EQ(param->kernel_size.size(), 2); + CHECK_EQ(param->dilation.size(), 2); + Array wshape( + {param->channels, + indexdiv(data->shape[1], param->groups), + param->kernel_size[0], + param->kernel_size[1]}); + channels = param->channels; + ksize_y = param->kernel_size[0]; + ksize_x = param->kernel_size[1]; + dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; + dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; + // assign result to reporter + reporter->Assign(types[2], TensorType(wshape, data->dtype)); + } else { + // use weight to infer the conv shape. + if (weight == nullptr) return false; + auto wshape = weight->shape; + if (param->kernel_size.defined()) { + CHECK_EQ(param->kernel_size.size(), 2); + // check the size + CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && + reporter->AssertEQ(param->kernel_size[1], wshape[3])) + << "DeformableConv2D: shape of weight is inconsistent with kernel_size, " + << " kernel_size=" << param->kernel_size + << " wshape=" << wshape; + } + if (param->channels.defined()) { + CHECK(reporter->AssertEQ(param->channels, wshape[0])) + << "DeformableConv2D: shape of weight is inconsistent with channels, " + << " channels=" << param->channels + << " wshape=" << wshape; + } + CHECK(reporter->AssertEQ(indexdiv(data->shape[1], param->groups), wshape[1])); + channels = wshape[0]; + ksize_y = wshape[2]; + ksize_x = wshape[3]; + dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; + dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; + } + // dilation + Array oshape({data->shape[0], channels, 0, 0}); + + IndexExpr pad_h, pad_w; + GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); + oshape.Set(2, indexdiv(data->shape[2] + pad_h - dilated_ksize_y, + param->strides[0]) + 1); + oshape.Set(3, indexdiv(data->shape[3] + pad_w - dilated_ksize_x, + param->strides[1]) + 1); + DataType out_dtype = param->out_dtype; + + // infer offset shape + Array offset_shape({data->shape[0], 2 * ksize_y * ksize_x * param->deformable_groups, + oshape[2], oshape[3]}); + reporter->Assign(types[1], TensorType(offset_shape, data->dtype)); + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + + reporter->Assign(types[3], TensorType(oshape, out_dtype)); + return true; +} + + template Array > ConvInferCorrectLayout( const Attrs& attrs, @@ -378,6 +908,7 @@ Array > ConvInferCorrectLayout( params->data_layout : params->out_layout}}; } + } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_NN_CONVOLUTION_H_ diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 7a42fc329e043..771a63deec69f 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -25,6 +25,7 @@ from tvm.relay.testing import ctx_list, run_infer_type from tvm.contrib import util import topi.testing +from topi.cuda.conv3d_winograd import _infer_tile_size def test_conv1d_infer_type(): @@ -326,7 +327,7 @@ def _query_inside(self, target, workload): cfg['tile_y'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1]) cfg['tile_x'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1]) cfg['tile_rc'] = autotvm.task.space.SplitEntity([-1, 1]) - cfg['auto_unroll_max_setp'] = autotvm.task.space.OtherOptionEntity(1500) + cfg['auto_unroll_max_step'] = autotvm.task.space.OtherOptionEntity(1500) cfg['unroll_explicit'] = autotvm.task.space.OtherOptionEntity(1) self.memory[key] = cfg return cfg @@ -522,6 +523,94 @@ def run_test_conv3d(dtype, out_dtype, scale, dshape, kshape, run_test_conv3d("float32", "float32", 1, dshape, kshape, padding=(1, 1, 1), channels=10, kernel_size=(3, 3 ,3), except_targets=["cuda"]) +def test_conv3d_winograd(): + class WinogradFallback(autotvm.FallbackContext): + def _query_inside(self, target, workload): + key = (target, workload) + if key in self.memory: + return self.memory[key] + cfg = autotvm.task.space.FallbackConfigEntity() + cfg.is_fallback = False + cfg.cost = 0.1 if 'winograd' in workload[0] else 1 + cfg['tile_b'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1]) + cfg['tile_y'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1]) + cfg['tile_x'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1]) + cfg['tile_rc'] = autotvm.task.space.SplitEntity([-1, 1]) + cfg['auto_unroll_max_step'] = autotvm.task.space.OtherOptionEntity(0) + cfg['unroll_explicit'] = autotvm.task.space.OtherOptionEntity(1) + self.memory[key] = cfg + return cfg + + def run_test_conv3d_cuda(dtype, out_dtype, scale, dshape, kshape, + padding=(1, 1, 1), + groups=1, + dilation=(1, 1, 1), + prepack=False, + **attrs): + + x = relay.var("x", shape=dshape, dtype=dtype) + w = relay.var("w", shape=kshape, dtype=dtype) + if prepack: + tile_size = _infer_tile_size(np.zeros(shape=dshape), np.zeros(shape=kshape)) + w_packed = relay.nn.contrib_conv3d_winograd_weight_transform(w, tile_size) + + y = relay.nn.contrib_conv3d_winograd_without_weight_transform( + x, w_packed, tile_size, + padding=padding, + dilation=dilation, + groups=groups, + channels=kshape[0], + **attrs) + else: + y = relay.nn.conv3d(x, w, + padding=padding, + dilation=dilation, + groups=groups, + **attrs) + func = relay.Function([x, w], y) + mod = tvm.IRModule() + mod['main'] = func + mod = relay.transform.InferType()(mod) + + data = np.random.uniform(-scale, scale, size=dshape).astype(dtype) + kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype) + ref_res = topi.testing.conv3d_ncdhw_python( + data.astype(out_dtype), kernel.astype(out_dtype), 1, padding, + groups=groups) + + with WinogradFallback(), relay.build_config(opt_level=3): + for target, ctx in ctx_list(): + if target != 'cuda': + continue + params = {'w': tvm.nd.array(kernel)} + graph, lib, params = relay.build_module.build(mod, target=target, params=params) + module = tvm.contrib.graph_runtime.create(graph, lib, ctx) + module.set_input('x', tvm.nd.array(data)) + module.set_input(**params) + module.run() + op_res1 = module.get_output(0) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-3, atol=1e-3) + + # normal winograd: stride 1, padding 1, kernel 3x3x3 + dshape = (1, 32, 16, 16, 16) + kshape = (64, 32, 3, 3, 3) + run_test_conv3d_cuda("float32", "float32", 1, dshape, kshape, + padding=(1, 1, 1), kernel_size=(3, 3, 3)) + # Without depth transform using 1x3x3 kernel. + kshape = (64, 32, 1, 3, 3) + run_test_conv3d_cuda("float32", "float32", 1, dshape, kshape, + padding=(0, 1, 1), kernel_size=(1, 3, 3)) + + # extended winograd: stride 1, padding N, kernel NxNxN + dshape = (1, 61, 20, 20, 20) + kshape = (120, 61, 5, 5, 5) + run_test_conv3d_cuda("float32", "float32", 1, dshape, kshape, + padding=(2, 2, 2), channels=120, kernel_size=(5, 5, 5)) + # Without depth transform + kshape = (120, 61, 1, 5, 5) + run_test_conv3d_cuda("float32", "float32", 1, dshape, kshape, + padding=(0, 2, 2), channels=120, kernel_size=(1, 5, 5)) + def test_conv2d_transpose_infer_type(): # symbolic in batch dimension @@ -1268,6 +1357,7 @@ def test_bitpack_infer_type(): test_conv2d_winograd() test_conv3d_run() test_conv3d_ndhwc_run() + test_conv3d_winograd() test_bitserial_conv2d_infer_type() test_batch_flatten() test_upsampling() diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index 302171ee6466b..83ddedc996fe3 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -31,6 +31,8 @@ from .conv2d_transpose_nchw import * from .deformable_conv2d import * from .conv3d import * +from .conv3d_winograd import * +from . import conv3d_alter_op from .reduction import schedule_reduce from .softmax import schedule_softmax from .injective import schedule_injective, schedule_elemwise, schedule_broadcast diff --git a/topi/python/topi/cuda/conv3d_alter_op.py b/topi/python/topi/cuda/conv3d_alter_op.py new file mode 100644 index 0000000000000..fbda456823527 --- /dev/null +++ b/topi/python/topi/cuda/conv3d_alter_op.py @@ -0,0 +1,95 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument +"""Conv3D alter op and legalize functions for cuda backend""" + +import logging +import tvm +from tvm import te +from tvm import relay +from tvm import autotvm + +from .. import nn +from ..util import get_const_tuple +from .conv3d_winograd import _infer_tile_size + +logger = logging.getLogger('topi') + +@nn.conv3d_alter_layout.register(["cuda", "gpu"]) +def _alter_conv3d_layout(attrs, inputs, tinfos, out_type): + target = tvm.target.Target.current(allow_none=False) + dispatch_ctx = autotvm.task.DispatchContext.current + + _, outs = relay.backend.compile_engine.select_implementation( + relay.op.get("nn.conv3d"), attrs, tinfos, out_type, target) + workload = autotvm.task.get_workload(outs) + if workload is None: + # The best implementation is not an AutoTVM template, + # we then assume it's not necessary to alter this op. + return None + cfg = dispatch_ctx.query(target, workload) + if cfg.is_fallback: # if is fallback, clear query cache and return None + autotvm.task.clear_fallback_cache(target, workload) + return None + + topi_tmpl = workload[0] + new_attrs = {k: attrs[k] for k in attrs.keys()} + + strides = attrs.get_int_tuple("strides") + padding = attrs.get_int_tuple("padding") + dilation = attrs.get_int_tuple("dilation") + groups = attrs.get_int('groups') + data_layout = attrs["data_layout"] + kernel_layout = attrs["kernel_layout"] + data, kernel = tinfos + out_dtype = out_type.dtype + + if topi_tmpl == "conv3d_ncdhw_winograd.cuda": + if dilation != (1, 1, 1): + logger.warning("Does not support weight pre-transform for dilated 3D convolution.") + return None + + assert data_layout == "NCDHW" and kernel_layout == "OIDHW" + N, CI, D, H, W = get_const_tuple(data.shape) + CO, _, KD, KH, KW = get_const_tuple(kernel.shape) + + # Pre-compute weight transformation in winograd + tile_size = _infer_tile_size(tinfos[0], tinfos[1]) + + weight = relay.nn.contrib_conv3d_winograd_weight_transform(inputs[1], tile_size=tile_size) + new_attrs['tile_size'] = tile_size + new_attrs['channels'] = CO + + # Store the same config for the altered operators (workload) + new_data = data + # Check if depth is transformed or not + if 2 < KD < 8 and KD == KH: + new_weight = te.placeholder( + (KD + tile_size - 1, KH + tile_size - 1, KW + tile_size - 1, CO, CI), + dtype=kernel.dtype) + else: + new_weight = te.placeholder( + (KH + tile_size - 1, KW + tile_size - 1, KD, CO, CI), + dtype=kernel.dtype) + new_workload = autotvm.task.args_to_workload( + [new_data, new_weight, strides, padding, dilation, out_dtype], + "conv3d_ncdhw_winograd_without_weight_transform.cuda") + dispatch_ctx.update(target, new_workload, cfg) + return relay.nn.contrib_conv3d_winograd_without_weight_transform( + inputs[0], weight, **new_attrs) + + return None diff --git a/topi/python/topi/cuda/conv3d_winograd.py b/topi/python/topi/cuda/conv3d_winograd.py new file mode 100644 index 0000000000000..c9e84468176dc --- /dev/null +++ b/topi/python/topi/cuda/conv3d_winograd.py @@ -0,0 +1,627 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument +"""Winograd template for cuda backend""" + +import logging +import tvm +from tvm import te +from tvm import autotvm + +from .. import nn +from ..util import get_const_int, get_const_tuple, traverse_inline, simplify +from ..nn.winograd_util import winograd_transform_matrices + +logger = logging.getLogger('conv3d_winograd') + + +def _infer_tile_size(data, kernel): + N, CI, D, H, W = get_const_tuple(data.shape) + + if H % 8 == 0: + return 4 + return 2 + + +def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed): + """Compute declaration for winograd""" + tile_size = _infer_tile_size(data, kernel) + + N, CI, D, H, W = get_const_tuple(data.shape) + + if isinstance(dilation, int): + dilation_d = dilation_h = dilation_w = dilation + else: + dilation_d, dilation_h, dilation_w = dilation + DSTR, HSTR, WSTR = (strides, strides, strides) if isinstance(strides, int) else strides + + if not pre_computed: # kernel tensor is raw tensor, do strict check + if dilation_d != 1 or dilation_h != 1 or dilation_w != 1: + kernel = nn.dilate(kernel, (1, 1, dilation_d, dilation_h, dilation_w)) + CO, CI, KD, KH, KW = get_const_tuple(kernel.shape) + alpha = KW + tile_size - 1 + assert DSTR == 1 and HSTR == 1 and WSTR == 1 and KD == KH and KH == KW + else: + # kernel tensor is pre-transformed. this op is created by alter op layout. + # dilation is not supported + alpha, _, _, CO, CI = get_const_tuple(kernel.shape) + KD = KH = KW = alpha + 1 - tile_size + assert DSTR == 1 and HSTR == 1 and WSTR == 1 and \ + dilation_d == 1 and dilation_h == 1 and dilation_w == 1 + + pf, pt, pl, pb, pd, pr = nn.get_pad_tuple3d(padding, (KD, KH, KW)) + data_pad = nn.pad(data, (0, 0, pf, pt, pl), (0, 0, pb, pd, pr), name="data_pad") + + r = KW + m = tile_size + A, B, G = winograd_transform_matrices(m, r, out_dtype) + + D = (D + pf + pb - KD) // DSTR + 1 + H = (H + pt + pd - KH) // HSTR + 1 + W = (W + pl + pr - KW) // WSTR + 1 + nD, nH, nW = (D + m - 1) // m, (H + m - 1) // m, (W + m - 1) // m + P = N * nD * nH * nW + + # transform kernel + if not pre_computed: + # Check if we are currently tuning, if so we want to avoid counting + # prepacking in time costs. Just use a placeholder with the packed shape instead. + if autotvm.GLOBAL_SCOPE.in_tuning: + kernel_pack = te.placeholder((alpha, alpha, alpha, CO, CI), + dtype=kernel.dtype, + name='kernel_pack') + else: + r_kd = te.reduce_axis((0, KD), name='r_kd') + r_kh = te.reduce_axis((0, KH), name='r_kh') + r_kw = te.reduce_axis((0, KW), name='r_kw') + kernel_pack = te.compute( + (alpha, alpha, alpha, CO, CI), + lambda omg, eps, nu, co, ci: te.sum( + kernel[co][ci][r_kd][r_kh][r_kw] * G[omg][r_kd] * G[eps][r_kh] * G[nu][r_kw], + axis=[r_kd, r_kh, r_kw]), + name='kernel_pack') + else: + kernel_pack = kernel + + idxdiv = tvm.tir.indexdiv + idxmod = tvm.tir.indexmod + # pack input tile + input_tile = te.compute((CI, P, alpha, alpha, alpha), + lambda c, p, omg, eps, nu: data_pad[idxdiv(p, (nD * nH * nW))] + [c] + [idxmod(idxdiv(p, nH * nW), nD) * m + omg] + [idxmod(idxdiv(p, nW), nH) * m + eps] + [idxmod(p, nW) * m + nu], + name='d') + + # transform data + r_a = te.reduce_axis((0, alpha), 'r_a') + r_b = te.reduce_axis((0, alpha), 'r_b') + r_c = te.reduce_axis((0, alpha), 'r_c') + data_pack = te.compute( + (alpha, alpha, alpha, CI, P), + lambda omg, eps, nu, ci, p: te.sum( + input_tile[ci][p][r_a][r_b][r_c] * B[r_a][omg] * B[r_b][eps] * B[r_c][nu], + axis=[r_a, r_b, r_c]), + name='data_pack') + + # do batch gemm + ci = te.reduce_axis((0, CI), name='ci') + bgemm = te.compute( + (alpha, alpha, alpha, CO, P), + lambda omg, eps, nu, co, p: te.sum( + kernel_pack[omg][eps][nu][co][ci] * data_pack[omg][eps][nu][ci][p], axis=[ci]), + name='bgemm') + + # inverse transform + r_a = te.reduce_axis((0, alpha), 'r_a') + r_b = te.reduce_axis((0, alpha), 'r_b') + r_c = te.reduce_axis((0, alpha), 'r_c') + inverse = te.compute((CO, P, m, m, m), + lambda co, p, vd, vh, vw: te.sum( + bgemm[r_a][r_b][r_c][co][p] * A[r_a][vd] * A[r_b][vh] * A[r_c][vw], + axis=[r_a, r_b, r_c]), + name='inverse') + + # output + output = te.compute((N, CO, D, H, W), + lambda n, co, d, h, w: inverse[co, n * nD * nH * nW + idxdiv(d, m) * nH * nW + + idxdiv(h, m) * nW + idxdiv(w, m), + idxmod(d, m), + idxmod(h, m), + idxmod(w, m)], + name='output', + tag='conv3d_ncdhw_winograd') + cfg.add_flop(2 * N * CO * D * H * W * CI * KD * KH * KW) + + return output + + +def winograd_without_depth_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, + pre_computed): + """Compute declaration for winograd without transforming depth""" + tile_size = _infer_tile_size(data, kernel) + + N, CI, D, H, W = get_const_tuple(data.shape) + + if isinstance(dilation, int): + dilation_d = dilation_h = dilation_w = dilation + else: + dilation_d, dilation_h, dilation_w = dilation + DSTR, HSTR, WSTR = (strides, strides, strides) if isinstance(strides, int) else strides + + if not pre_computed: # kernel tensor is raw tensor, do strict check + if dilation_d != 1 or dilation_h != 1 or dilation_w != 1: + kernel = nn.dilate(kernel, (1, 1, dilation_d, dilation_h, dilation_w)) + CO, CI, KD, KH, KW = get_const_tuple(kernel.shape) + alpha = KW + tile_size - 1 + assert HSTR == 1 and WSTR == 1 and KH == KW + else: + # kernel tensor is pre-transfomred. this op is created by alter op layout. + # dilation is not supported + alpha, _, KD, CO, CI = get_const_tuple(kernel.shape) + KH = KW = alpha + 1 - tile_size + assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1 + + pf, pt, pl, pb, pd, pr = nn.get_pad_tuple3d(padding, (KD, KH, KW)) + data_pad = nn.pad(data, (0, 0, pf, pt, pl), (0, 0, pb, pd, pr), name="data_pad") + out_depth = simplify((D - KD + pf + pb) // DSTR + 1) + D += pf + pb + + r = KW + m = tile_size + A, B, G = winograd_transform_matrices(m, r, out_dtype) + + H = (H + pt + pd - KH) // HSTR + 1 + W = (W + pl + pr - KW) // WSTR + 1 + nH, nW = (H + m-1) // m, (W + m-1) // m + P = N * nH * nW + + # transform kernel + if not pre_computed: + # During autotuning dont count kernel packing as a time cost + # as it will later be removed via alter_op_layout. + if autotvm.GLOBAL_SCOPE.in_tuning: + kernel_pack = te.placeholder((alpha, alpha, KD, CO, CI), + dtype=kernel.dtype, + name='kernel_pack') + else: + r_kh = te.reduce_axis((0, KH), name='r_kh') + r_kw = te.reduce_axis((0, KW), name='r_kw') + kernel_pack = te.compute( + (alpha, alpha, KD, CO, CI), + lambda eps, nu, d, co, ci: te.sum( + kernel[co][ci][d][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), + name='kernel_pack') + else: + kernel_pack = kernel + + idxdiv = tvm.tir.indexdiv + idxmod = tvm.tir.indexmod + # pack input tile + input_tile = te.compute((CI, D, P, alpha, alpha), lambda c, d, p, eps, nu: + data_pad[idxdiv(p, (nH * nW))][c][d] + [idxmod(idxdiv(p, nW), nH) * m + eps] + [idxmod(p, nW) * m + nu], name='d') + + # transform data + r_a = te.reduce_axis((0, alpha), 'r_a') + r_b = te.reduce_axis((0, alpha), 'r_b') + data_pack = te.compute((alpha, alpha, CI, D, P), lambda eps, nu, ci, d, p: + te.sum(input_tile[ci][d][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu], + axis=[r_a, r_b]), name='data_pack') + + # do batch gemm + ci = te.reduce_axis((0, CI), name='ci') + rz = te.reduce_axis((0, KD), name='rz') + bgemm = te.compute((alpha, alpha, CO, out_depth, P), lambda eps, nu, co, d, p: + te.sum(kernel_pack[eps][nu][rz][co][ci] * + data_pack[eps][nu][ci][d * DSTR + rz][p], + axis=[ci, rz]), name='bgemm') + + # inverse transform + r_a = te.reduce_axis((0, alpha), 'r_a') + r_b = te.reduce_axis((0, alpha), 'r_b') + inverse = te.compute((CO, out_depth, P, m, m), lambda co, d, p, vh, vw: + te.sum(bgemm[r_a][r_b][co][d][p] * A[r_a][vh] * A[r_b][vw], + axis=[r_a, r_b]), name='inverse') + + # output + output = te.compute((N, CO, out_depth, H, W), lambda n, co, d, h, w: + inverse[co, + d, + n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), + idxmod(h, m), + idxmod(w, m)], + name='output', tag='conv3d_ncdhw_winograd_without_depth') + cfg.add_flop(2 * N * CO * D * H * W * CI * KD * KH * KW) + + return output + + +def schedule_winograd_cuda(cfg, s, output, pre_computed): + """Schedule winograd template""" + # get stages + inverse = s[output].op.input_tensors[0] + bgemm, A = s[inverse].op.input_tensors + kernel_pack, data_pack = s[bgemm].op.input_tensors + input_tile, B = s[data_pack].op.input_tensors + pad_data = s[input_tile].op.input_tensors[0] + + # data transform + s[B].compute_inline() + + data_l = s.cache_write(data_pack, 'local') + omg, eps, nu, c, p = s[data_l].op.axis + r_a, r_b, r_c = s[data_l].op.reduce_axis + # TODO unrolling by omg, eps, nu may improve performance but + # in some cases causes extremely long build times due to imperfect tiling. + for axis in [r_a, r_b, r_c]: + s[data_l].unroll(axis) + + omg, eps, nu, c, p = s[data_pack].op.axis + p, pi = s[data_pack].split(p, 1) + fused = s[data_pack].fuse(c, p) + bb, tt = s[data_pack].split(fused, 128) + s[data_pack].reorder(bb, tt, pi, omg, eps, nu) + s[data_pack].bind(bb, te.thread_axis("blockIdx.x")) + s[data_pack].bind(tt, te.thread_axis("threadIdx.x")) + + s[data_l].compute_at(s[data_pack], pi) + s[input_tile].compute_at(s[data_pack], pi) + s[pad_data].compute_inline() + + # transform kernel + if not pre_computed and not autotvm.GLOBAL_SCOPE.in_tuning: + kernel, G = s[kernel_pack].op.input_tensors + omg, eps, nu, co, ci = s[kernel_pack].op.axis + s[G].compute_inline() + r_a, r_b, r_c = s[kernel_pack].op.reduce_axis + # Could add additional unrolling by omg, eps, nu in the future. + for axis in [r_a, r_b, r_c]: + s[kernel_pack].unroll(axis) + + fused = s[kernel_pack].fuse(co, ci) + bb, tt = s[kernel_pack].split(fused, 128) + s[kernel_pack].reorder(bb, tt, omg, eps, nu, r_a, r_b, r_c) + s[kernel_pack].bind(bb, te.thread_axis("blockIdx.x")) + s[kernel_pack].bind(tt, te.thread_axis("threadIdx.x")) + else: + kernel = kernel_pack + + if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + + ##### space definition begin ##### + b1, b2, b3, y, x = s[bgemm].op.axis + rc = s[bgemm].op.reduce_axis[0] + alpha = get_const_int(b1.dom.extent) + + cfg.define_split( + "tile_b", + cfg.axis(alpha * alpha * alpha), + num_outputs=4, + filter=lambda x: x.size[-3:] == [1, 1, 1]) + cfg.define_split("tile_y", y, num_outputs=4) + cfg.define_split("tile_x", x, num_outputs=4) + cfg.define_split("tile_rc", rc, num_outputs=2) + cfg.define_knob("auto_unroll_max_step", [0, 128, 1500]) + target = tvm.target.Target.current() + if target.target_name in ['nvptx', 'rocm']: + cfg.define_knob("unroll_explicit", [1]) + else: + cfg.define_knob("unroll_explicit", [0, 1]) + ##### space definition end ##### + + # batch gemm + C = bgemm + A0, B0 = kernel_pack, data_pack + + OL = s.cache_write(C, 'local') + AA = s.cache_read(A0, 'shared', [OL]) + BB = s.cache_read(B0, 'shared', [OL]) + + b = s[bgemm].fuse(b1, b2, b3) + + # tile and bind spatial axes + bgemm_scope, b = s[bgemm].split(b, nparts=1) + bz, vz, tz, zi = cfg["tile_b"].apply(s, C, b) + by, vy, ty, yi = cfg["tile_y"].apply(s, C, y) + bx, vx, tx, xi = cfg["tile_x"].apply(s, C, x) + s[C].bind(bz, te.thread_axis("blockIdx.z")) + s[C].bind(by, te.thread_axis("blockIdx.y")) + s[C].bind(bx, te.thread_axis("blockIdx.x")) + s[C].bind(vz, te.thread_axis("vthread")) + s[C].bind(vy, te.thread_axis("vthread")) + s[C].bind(vx, te.thread_axis("vthread")) + s[C].bind(tz, te.thread_axis("threadIdx.z")) + s[C].bind(ty, te.thread_axis("threadIdx.y")) + s[C].bind(tx, te.thread_axis("threadIdx.x")) + s[C].reorder(bgemm_scope, bz, by, bx, vz, vy, vx, tz, ty, tx, zi, yi, xi) + + # tile reduction axes + s[OL].compute_at(s[C], tx) + b1, b2, b3, y, x = s[OL].op.axis + b = s[OL].fuse(b1, b2, b3) + rc, = s[OL].op.reduce_axis + rco, rci = cfg['tile_rc'].apply(s, OL, rc) + s[OL].reorder(rco, rci, b, y, x) + + s[AA].compute_at(s[OL], rco) + s[BB].compute_at(s[OL], rco) + + # cooperative fetching + for load in [AA, BB]: + fused = s[load].fuse(*list(s[load].op.axis)) + fused, tx = s[load].split(fused, cfg["tile_x"].size[2]) + fused, ty = s[load].split(fused, cfg["tile_y"].size[2]) + fused, tz = s[load].split(fused, cfg["tile_b"].size[2]) + s[load].bind(tz, te.thread_axis("threadIdx.z")) + s[load].bind(ty, te.thread_axis("threadIdx.y")) + s[load].bind(tx, te.thread_axis("threadIdx.x")) + + s[C].pragma(bgemm_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val) + s[C].pragma(bgemm_scope, 'unroll_explicit', cfg['unroll_explicit'].val) + + # schedule inverse, output and fusion + if output.op in s.outputs: + OL = None + else: + OL = output + s[OL].set_scope('local') + output = s.outputs[0] + + m = alpha - 3 + 1 + n, co, d, h, w = s[output].op.axis + do, di = s[output].split(d, m) + ho, hi = s[output].split(h, m) + wo, wi = s[output].split(w, m) + s[output].reorder(n, co, do, ho, wo, di, hi, wi) + inverse_scope, n = s[output].split(n, nparts=1) + + fused = s[output].fuse(n, co, do, ho, wo) + bb, tt = s[output].split(fused, 128) + + s[output].bind(bb, te.thread_axis("blockIdx.x")) + s[output].bind(tt, te.thread_axis("threadIdx.x")) + + if OL is not None: + s[OL].compute_at(s[output], tt) + + s[A].compute_inline() + co, p, vd, vh, vw = s[inverse].op.axis + r_a, r_b, r_c = s[inverse].op.reduce_axis + # Could add additional unrolling of vd, vh, vw, in the future + for axis in [r_a, r_b, r_c]: + s[inverse].unroll(axis) + s[inverse].compute_at(s[output], tt) + + return s + + +def schedule_winograd_no_depth_cuda(cfg, s, output, pre_computed): + """Schedule winograd template""" + # get stages + inverse = s[output].op.input_tensors[0] + bgemm, A = s[inverse].op.input_tensors + kernel_pack, data_pack = s[bgemm].op.input_tensors + input_tile, B = s[data_pack].op.input_tensors + pad_data = s[input_tile].op.input_tensors[0] + + # data transform + s[B].compute_inline() + + data_l = s.cache_write(data_pack, 'local') + eps, nu, c, d, p = s[data_l].op.axis + r_a, r_b = s[data_l].op.reduce_axis + for axis in [eps, nu, r_a, r_b]: + s[data_l].unroll(axis) + + eps, nu, c, d, p = s[data_pack].op.axis + p, pi = s[data_pack].split(p, 1) + fused = s[data_pack].fuse(c, d, p) + bb, tt = s[data_pack].split(fused, 128) + s[data_pack].reorder(bb, tt, pi, eps, nu) + s[data_pack].bind(bb, te.thread_axis("blockIdx.x")) + s[data_pack].bind(tt, te.thread_axis("threadIdx.x")) + + s[data_l].compute_at(s[data_pack], pi) + s[input_tile].compute_at(s[data_pack], pi) + s[pad_data].compute_inline() + + # transform kernel + if not pre_computed and not autotvm.GLOBAL_SCOPE.in_tuning: + kernel, G = s[kernel_pack].op.input_tensors + eps, nu, kd, co, ci = s[kernel_pack].op.axis + s[G].compute_inline() + r_a, r_b = s[kernel_pack].op.reduce_axis + for axis in [eps, nu, r_a, r_b]: + s[kernel_pack].unroll(axis) + + fused = s[kernel_pack].fuse(kd, co, ci) + bb, tt = s[kernel_pack].split(fused, 128) + s[kernel_pack].reorder(bb, tt, eps, nu, r_a, r_b) + s[kernel_pack].bind(bb, te.thread_axis("blockIdx.x")) + s[kernel_pack].bind(tt, te.thread_axis("threadIdx.x")) + else: + kernel = kernel_pack + + if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + + ##### space definition begin ##### + b1, b2, z, y, x = s[bgemm].op.axis + # Combine channel and depth axes. + rc = s[bgemm].op.reduce_axis[0] + rz = s[bgemm].op.reduce_axis[1] + alpha = get_const_int(b1.dom.extent) + + cfg.define_split("tile_b", cfg.axis(alpha * alpha), num_outputs=4, + filter=lambda x: x.size[-3:] == [1, 1, 1]) + cfg.define_split("tile_y", y, num_outputs=4) + cfg.define_split("tile_x", x, num_outputs=4) + cfg.define_split("tile_rc", rc, num_outputs=2) + cfg.define_split("tile_rz", rz, num_outputs=2) + cfg.define_knob("auto_unroll_max_step", [0, 128, 1500]) + target = tvm.target.Target.current() + if target.target_name in ['nvptx', 'rocm']: + cfg.define_knob("unroll_explicit", [1]) + else: + cfg.define_knob("unroll_explicit", [0, 1]) + ##### space definition end ##### + + # batch gemm + C = bgemm + A0, B0 = kernel_pack, data_pack + + OL = s.cache_write(C, 'local') + AA = s.cache_read(A0, 'shared', [OL]) + BB = s.cache_read(B0, 'shared', [OL]) + + b = s[bgemm].fuse(b1, b2) + y = s[bgemm].fuse(z, y) + + # tile and bind spatial axes + bgemm_scope, b = s[bgemm].split(b, nparts=1) + bz, vz, tz, zi = cfg["tile_b"].apply(s, C, b) + by, vy, ty, yi = cfg["tile_y"].apply(s, C, y) + bx, vx, tx, xi = cfg["tile_x"].apply(s, C, x) + s[C].bind(bz, te.thread_axis("blockIdx.z")) + s[C].bind(by, te.thread_axis("blockIdx.y")) + s[C].bind(bx, te.thread_axis("blockIdx.x")) + s[C].bind(vz, te.thread_axis("vthread")) + s[C].bind(vy, te.thread_axis("vthread")) + s[C].bind(vx, te.thread_axis("vthread")) + s[C].bind(tz, te.thread_axis("threadIdx.z")) + s[C].bind(ty, te.thread_axis("threadIdx.y")) + s[C].bind(tx, te.thread_axis("threadIdx.x")) + s[C].reorder(bgemm_scope, bz, by, bx, vz, vy, vx, tz, ty, tx, zi, yi, xi) + + # tile reduction axes + s[OL].compute_at(s[C], tx) + b1, b2, y1, y2, x = s[OL].op.axis + y = s[OL].fuse(y1, y2) + b = s[OL].fuse(b1, b2) + rc, rz = s[OL].op.reduce_axis + rco, rci = cfg['tile_rc'].apply(s, OL, rc) + rzo, rzi = cfg['tile_rz'].apply(s, OL, rz) + s[OL].reorder(rco, rzo, rci, rzi, b, y, x) + + s[AA].compute_at(s[OL], rco) + s[BB].compute_at(s[OL], rco) + + # cooperative fetching + for load in [AA, BB]: + fused = s[load].fuse(*list(s[load].op.axis)) + fused, tx = s[load].split(fused, cfg["tile_x"].size[2]) + fused, ty = s[load].split(fused, cfg["tile_y"].size[2]) + fused, tz = s[load].split(fused, cfg["tile_b"].size[2]) + s[load].bind(tz, te.thread_axis("threadIdx.z")) + s[load].bind(ty, te.thread_axis("threadIdx.y")) + s[load].bind(tx, te.thread_axis("threadIdx.x")) + + s[C].pragma(bgemm_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val) + s[C].pragma(bgemm_scope, 'unroll_explicit', cfg['unroll_explicit'].val) + + # schedule inverse, output and fusion + if output.op in s.outputs: + OL = None + else: + OL = output + s[OL].set_scope('local') + output = s.outputs[0] + + m = alpha - 3 + 1 + n, co, d, h, w = s[output].op.axis + do, di = s[output].split(d, m) + ho, hi = s[output].split(h, m) + wo, wi = s[output].split(w, m) + s[output].reorder(n, co, do, ho, wo, di, hi, wi) + inverse_scope, n = s[output].split(n, nparts=1) + + fused = s[output].fuse(n, co, do, ho, wo) + bb, tt = s[output].split(fused, 128) + + s[output].bind(bb, te.thread_axis("blockIdx.x")) + s[output].bind(tt, te.thread_axis("threadIdx.x")) + + if OL is not None: + s[OL].compute_at(s[output], tt) + + s[A].compute_inline() + co, d, p, vh, vw = s[inverse].op.axis + r_a, r_b = s[inverse].op.reduce_axis + for axis in [vh, vw, r_a, r_b]: + s[inverse].unroll(axis) + s[inverse].compute_at(s[output], tt) + + return s + + +@autotvm.register_topi_compute("conv3d_ncdhw_winograd.cuda") +def conv3d_ncdhw_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype): + CO, CI, KD, KH, KW = get_const_tuple(kernel.shape) + # Check if we can transform depth. + if 2 < KD < 8 and KD == KH: + return winograd_cuda( + cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed=False) + + return winograd_without_depth_cuda( + cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed=False) + + +@autotvm.register_topi_schedule("conv3d_ncdhw_winograd.cuda") +def schedule_conv3d_ncdhw_winograd(cfg, outs): + """Dispatch to schedule approriate for conv3d winograd algorithm used.""" + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if 'conv3d_ncdhw_winograd_without_depth' in op.tag: + schedule_winograd_no_depth_cuda(cfg, s, op.output(0), pre_computed=False) + elif 'conv3d_ncdhw_winograd' in op.tag: + schedule_winograd_cuda(cfg, s, op.output(0), pre_computed=False) + + traverse_inline(s, outs[0].op, _callback) + return s + + +@autotvm.register_topi_compute("conv3d_ncdhw_winograd_without_weight_transform.cuda") +def conv3d_ncdhw_winograd_without_weight_transform(cfg, data, kernel, strides, padding, dilation, + out_dtype): + A, B, C, _, _ = get_const_tuple(kernel.shape) + # Check if we can transform depth. + if A == B == C: + return winograd_cuda( + cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed=True) + + return winograd_without_depth_cuda( + cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed=True) + + +@autotvm.register_topi_schedule("conv3d_ncdhw_winograd_without_weight_transform.cuda") +def schedule_conv3d_ncdhw_winograd_without_weight_transform(cfg, outs): + """TOPI schedule callback""" + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if 'conv3d_ncdhw_winograd_without_depth' in op.tag: + schedule_winograd_no_depth_cuda(cfg, s, op.output(0), pre_computed=True) + elif 'conv3d_ncdhw_winograd' in op.tag: + schedule_winograd_cuda(cfg, s, op.output(0), pre_computed=True) + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index 43b12822b2392..2be4bbb456dec 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -187,6 +187,43 @@ def schedule_conv2d_winograd_weight_transform(outs): return s +def schedule_conv3d_winograd_weight_transform(outs): + """Schedule for weight transformation of 3D winograd + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of this operator + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + # Typically this is computed in PreCompute pass + # so we make a schedule here for cpu llvm + s = te.create_schedule([x.op for x in outs]) + output = outs[0] + _, G = s[output].op.input_tensors + s[G].compute_inline() + transform_depth = len(s[output].op.reduce_axis) == 3 + if transform_depth: + omg, eps, nu, ci, co = s[output].op.axis + r_kd, r_kh, r_kw = s[output].op.reduce_axis + s[output].reorder(co, ci, omg, eps, nu, r_kd, r_kh, r_kw) + for axis in [r_kd, r_kh, r_kw]: + s[output].unroll(axis) + else: + eps, nu, d, ci, co = s[output].op.axis + r_kh, r_kw = s[output].op.reduce_axis + s[output].reorder(co, ci, d, eps, nu, r_kh, r_kw) + for axis in [r_kh, r_kw]: + s[output].unroll(axis) + s[output].parallel(co) + return s + + def schedule_conv2d_winograd_without_weight_transform(outs): """Schedule for winograd without weight transformation diff --git a/topi/python/topi/nn/conv3d.py b/topi/python/topi/nn/conv3d.py index d6bd6424a9477..2bac284ab4014 100644 --- a/topi/python/topi/nn/conv3d.py +++ b/topi/python/topi/nn/conv3d.py @@ -17,11 +17,13 @@ # pylint: disable=invalid-name, unused-variable, too-many-locals # pylint: disable=unused-argument, redefined-builtin, no-else-return """Conv3D operators""" +import tvm from tvm import te from .pad import pad from .util import get_pad_tuple3d -from ..util import simplify +from ..util import simplify, get_const_tuple +from .winograd_util import winograd_transform_matrices def conv3d_ncdhw(Input, Filter, stride, padding, dilation, out_dtype=None): @@ -159,3 +161,74 @@ def conv3d_ndhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'): Filter[rd, rh, rw, rc, cc].astype(out_dtype), axis=[rd, rh, rw, rc]), name="Conv3dOutput", tag="conv3d_ndhwc") return Output + + +def conv3d_winograd_weight_transform(kernel, tile_size): + """Weight transformation for 3D winograd + + Parameters + ---------- + kernel: Tensor + The raw kernel tensor with layout "NCDHW". + tile_size: int + Tile size of winograd transform. e.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3) + + Returns + ------- + output : tvm.te.Tensor + 5-D with shape [alpha, alpha, alpha, CO, CI] + """ + CO, CI, KD, KH, KW = get_const_tuple(kernel.shape) + + depth_transform = 2 < KD < 8 and KD == KH + + if depth_transform: + assert KD == KH == KW, "Only support NxNxN kernel" + else: + assert KH == KW, "Only supports DxNxN kernel" + + r = tile_size + KH - 1 + + r_kh = te.reduce_axis((0, KH), name='r_kh') + r_kw = te.reduce_axis((0, KW), name='r_kw') + _, _, G = winograd_transform_matrices(tile_size, KH, kernel.dtype) + if depth_transform: + shape = (r, r, r, CO, CI) + r_kd = te.reduce_axis((0, KD), name='r_kd') + return te.compute( + shape, + lambda omg, eps, nu, co, ci: te.sum( + kernel[co][ci][r_kd][r_kh][r_kw] * G[omg][r_kd] * G[eps][r_kh] * G[nu][r_kw], + axis=[r_kd, r_kh, r_kw]), + name='transform_weight') + else: + shape = (r, r, KD, CO, CI) + return te.compute( + shape, + lambda eps, nu, d, co, ci: te.sum( + kernel[co][ci][d][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), + name='transform_weight') + + + +@tvm.target.generic_func +def conv3d_alter_layout(attrs, inputs, tinfos, out_type): + """Change Conv3D layout. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current convolution + inputs : tvm.relay.Expr + Grouped input symbols + tinfos : list + Input shape and dtype + out_type: type + The output type + + Note + ---- + Unlike other TOPI functions, this function operates on both graph level and operator level. + """ + # not to change by default + return None diff --git a/topi/tests/python/test_topi_conv3d_winograd.py b/topi/tests/python/test_topi_conv3d_winograd.py new file mode 100644 index 0000000000000..6d0d99d00b10b --- /dev/null +++ b/topi/tests/python/test_topi_conv3d_winograd.py @@ -0,0 +1,151 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for 3d convolution with winograd.""" + +import numpy as np +import tvm +from tvm import te +from tvm import autotvm +import topi +import topi.testing +from tvm.contrib.pickle_memoize import memoize +from topi.nn.util import get_pad_tuple3d +from topi.util import get_const_tuple + +from common import get_all_backend + +_conv3d_ncdhw_implement = { + "gpu": (topi.cuda.conv3d_ncdhw_winograd, topi.cuda.schedule_conv3d_ncdhw_winograd), +} + + +def verify_conv3d_ncdhw(batch, + in_channel, + in_size, + num_filter, + depth_kernel, + space_kernel, + stride, + padding, + dilation=1, + add_bias=False, + add_relu=False): + pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = get_pad_tuple3d( + padding, (depth_kernel, space_kernel, space_kernel)) + padding_sum = pad_front + pad_back + pad_top + pad_left + pad_bottom + pad_right + print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % + (batch, in_channel, in_size, num_filter, space_kernel, stride, padding_sum, dilation)) + + in_depth = in_height = in_width = in_size + + A = te.placeholder((batch, in_channel, in_depth, in_height, in_width), name='A') + W = te.placeholder((num_filter, in_channel, depth_kernel, space_kernel, space_kernel), name='W') + bias = te.placeholder((num_filter, 1, 1, 1), name='bias') + + a_shape = get_const_tuple(A.shape) + w_shape = get_const_tuple(W.shape) + bias_shape = get_const_tuple(bias.shape) + dtype = A.dtype + + @memoize("topi.tests.test_topi_conv3d_ncdhw.verify_conv3d_ncdhw") + def get_ref_data(): + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(dtype) + b_np = np.random.uniform(size=bias_shape).astype(dtype) + dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation, dilation)) + c_np = topi.testing.conv3d_ncdhw_python(a_np, dw_np, stride, padding) + if add_bias: + c_np += b_np + if add_relu: + c_np = np.maximum(c_np, 0) + return a_np, w_np, b_np, c_np + + a_np, w_np, b_np, c_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + fcompute, fschedule = topi.testing.dispatch(device, _conv3d_ncdhw_implement) + with tvm.target.create(device): + C = fcompute(A, W, (stride, stride, stride), padding, (dilation, dilation, dilation), + dtype) + if add_bias: + C = topi.add(C, bias) + if add_relu: + C = topi.nn.relu(C) + s = fschedule([C]) + + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) + if add_bias: + func = tvm.build( + s, [A, W, bias, C], + device, + name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % + (batch, in_channel, in_size, num_filter, space_kernel, stride, padding_sum, dilation)) + func(a, w, b, c) + else: + func = tvm.build( + s, [A, W, C], + device, + name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % + (batch, in_channel, in_size, num_filter, space_kernel, stride, padding_sum, dilation)) + func(a, w, c) + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4) + + for device in ["cuda"]: + with autotvm.tophub.context(device): # load tophub pre-tuned parameters + check_device(device) + + +def test_conv3d_ncdhw(): + # Try without depth transformation + #3DCNN workloads + verify_conv3d_ncdhw(1, 61, 20, 120, 3, 3, 1, 0) + verify_conv3d_ncdhw(1, 61, 20, 120, 1, 3, 1, 0) + verify_conv3d_ncdhw(1, 61, 20, 120, 5, 3, 1, 0) + verify_conv3d_ncdhw(1, 61, 20, 120, 5, 5, 1, 2) + verify_conv3d_ncdhw(1, 61, 20, 120, 1, 5, 1, 2) + verify_conv3d_ncdhw(1, 61, 20, 120, 7, 7, 1, 3) + verify_conv3d_ncdhw(1, 128, 12, 256, 3, 3, 1, 1) + verify_conv3d_ncdhw(1, 64, 12, 128, 3, 3, 1, 1) + + # bias, relu + verify_conv3d_ncdhw(1, 64, 12, 128, 3, 3, 1, 1, add_relu=True) + verify_conv3d_ncdhw(1, 64, 12, 128, 3, 3, 1, 1, add_relu=True, add_bias=True) + verify_conv3d_ncdhw(1, 64, 12, 128, 1, 3, 1, 1, add_relu=True, add_bias=True) + + # dilation = 2 + verify_conv3d_ncdhw(1, 16, 12, 16, 3, 3, 1, "VALID", dilation=2) + verify_conv3d_ncdhw(1, 16, 12, 16, 1, 3, 1, "VALID", dilation=2) + + # batch size + verify_conv3d_ncdhw(4, 32, 12, 64, 3, 3, 1, 1) + verify_conv3d_ncdhw(4, 32, 12, 64, 1, 3, 1, 1) + + # weird workloads + verify_conv3d_ncdhw(2, 2, 2, 2, 3, 3, 1, 2) + verify_conv3d_ncdhw(3, 3, 3, 3, 3, 3, 1, 3) + + +if __name__ == "__main__": + test_conv3d_ncdhw()