Skip to content

Commit

Permalink
[Strategy] Support for Int8 schedules - CUDA/x86 (apache#5031)
Browse files Browse the repository at this point in the history
* [CUDA] Op strategy changes for Int8 schedules.

* Applying Haichen's suggestions.

* Make 4D output work for task extraction.

* Make x86 work.

* Fix lint.

* Lint fixes.

* Tests, comments, out channel a multiple of 4.

* Topi test.

Co-authored-by: Ubuntu <ubuntu@ip-172-31-38-96.us-west-2.compute.internal>
  • Loading branch information
2 people authored and zhiics committed Apr 17, 2020
1 parent 59a13c2 commit 8147ae8
Show file tree
Hide file tree
Showing 8 changed files with 273 additions and 28 deletions.
8 changes: 4 additions & 4 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,8 +1373,8 @@ def _get_sum(_res, _output_scale, out_dtype):

# 3) Clip/cast to change the out dtype.
_res = relay.clip(_res,
a_min=float(tvm.api.min_value(out_dtype).value),
a_max=float(tvm.api.max_value(out_dtype).value))
a_min=float(tvm.tir.op.min_value(out_dtype).value),
a_max=float(tvm.tir.op.max_value(out_dtype).value))
_res = relay.cast(_res, out_dtype)
return _res

Expand Down Expand Up @@ -1647,8 +1647,8 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale):
_op.multiply(_op.cast(bias_data, 'float32'), bias_requantize_scale)
rounded_bias = _op.round(multiplied_bias)
clipped_bias = _op.clip(rounded_bias,
a_min=tvm.api.min_value('int32').value,
a_max=tvm.api.max_value('int32').value)
a_min=tvm.tir.op.min_value('int32').value,
a_max=tvm.tir.op.max_value('int32').value)
requantized_bias = _op.cast(clipped_bias, 'int32')
res = _op.nn.bias_add(res, requantized_bias, axis=-1)
enable_float_output = attrs.get_bool('enable_float_output', False)
Expand Down
16 changes: 11 additions & 5 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,18 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):

if groups == 1:
if layout == "NCHW":
# TODO(@vinx13, @icemelon9): Use conv2d_NCHWc_int8 when dtype is int8/uint8.
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nchw),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw),
name="conv2d_nchw.cuda")
if data.dtype in ('int8', 'uint8') and kernel.dtype in ('int8', 'uint8'):
assert data.dtype == kernel.dtype
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nchw_int8),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_int8),
name="conv2d_nchw_int8.cuda")
else:
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nchw),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw),
name="conv2d_nchw.cuda")
_, _, kh, kw = get_const_tuple(kernel.shape)
if 2 < kh < 8 and 2 < kw < 8 and kh == kw and stride_h == 1 and stride_w == 1 and \
dilation_h == 1 and dilation_w == 1:
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,17 @@ def _qnn_dense_legalize_intel_cpu(attrs, inputs, types):
if is_fast_int8_on_intel():
return helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay.qnn.op.dense)
return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense)

#####################
# CUDA legalizations.
#####################

@qnn_conv2d_legalize.register('cuda')
def _qnn_conv2d_legalize_cuda(attrs, inputs, types):
# CUDA prefers the dtypes to be same.
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)

@qnn_dense_legalize.register('cuda')
def _qnn_dense_legalize_cuda(attrs, inputs, types):
# CUDA prefers the dtypes to be same.
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)
14 changes: 14 additions & 0 deletions tests/python/relay/test_pass_qnn_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,13 @@ def _get_mod(data_dtype, kernel_dtype):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()

###########################################
# Check transformations for CUDA platforms.
###########################################
with tvm.target.create('cuda'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert 'cast' in legalized_mod.astext() and "qnn" in legalized_mod.astext()


def test_qnn_legalize_qnn_dense():
def _get_mod(data_dtype, kernel_dtype):
Expand Down Expand Up @@ -257,6 +264,13 @@ def _get_mod(data_dtype, kernel_dtype):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()

###########################################
# Check transformations for CUDA platforms.
###########################################
with tvm.target.create('cuda'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert 'cast' in legalized_mod.astext() and "qnn" in legalized_mod.astext()


if __name__ == "__main__":
test_qnn_legalize()
Expand Down
80 changes: 80 additions & 0 deletions topi/python/topi/cuda/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .. import nn
from ..util import get_const_tuple
from .conv2d_winograd import _infer_tile_size
from ..nn import conv2d_legalize

logger = logging.getLogger('topi')

Expand Down Expand Up @@ -135,3 +136,82 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
return relay.nn.conv2d(*inputs, **new_attrs)

return None

@conv2d_legalize.register("cuda")
def _conv2d_legalize(attrs, inputs, arg_types):
"""Legalizes Conv2D op.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""

# Dilation not supported yet. Return None if dilation is not (1, 1)
dilation = attrs.get_int_tuple("dilation")
if not (dilation[0] == 1 and dilation[1] == 1):
return None

# No legalization for depthwise convolutions yet.
groups = attrs.get_int("groups")
if groups != 1:
return None

# Collect the input tensors.
data_tensor, kernel_tensor = arg_types[0], arg_types[1]
data_dtype = data_tensor.dtype

# Collect the output tensor.
output_tensor = arg_types[2]

# Collect the input exprs.
data, kernel = inputs

# Get the conv attrs
new_attrs = {k: attrs[k] for k in attrs.keys()}

# Get data layout. Return None if not NCHW
data_layout = attrs['data_layout']
kernel_layout = attrs['kernel_layout']

# Pad input and output channels to use int8 schedule.
if data_dtype in ['int8', 'uint8']:
if data_layout == 'NCHW' and kernel_layout == "OIHW":
oc_modified = False
in_channel = data_tensor.shape[1].value
out_channel = kernel_tensor.shape[0].value

# Pad input channel
if in_channel % 4 != 0:
new_in_channel = ((in_channel + 4) // 4) * 4
diff = new_in_channel - in_channel
pad_width = ((0, 0), (0, diff), (0, 0), (0, 0))
data = relay.nn.pad(data, pad_width=pad_width)
kernel = relay.nn.pad(kernel, pad_width=pad_width)

# Pad output channel
new_out_channel = out_channel
if out_channel % 4 != 0:
new_out_channel = ((out_channel + 4) // 4) * 4
diff = new_out_channel - out_channel
kernel = relay.nn.pad(kernel, pad_width=((0, diff), (0, 0), (0, 0), (0, 0)))
oc_modified = True

if oc_modified:
new_attrs['channels'] = new_out_channel
out = tvm.relay.nn.conv2d(data, kernel, **new_attrs)
original_out_shape = [x.value for x in output_tensor.shape]
out = relay.strided_slice(out, begin=(0, 0, 0, 0), end=original_out_shape)
else:
out = relay.nn.conv2d(data, kernel, **new_attrs)
return out
return None
22 changes: 21 additions & 1 deletion topi/python/topi/cuda/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
# pylint: disable=no-value-for-parameter
"""Int8 conv2d in NCHWc layout"""
import tvm
from tvm import te
Expand All @@ -23,10 +24,23 @@
from .injective import schedule_injective_from_existing
from .tensor_intrin import dp4a
from ..nn.pad import pad
from ..nn.conv2d import unpack_NCHWc_to_nchw
from ..nn.util import get_pad_tuple
from ..util import get_const_tuple, traverse_inline


def conv2d_nchw_int8(data, kernel, strides, padding, dilation, out_dtype='int32'):
"""Compute conv2d internally using conv2d_nchwc layout for int8 dtype"""
assert data.dtype in ('int8', 'uint8')
assert kernel.dtype in ('int8', 'uint8')
assert data.dtype == kernel.dtype
packed_out = conv2d_NCHWc_int8(data, kernel, strides, padding, dilation, "NCHW", out_dtype)
return unpack_NCHWc_to_nchw(packed_out, out_dtype)

def schedule_conv2d_nchw_int8(outs):
"""Create schedule for tensors"""
return schedule_conv2d_NCHWc_int8(outs)

@autotvm.register_topi_compute("conv2d_NCHWc_int8.cuda")
def conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, dilation, layout, out_dtype):
"""Convolution operator in NCHW[x]c layout for int8.
Expand Down Expand Up @@ -205,7 +219,13 @@ def _schedule_conv2d_NCHWc_int8(cfg, s, output):
output = s.outputs[0].output(0)

# tile and bind spatial axes
n, f, y, x, c = s[output].op.axis
if len(s[output].op.axis) == 5:
n, f, y, x, c = s[output].op.axis
else:
# For task extraction of auto-tuning, the expected output is 4D. Since auto-tuning tasks
# are created from scratch, therefore the real auto-tuning will still happen on 5D output.
n, f, y, x = s[output].op.axis

cfg.define_split("tile_n", cfg.axis(n), num_outputs=4)
cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
Expand Down
66 changes: 48 additions & 18 deletions topi/python/topi/generic/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data_vec, kernel_vec, conv_out,
parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
s[data_vec].parallel(parallel_axis)

oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[kernel_vec].op.axis
# conv2d_nchwc_int8 has 7D kernel
oc_chunk, ic_chunk, oh, ow, ic_block, oc_block, _ = s[kernel_vec].op.axis
s[kernel_vec].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block)
oc_bn = cfg["tile_oc"].size[-1]
if oc_bn > 1:
Expand Down Expand Up @@ -189,13 +190,26 @@ def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data_vec, kernel_vec, conv_out,
s[CC].unroll(oc_f_inner)

if C != O:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(batch, oc_chunk, oh)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
out_ndim = len(s[O].op.axis)
if out_ndim == 5:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(batch, oc_chunk, oh)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
elif out_ndim == 4:
batch, oc, oh, ow = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
oc_chunk, oc_block = s[O].split(oc, factor=oc_bn)
s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(batch, oc_chunk, oh)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
else:
raise ValueError("Unsupported output ndim: %s" % out_ndim)

return s

Expand Down Expand Up @@ -234,7 +248,8 @@ def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data_vec, kernel_vec, conv_out,
parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
s[data_vec].parallel(parallel_axis)

oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[kernel_vec].op.axis
# Conv2d int8 schedule has 7D kernel
oc_chunk, ic_chunk, oh, ow, ic_block, oc_block, _ = s[kernel_vec].op.axis
s[kernel_vec].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block)
oc_bn = cfg["tile_oc"].size[-1]
if oc_bn > 1:
Expand Down Expand Up @@ -277,14 +292,29 @@ def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data_vec, kernel_vec, conv_out,
s[CC].unroll(oh_inner)

if C != O:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)

parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
out_ndim = len(s[O].op.axis)
if out_ndim == 5:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)

parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
elif out_ndim == 4:
batch, oc, oh, ow = s[O].op.axis
oc_chunk, oc_block = s[O].split(oc, factor=oc_bn)
oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)

parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
else:
raise ValueError("Unsupported output ndim: %s" % out_ndim)

return s
Loading

0 comments on commit 8147ae8

Please sign in to comment.