Skip to content

Commit

Permalink
[Strategy] Add group_conv2d_nchw_int8 in cuda strategy (apache#8167)
Browse files Browse the repository at this point in the history
* add group_nchw_int8.cuda

* fix

* fix style

* fix style

* fix style

* fix style

* fix style

* fix style

Co-authored-by: wangyucheng <wangyucheng@sensetime.com>
  • Loading branch information
2 people authored and Trevor Morris committed Jun 17, 2021
1 parent bbd192c commit d1cf696
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 9 deletions.
34 changes: 27 additions & 7 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,17 +334,37 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
cudnn_impl = True

if layout == "NCHW":
# TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8.
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, has_groups=True),
wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw),
name="group_conv2d_nchw.cuda",
)
_, channels, _, _ = get_const_tuple(data.shape)
out_channels, in_channels, _, _ = get_const_tuple(kernel.shape)
oc_chunk = out_channels // 4
ic_chunk = in_channels // 4

if (
data.dtype in ["int8", "uint8"]
and kernel.dtype in ["int8", "uint8"]
and channels % groups == 0
and out_channels % groups == 0
and channels % 4 == 0
and out_channels % 4 == 0
and groups <= oc_chunk
and groups <= ic_chunk
):
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.group_conv2d_nchw_int8, has_groups=True),
wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw_int8),
name="group_conv2d_nchw_int8.cuda",
)
else:
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, has_groups=True),
wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw),
name="group_conv2d_nchw.cuda",
)
elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
assert kernel_layout == "OIHW4o4i"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, True),
wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, has_groups=True),
wrap_topi_schedule(topi.cuda.schedule_group_conv2d_NCHWc_int8),
name="group_conv2d_NCHWc_int8.cuda",
)
Expand Down
26 changes: 25 additions & 1 deletion python/tvm/topi/cuda/group_conv2d_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
# pylint: disable=no-value-for-parameter
"""The template for cuda group_conv2d_nchw"""
import tvm
from tvm import te
Expand All @@ -23,11 +24,28 @@
from .injective import schedule_injective_from_existing
from .tensor_intrin import dp4a
from ..nn.pad import pad
from ..nn.conv2d import unpack_NCHWc_to_nchw
from ..nn.utils import get_pad_tuple
from ..utils import traverse_inline, get_const_tuple, get_const_int
from .. import nn


def group_conv2d_nchw_int8(data, kernel, strides, padding, dilation, groups, out_dtype="float32"):
"""Compute group_conv2d internally using group_conv2d_nchwc layout for int8 dtype"""
assert data.dtype in ("int8", "uint8")
assert kernel.dtype in ("int8", "uint8")
assert data.dtype == kernel.dtype
packed_out = group_conv2d_NCHWc_int8(
data, kernel, strides, padding, dilation, groups, out_dtype
)
return unpack_NCHWc_to_nchw(packed_out, out_dtype)


def schedule_group_conv2d_nchw_int8(outs):
"""Create schedule for tensors"""
return schedule_group_conv2d_NCHWc_int8(outs)


@autotvm.register_topi_compute("group_conv2d_nchw.cuda")
def group_conv2d_nchw(_, data, kernel, stride, padding, dilation, groups, out_dtype="float32"):
return nn.group_conv2d_nchw(data, kernel, stride, padding, dilation, groups, out_dtype)
Expand Down Expand Up @@ -422,7 +440,13 @@ def _schedule_group_conv2d_NCHWc_int8(cfg, s, output):

oc_chunk = get_const_int(output.shape[1])
# tile and bind spatial axes
n, f, y, x, c = s[output].op.axis
if len(s[output].op.axis) == 5:
n, f, y, x, c = s[output].op.axis
else:
# For task extraction of auto-tuning, the expected output is 4D. Since auto-tuning tasks
# are created from scratch, therefore the real auto-tuning will still happen on 5D output.
n, f, y, x = s[output].op.axis

cfg.define_split("tile_n", n, num_outputs=4)
cfg.define_split("tile_g", cfg.axis(groups), num_outputs=2)
cfg.define_split("tile_f", cfg.axis(oc_chunk // groups), num_outputs=4)
Expand Down
189 changes: 188 additions & 1 deletion tests/python/topi/python/test_topi_group_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,22 @@
import tvm.testing


def _transform_data(data, bn):
# NCHW -> NCHW[x]c
batch_size, channel, height, width = data.shape
data = np.reshape(data, (batch_size, channel // bn, bn, height, width))
data = np.transpose(data, (0, 1, 3, 4, 2))
return data


def _transform_kernel(kernel, ic_bn, oc_bn):
# OIHW -> OIHW[x]o[x]i
out_channel, in_channel, kh, kw = kernel.shape
kernel = np.reshape(kernel, (out_channel // oc_bn, oc_bn, in_channel // ic_bn, ic_bn, kh, kw))
kernel = np.transpose(kernel, (0, 2, 4, 5, 1, 3))
return kernel


_group_conv2d_nchw_implement = {
"generic": (topi.nn.group_conv2d_nchw, topi.generic.schedule_group_conv2d_nchw),
"gpu": (topi.cuda.group_conv2d_nchw, topi.cuda.schedule_group_conv2d_nchw),
Expand Down Expand Up @@ -154,6 +170,7 @@ def check_target(target):


oc_block_factor = 4
ic_block_factor = 4


def verify_group_conv2d_NCHWc_int8(
Expand All @@ -176,6 +193,151 @@ def verify_group_conv2d_NCHWc_int8(

in_height = in_width = in_size

A = te.placeholder(
(batch, in_channel // ic_block_factor, in_height, in_width, ic_block_factor),
name="A",
dtype="int8",
)
W = te.placeholder(
(
num_filter // oc_block_factor,
(in_channel // groups) // ic_block_factor,
kernel,
kernel,
oc_block_factor,
ic_block_factor,
),
name="W",
dtype="int8",
)
bias = te.placeholder(
(num_filter // oc_block_factor, 1, 1, oc_block_factor), name="bias", dtype="int8"
)

bias_shape = get_const_tuple(bias.shape)
dtype = A.dtype

@memoize("topi.tests.test_topi_group_conv2d.verify_group_conv2d_NCHWc_int8")
def get_ref_data():
a_np = np.random.randint(
low=-128, high=127, size=(batch, in_channel, in_height, in_width)
).astype(dtype)
w_np = np.random.randint(
low=-128, high=128, size=(num_filter, in_channel // groups, kernel, kernel)
).astype(dtype)
b_np = np.random.uniform(size=bias_shape).astype(dtype)
dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding, groups).astype(
dtype
)

# convert to NCHWc
_, _, out_height, out_width = c_np.shape
c_np = c_np.reshape(
(batch, num_filter // oc_block_factor, oc_block_factor, out_height, out_width)
).transpose(0, 1, 3, 4, 2)

if add_bias:
b_np = np.random.uniform(size=bias_shape).astype(dtype)
c_np += b_np
if add_relu:
c_np = np.maximum(c_np, 0)

return (
_transform_data(a_np, ic_block_factor),
_transform_kernel(w_np, ic_block_factor, oc_block_factor),
b_np,
c_np,
)

a_np, w_np, b_np, c_np = get_ref_data()

def check_target(target):
dev = tvm.device(target, 0)
if not tvm.testing.device_enabled(target):
print("Skip because %s is not enabled" % target)
return
if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
print("Skip because int8 intrinsics are not available")
return

print("Running on target: %s" % target)
with tvm.target.Target(target):
C = topi.cuda.group_conv2d_NCHWc_int8(A, W, stride, padding, dilation, groups, dtype)
if add_bias:
C = topi.add(C, bias)
if add_relu:
C = topi.nn.relu(C)
s = topi.cuda.schedule_group_conv2d_NCHWc_int8([C])

a = tvm.nd.array(a_np, dev)
w = tvm.nd.array(w_np, dev)
b = tvm.nd.array(b_np, dev)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
if add_bias:
func = tvm.build(
s,
[A, W, bias, C],
target,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d"
% (
batch,
in_channel,
in_size,
num_filter,
kernel,
stride,
padding,
dilation,
groups,
),
)
func(a, w, b, c)
else:
func = tvm.build(
s,
[A, W, C],
target,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d"
% (
batch,
in_channel,
in_size,
num_filter,
kernel,
stride,
padding,
dilation,
groups,
),
)
func(a, w, c)
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)

for target in ["cuda"]:
check_target(target)


def verify_group_conv2d_nchw_int8(
batch,
in_channel,
in_size,
num_filter,
kernel,
stride,
padding,
dilation,
groups,
add_bias=False,
add_relu=False,
):
print(
"Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d)"
% (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups)
)

in_height = in_width = in_size

A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype="int8")
W = te.placeholder((num_filter, in_channel // groups, kernel, kernel), name="W", dtype="int8")
bias = te.placeholder(
Expand All @@ -187,7 +349,7 @@ def verify_group_conv2d_NCHWc_int8(
bias_shape = get_const_tuple(bias.shape)
dtype = A.dtype

@memoize("topi.tests.test_topi_group_conv2d.verify_group_conv2d_NCHWc_int8")
@memoize("topi.tests.test_topi_group_conv2d.verify_group_conv2d_nchw_int8")
def get_ref_data():
a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
Expand Down Expand Up @@ -442,6 +604,30 @@ def test_group_conv2d_NCHWc_int8():
verify_group_conv2d_NCHWc_int8(9, 128, 56, 128, 3, 1, 1, 1, 32)


@tvm.testing.requires_cuda
def test_group_conv2d_nchw_int8():
with Int8Fallback():
# ResNeXt-50 workload
verify_group_conv2d_nchw_int8(1, 128, 56, 128, 3, 1, 1, 1, 32)
verify_group_conv2d_nchw_int8(1, 256, 56, 256, 3, 2, 1, 1, 32)
verify_group_conv2d_nchw_int8(1, 256, 28, 256, 3, 1, 1, 1, 32)
verify_group_conv2d_nchw_int8(1, 512, 28, 512, 3, 2, 1, 1, 32)
verify_group_conv2d_nchw_int8(1, 512, 14, 512, 3, 1, 1, 1, 32)
verify_group_conv2d_nchw_int8(1, 1024, 14, 1024, 3, 2, 1, 1, 32)
verify_group_conv2d_nchw_int8(1, 1024, 7, 1024, 3, 1, 1, 1, 32)

# bias, relu
verify_group_conv2d_nchw_int8(1, 128, 56, 128, 3, 1, 1, 1, 32, add_relu=True)
verify_group_conv2d_nchw_int8(1, 128, 56, 128, 3, 1, 1, 1, 32, add_bias=True)
verify_group_conv2d_nchw_int8(1, 128, 56, 128, 3, 1, 1, 1, 32, add_relu=True, add_bias=True)
# dilation
verify_group_conv2d_nchw_int8(1, 128, 56, 128, 3, 1, 1, 2, 32)

# batch size
verify_group_conv2d_nchw_int8(2, 128, 56, 128, 3, 1, 1, 1, 32)
verify_group_conv2d_nchw_int8(9, 128, 56, 128, 3, 1, 1, 1, 32)


def test_group_conv2d_nhwc():
# ResNeXt-50 workload
verify_group_conv2d_nhwc(1, 128, 56, 128, 3, 1, 1, 1, 32)
Expand All @@ -468,4 +654,5 @@ def test_group_conv2d_nhwc():
if __name__ == "__main__":
test_group_conv2d_nchw()
test_group_conv2d_NCHWc_int8()
test_group_conv2d_nchw_int8()
test_group_conv2d_nhwc()

0 comments on commit d1cf696

Please sign in to comment.