Skip to content

Commit

Permalink
[Relay][Topi][AutoTVM] Winograd support for Conv3D (apache#5186)
Browse files Browse the repository at this point in the history
* Functional conv3d winograd working.

* Formatted python code.

* registered conv3d winograd compute and started adding relay without_weight_transform operator.

* Add topi testing for conv3d winograd.

* Format file.

* small tweak to unrolling to prevent build sticking.

* Refactoring convolution ops in relay.

* Refactored relay convolutions.

* Bug fixes.

* Fixed static bug in convolution.

* Added conv3d alter op layout and related support.

* Bug fixes and testing done.

* Fix a few autotvm bugs.

* Drop silly debug print.

* Removed debug_skip_region.

* Add variant of conv3d_winograd that doesn't transform depth.

* initial infrastructure done for depthless conv.

* Fix no_depth schedule bugs.

* automatic topi switching between depth and depthless winograd.

* Fixed bug in schedule.

* lint fixes.

* Removed indents in convolution.cc

* missed a few indents oops.

* fixed flop count.

* One more small tweak.

* Change kernel pack inner axes order.

* Style changes.

* Comment fixes.
  • Loading branch information
jwfromm authored and dpankratz committed Apr 24, 2020
1 parent 216b717 commit 4c2e3cd
Show file tree
Hide file tree
Showing 17 changed files with 2,192 additions and 621 deletions.
5 changes: 5 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,13 @@ This level enables typical convnet models.
tvm.relay.nn.pad
tvm.relay.nn.lrn
tvm.relay.nn.l2_normalize
tvm.relay.nn.bitpack
tvm.relay.nn.bitserial_dense
tvm.relay.nn.bitserial_conv2d
tvm.relay.nn.contrib_conv2d_winograd_without_weight_transform
tvm.relay.nn.contrib_conv2d_winograd_weight_transform
tvm.relay.nn.contrib_conv3d_winograd_without_weight_transform
tvm.relay.nn.contrib_conv3d_winograd_weight_transform


**Level 3: Additional Math And Transform Operators**
Expand Down
71 changes: 67 additions & 4 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,12 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
};

/*! \brief Attributes used in winograd weight transformation operators */
struct Conv2DWinogradWeightTransformAttrs :
public tvm::AttrsNode<Conv2DWinogradWeightTransformAttrs> {
struct ConvWinogradWeightTransformAttrs :
public tvm::AttrsNode<ConvWinogradWeightTransformAttrs> {
int tile_size;

TVM_DECLARE_ATTRS(Conv2DWinogradWeightTransformAttrs,
"relay.attrs.Conv2DWinogradWeightTransformAttrs") {
TVM_DECLARE_ATTRS(ConvWinogradWeightTransformAttrs,
"relay.attrs.ConvWinogradWeightTransformAttrs") {
TVM_ATTR_FIELD(tile_size)
.describe("Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)");
}
Expand Down Expand Up @@ -306,6 +306,69 @@ struct Conv3DAttrs : public tvm::AttrsNode<Conv3DAttrs> {
}
};

/*! \brief Attributes used in 3d winograd convolution operators */
struct Conv3DWinogradAttrs : public tvm::AttrsNode<Conv3DWinogradAttrs> {
int tile_size;
Array<IndexExpr> strides;
Array<IndexExpr> padding;
Array<IndexExpr> dilation;
int groups;
IndexExpr channels;
Array<IndexExpr> kernel_size;
std::string data_layout;
std::string kernel_layout;
std::string out_layout;
DataType out_dtype;

TVM_DECLARE_ATTRS(Conv3DWinogradAttrs, "relay.attrs.Conv3DWinogradAttrs") {
TVM_ATTR_FIELD(tile_size)
.describe("The tile size of winograd. E.g. 2 for F(2x2x2, 3x3x3) and 4 for F(4x4x4, 3x3x3)");
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1, 1}))
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"three int : back, bottom, right will use same padding as front, top, left"
"six int : padding width in the order of (front, top, left, back, bottom,"
"right)");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(groups).set_default(1)
.describe("Controls the connections between inputs and outputs."
"At groups=1, all inputs are convolved to all outputs."
"At groups=2, the operation becomes equivalent to having two convolution"
"layers side by side, each seeing half the input channels, and producing"
"half the output channels, and both subsequently concatenated.");
TVM_ATTR_FIELD(channels)
.describe("The number of output channels in the convolution."
" If it is not set, inferred by shape of the weight.")
.set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(kernel_size)
.describe("Specifies the dimensions of the convolution window.")
.set_default(NullValue<Array<IndexExpr> >());
TVM_ATTR_FIELD(data_layout).set_default("NCDHW")
.describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
"dimensions respectively. Convolution is applied on the 'D', 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(kernel_layout).set_default("OIDHW")
.describe("Dimension ordering of weight. Can be 'OIDHW', 'OIDHW16o16i', etc."
"'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height,"
"and width dimensions respectively.");
TVM_ATTR_FIELD(out_layout).set_default("")
.describe("Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc."
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
"dimensions respectively. Default to be same as input layout.");

// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};


/*! \brief Attributes used in softmax operators */
struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> {
int axis;
Expand Down
23 changes: 23 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,29 @@ def legalize_conv2d_transpose(attrs, inputs, types):
reg.register_strategy("nn.conv3d", strategy.conv3d_strategy)
reg.register_pattern("nn.conv3d", OpPattern.OUT_ELEMWISE_FUSABLE)

@reg.register_alter_op_layout("nn.conv3d")
def alter_op_layout_conv3d(attrs, inputs, tinfos, out_type):
"""Alternate the layout of conv3d"""
return topi.nn.conv3d_alter_layout(attrs, inputs, tinfos, out_type)

# conv3d_winograd related operators
reg.register_strategy("nn.contrib_conv3d_winograd_without_weight_transform",
strategy.conv3d_winograd_without_weight_transfrom_strategy)
reg.register_pattern("nn.contrib_conv3d_winograd_without_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE)

@reg.register_compute("nn.contrib_conv3d_winograd_weight_transform")
def compute_contrib_conv3d_winograd_weight_transform(attrs, inputs, out_dtype):
"""Compute definition of contrib_conv3d_winograd_weight_transform"""
out = topi.nn.conv3d_winograd_weight_transform(
inputs[0], attrs.get_int('tile_size'))
return [out]

reg.register_schedule("nn.contrib_conv3d_winograd_weight_transform",
strategy.schedule_conv3d_winograd_weight_transform)
reg.register_pattern("nn.contrib_conv3d_winograd_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE)


# conv1d_transpose
reg.register_strategy("nn.conv1d_transpose", strategy.conv1d_transpose_strategy)
Expand Down
100 changes: 97 additions & 3 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from __future__ import absolute_import as _abs
from ...expr import TupleWrapper
from . import _make
from .util import get_pad_tuple2d
from .util import get_pad_tuple2d, get_pad_tuple3d


def conv1d(data,
Expand Down Expand Up @@ -295,13 +295,84 @@ def conv3d(data,
strides = (strides, strides, strides)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
if isinstance(padding, int):
padding = (padding, padding, padding)
padding = get_pad_tuple3d(padding)
return _make.conv3d(data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)


def contrib_conv3d_winograd_without_weight_transform(data,
weight,
tile_size,
strides=(1, 1, 1),
padding=(0, 0, 0),
dilation=(1, 1, 1),
groups=1,
channels=None,
kernel_size=None,
data_layout="NCDHW",
kernel_layout="OIDHW",
out_layout="",
out_dtype=""):
r"""3D convolution with winograd algorithm.
The basic parameters are the same as the ones in vanilla conv3d.
It assumes the weight is pre-transformed by nn.contrib_conv3d_winograd_weight_transform
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
weight : tvm.relay.Expr
The weight expressions.
tile_size : int
The Tile size of winograd. E.g. 2 for F(2x2x2, 3x3x3) and 4 for F(4x4x4, 3x3x3)
strides : tuple of int, optional
The strides of convolution.
padding : tuple of int, optional
The padding of convolution on both sides of inputs before convolution.
dilation : tuple of int, optional
Specifies the dilation rate to be used for dilated convolution.
groups : int, optional
Number of groups for grouped convolution.
channels : int, optional
Number of output channels of this convolution.
kernel_size : tuple of int, optional
The spatial of the convolution kernel.
data_layout : str, optional
Layout of the input.
kernel_layout : str, optional
Layout of the weight.
out_layout : str, optional
Layout of the output, by default, out_layout is the same as data_layout
out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
# convert 3-way padding to 6-way padding
padding = get_pad_tuple3d(padding)
return _make.contrib_conv3d_winograd_without_weight_transform(
data, weight, tile_size, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)


def conv2d_transpose(data,
weight,
strides=(1, 1),
Expand Down Expand Up @@ -1952,6 +2023,29 @@ def contrib_conv2d_winograd_weight_transform(weight,
return _make.contrib_conv2d_winograd_weight_transform(weight, tile_size)


def contrib_conv3d_winograd_weight_transform(weight,
tile_size):
r"""Weight Transformation part for 3D convolution with winograd algorithm.
We separate this as a single op to enable pre-compute for inference.
Use this together with nn.contrib_conv3d_winograd_without_weight_transform
Parameters
----------
weight : tvm.relay.Expr
The weight expressions.
tile_size : int
The Tile size of winograd. E.g. 2 for F(2x2x2, 3x3x3) and 4 for F(4x4x4, 3x3x3)
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.contrib_conv3d_winograd_weight_transform(weight, tile_size)


def contrib_conv2d_winograd_nnpack_weight_transform(weight,
convolution_algorithm,
out_dtype=""):
Expand Down
43 changes: 43 additions & 0 deletions python/tvm/relay/op/nn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,46 @@ def get_pad_tuple2d(padding):
pad_top = (pad_h + 1) // 2
pad_left = (pad_w + 1) // 2
return pad_top, pad_left, pad_h - pad_top, pad_w - pad_left


def get_pad_tuple3d(padding):
"""Common code to get the pad option
Parameters
----------
padding : Union[int, Tuple[int, ...]]
Padding size
Returns
-------
pad_front : int
Padding size on front
pad_top : int
Padding size on top
pad_left : int
Padding size on left
pad_back : int
Padding size on back
pad_down : int
Padding size on down.
pad_right : int
Padding size on right.
"""
# compute the padding size
if isinstance(padding, container.Array):
padding = list(padding)
if isinstance(padding, (tuple, list)):
if len(padding) == 3:
pad_d = padding[0] * 2
pad_h = padding[1] * 2
pad_w = padding[2] * 2
elif len(padding) == 6:
return padding[0], padding[1], padding[2], padding[3], padding[4], padding[5]
else:
raise ValueError("Size of padding can only be 3 or 6")
elif isinstance(padding, int):
pad_d = pad_h = pad_w = padding * 2
else:
raise ValueError("Unknown padding option %s" % padding)
pad_front = (pad_d + 1) // 2
pad_top = (pad_h + 1) // 2
pad_left = (pad_w + 1) // 2
return pad_front, pad_top, pad_left, pad_d - pad_front, pad_h - pad_top, pad_w - pad_left
16 changes: 13 additions & 3 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,19 @@ class Conv2DWinogradAttrs(Attrs):
"""Attributes for nn.contrib_conv2d_winograd_without_weight_transform"""


@tvm._ffi.register_object("relay.attrs.Conv2DWinogradWeightTransformAttrs")
class Conv2DWinogradWeightTransformAttrs(Attrs):
"""Attributes for nn.contrib_conv2d_winograd_weight_transform"""
@tvm._ffi.register_object("relay.attrs.Conv3DAttrs")
class Conv3DAttrs(Attrs):
"""Attributes for nn.conv3d"""


@tvm._ffi.register_object("relay.attrs.Conv3DWinogradAttrs")
class Conv3DWinogradAttrs(Attrs):
"""Attributes for nn.contrib_conv3d_winograd_without_weight_transform"""


@tvm._ffi.register_object("relay.attrs.ConvWinogradWeightTransformAttrs")
class ConvWinogradWeightTransformAttrs(Attrs):
"""Attributes for nn.contrib_convNd_winograd_weight_transform"""


@tvm._ffi.register_object("relay.attrs.Conv2DWinogradNNPACKWeightTransformAttrs")
Expand Down
32 changes: 32 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,13 +233,25 @@ def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target):
def conv3d_strategy_cuda(attrs, inputs, out_type, target):
"""conv3d cuda strategy"""
strategy = _op.OpStrategy()
_, kernel = inputs
layout = attrs.data_layout
_, stride_h, stride_w = attrs.get_int_tuple("strides")
_, dilation_h, dilation_w = attrs.get_int_tuple("dilation")
assert layout in ["NCDHW", "NDHWC"], "Not support this layout {} yet".format(layout)
if layout == "NCDHW":
strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_ncdhw),
wrap_topi_schedule(topi.cuda.schedule_conv3d_ncdhw),
name="conv3d_ncdhw.cuda",
plevel=10)
_, _, _, kh, kw = get_const_tuple(kernel.shape)
if 2 < kh < 8 and 2 < kw < 8 and kh == kw and \
stride_h == 1 and stride_w == 1 and \
dilation_h == 1 and dilation_w == 1:
strategy.add_implementation(
wrap_compute_conv3d(topi.cuda.conv3d_ncdhw_winograd),
wrap_topi_schedule(topi.cuda.schedule_conv3d_ncdhw_winograd),
name="conv3d_ncdhw_winograd.cuda",
plevel=5)
else: # layout == "NDHWC":
strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_ndhwc),
wrap_topi_schedule(topi.cuda.schedule_conv3d_ndhwc),
Expand All @@ -252,6 +264,26 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target):
plevel=15)
return strategy

@conv3d_winograd_without_weight_transfrom_strategy.register(["cuda", "gpu"])
def conv3d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_type, target):
"""conv3d_winograd_without_weight_transfrom cuda strategy"""
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
layout = attrs.data_layout
assert dilation == (1, 1, 1), "Do not support dilate now"
assert groups == 1, "Do not supoort arbitrary group number"
strategy = _op.OpStrategy()
if layout == "NCDHW":
strategy.add_implementation(
wrap_compute_conv3d(topi.cuda.conv3d_ncdhw_winograd_without_weight_transform),
wrap_topi_schedule(
topi.cuda.schedule_conv3d_ncdhw_winograd_without_weight_transform),
name="conv3d_ncdhw_winograd_without_weight_transform.cuda")
else:
raise RuntimeError("Unsupported conv3d_winograd_without_weight_transfrom layout {}".
format(layout))
return strategy

@conv1d_strategy.register(["cuda", "gpu"])
def conv1d_strategy_cuda(attrs, inputs, out_type, target):
"""conv1d cuda strategy"""
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,19 @@ def conv3d_strategy(attrs, inputs, out_type, target):
raise ValueError("Not support this layout {} yet".format(layout))
return strategy

# conv3d_winograd_without_weight_transform
@override_native_generic_func("conv3d_winograd_without_weight_transform_strategy")
def conv3d_winograd_without_weight_transfrom_strategy(attrs, inputs, out_type, target):
"""conv3d_winograd_without_weight_transfrom generic strategy"""
raise ValueError("No generic implemenation for conv3d_winograd_without_weight_transform")

# conv3d_winograd_weight_transform
@generic_func
def schedule_conv3d_winograd_weight_transform(attrs, outs, target):
"""Schedule conv3d_winograd_weight_transform"""
with target:
return topi.generic.schedule_conv3d_winograd_weight_transform(outs)

# conv1d
def wrap_compute_conv1d(topi_compute):
"""wrap conv1d topi compute"""
Expand Down
Loading

0 comments on commit 4c2e3cd

Please sign in to comment.