From faf39ad6a47d1094e4d22eea9c73aa0975d8f0c4 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 11 Mar 2020 02:06:16 +0000 Subject: [PATCH] Make x86 work. --- python/tvm/contrib/util.py | 31 ------------- python/tvm/relay/op/strategy/cuda.py | 2 +- topi/python/topi/generic/conv2d.py | 66 ++++++++++++++++++++-------- 3 files changed, 49 insertions(+), 50 deletions(-) diff --git a/python/tvm/contrib/util.py b/python/tvm/contrib/util.py index 25d361585f96..2ebe175e8160 100644 --- a/python/tvm/contrib/util.py +++ b/python/tvm/contrib/util.py @@ -166,34 +166,3 @@ def which(exec_name): if os.path.isfile(full_path) and os.access(full_path, os.X_OK): return full_path return None - - -def get_lower_ir(s): - """Get lower ir code of a schedule. - This is useful for debug, since you don't have to find all inputs/outputs - for a schedule in a fused subgraph. - Parameters - ---------- - s: Schedule - Returns - ------- - ir: str - The lower ir - """ - from tvm.te import tensor - from tvm.driver.build_module import lower - - outputs = s.outputs - - inputs = [] - def find_all(op): - if isinstance(op, tensor.PlaceholderOp): - inputs.append(op.output(0)) - else: - for x in op.input_tensors: - find_all(x.op) - - for out in outputs: - find_all(out) - - return lower(s, inputs, simple_mode=True) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 646ceff34f1c..e5eff1c6b790 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -86,7 +86,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): if groups == 1: if layout == "NCHW": assert kernel_layout == "OIHW" - if data.dtype in ('int8', 'uint8') and kernel.dtype in ('int8', 'uint8'): + if data.dtype in ('int8', 'uint8') and kernel.dtype in ('int8', 'uint8'): assert data.dtype == kernel.dtype strategy.add_implementation( wrap_compute_conv2d(topi.cuda.conv2d_nchw_int8), diff --git a/topi/python/topi/generic/conv2d.py b/topi/python/topi/generic/conv2d.py index 69984a169ac6..2d9f78b645db 100644 --- a/topi/python/topi/generic/conv2d.py +++ b/topi/python/topi/generic/conv2d.py @@ -144,7 +144,8 @@ def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data_vec, kernel_vec, conv_out, parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih) s[data_vec].parallel(parallel_axis) - oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[kernel_vec].op.axis + # conv2d_nchwc_int8 has 7D kernel + oc_chunk, ic_chunk, oh, ow, ic_block, oc_block, _ = s[kernel_vec].op.axis s[kernel_vec].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block) oc_bn = cfg["tile_oc"].size[-1] if oc_bn > 1: @@ -189,13 +190,26 @@ def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data_vec, kernel_vec, conv_out, s[CC].unroll(oc_f_inner) if C != O: - batch, oc_chunk, oh, ow, oc_block = s[O].op.axis - ow_chunk, ow_block = s[O].split(ow, factor=reg_n) - s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) - parallel_axis = s[O].fuse(batch, oc_chunk, oh) - s[C].compute_at(s[O], parallel_axis) - s[O].vectorize(oc_block) - s[O].parallel(parallel_axis) + out_ndim = len(s[O].op.axis) + if out_ndim == 5: + batch, oc_chunk, oh, ow, oc_block = s[O].op.axis + ow_chunk, ow_block = s[O].split(ow, factor=reg_n) + s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + parallel_axis = s[O].fuse(batch, oc_chunk, oh) + s[C].compute_at(s[O], parallel_axis) + s[O].vectorize(oc_block) + s[O].parallel(parallel_axis) + elif out_ndim == 4: + batch, oc, oh, ow = s[O].op.axis + ow_chunk, ow_block = s[O].split(ow, factor=reg_n) + oc_chunk, oc_block = s[O].split(oc, factor=oc_bn) + s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + parallel_axis = s[O].fuse(batch, oc_chunk, oh) + s[C].compute_at(s[O], parallel_axis) + s[O].vectorize(oc_block) + s[O].parallel(parallel_axis) + else: + raise ValueError("Unsupported output ndim: %s" % out_ndim) return s @@ -234,7 +248,8 @@ def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data_vec, kernel_vec, conv_out, parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih) s[data_vec].parallel(parallel_axis) - oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[kernel_vec].op.axis + # Conv2d int8 schedule has 7D kernel + oc_chunk, ic_chunk, oh, ow, ic_block, oc_block, _ = s[kernel_vec].op.axis s[kernel_vec].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block) oc_bn = cfg["tile_oc"].size[-1] if oc_bn > 1: @@ -277,14 +292,29 @@ def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data_vec, kernel_vec, conv_out, s[CC].unroll(oh_inner) if C != O: - batch, oc_chunk, oh, ow, oc_block = s[O].op.axis - oh_outer, oh_inner = s[O].split(oh, factor=oh_factor) - ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) - s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) - - parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer) - s[C].compute_at(s[O], parallel_axis) - s[O].vectorize(oc_block) - s[O].parallel(parallel_axis) + out_ndim = len(s[O].op.axis) + if out_ndim == 5: + batch, oc_chunk, oh, ow, oc_block = s[O].op.axis + oh_outer, oh_inner = s[O].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) + s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) + + parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer) + s[C].compute_at(s[O], parallel_axis) + s[O].vectorize(oc_block) + s[O].parallel(parallel_axis) + elif out_ndim == 4: + batch, oc, oh, ow = s[O].op.axis + oc_chunk, oc_block = s[O].split(oc, factor=oc_bn) + oh_outer, oh_inner = s[O].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) + s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) + + parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer) + s[C].compute_at(s[O], parallel_axis) + s[O].vectorize(oc_block) + s[O].parallel(parallel_axis) + else: + raise ValueError("Unsupported output ndim: %s" % out_ndim) return s