Skip to content

Commit

Permalink
Make x86 work.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Mar 11, 2020
1 parent 5ef694f commit faf39ad
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 50 deletions.
31 changes: 0 additions & 31 deletions python/tvm/contrib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,34 +166,3 @@ def which(exec_name):
if os.path.isfile(full_path) and os.access(full_path, os.X_OK):
return full_path
return None


def get_lower_ir(s):
"""Get lower ir code of a schedule.
This is useful for debug, since you don't have to find all inputs/outputs
for a schedule in a fused subgraph.
Parameters
----------
s: Schedule
Returns
-------
ir: str
The lower ir
"""
from tvm.te import tensor
from tvm.driver.build_module import lower

outputs = s.outputs

inputs = []
def find_all(op):
if isinstance(op, tensor.PlaceholderOp):
inputs.append(op.output(0))
else:
for x in op.input_tensors:
find_all(x.op)

for out in outputs:
find_all(out)

return lower(s, inputs, simple_mode=True)
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
if groups == 1:
if layout == "NCHW":
assert kernel_layout == "OIHW"
if data.dtype in ('int8', 'uint8') and kernel.dtype in ('int8', 'uint8'):
if data.dtype in ('int8', 'uint8') and kernel.dtype in ('int8', 'uint8'):
assert data.dtype == kernel.dtype
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nchw_int8),
Expand Down
66 changes: 48 additions & 18 deletions topi/python/topi/generic/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data_vec, kernel_vec, conv_out,
parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
s[data_vec].parallel(parallel_axis)

oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[kernel_vec].op.axis
# conv2d_nchwc_int8 has 7D kernel
oc_chunk, ic_chunk, oh, ow, ic_block, oc_block, _ = s[kernel_vec].op.axis
s[kernel_vec].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block)
oc_bn = cfg["tile_oc"].size[-1]
if oc_bn > 1:
Expand Down Expand Up @@ -189,13 +190,26 @@ def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data_vec, kernel_vec, conv_out,
s[CC].unroll(oc_f_inner)

if C != O:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(batch, oc_chunk, oh)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
out_ndim = len(s[O].op.axis)
if out_ndim == 5:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(batch, oc_chunk, oh)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
elif out_ndim == 4:
batch, oc, oh, ow = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
oc_chunk, oc_block = s[O].split(oc, factor=oc_bn)
s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(batch, oc_chunk, oh)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
else:
raise ValueError("Unsupported output ndim: %s" % out_ndim)

return s

Expand Down Expand Up @@ -234,7 +248,8 @@ def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data_vec, kernel_vec, conv_out,
parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
s[data_vec].parallel(parallel_axis)

oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[kernel_vec].op.axis
# Conv2d int8 schedule has 7D kernel
oc_chunk, ic_chunk, oh, ow, ic_block, oc_block, _ = s[kernel_vec].op.axis
s[kernel_vec].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block)
oc_bn = cfg["tile_oc"].size[-1]
if oc_bn > 1:
Expand Down Expand Up @@ -277,14 +292,29 @@ def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data_vec, kernel_vec, conv_out,
s[CC].unroll(oh_inner)

if C != O:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)

parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
out_ndim = len(s[O].op.axis)
if out_ndim == 5:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)

parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
elif out_ndim == 4:
batch, oc, oh, ow = s[O].op.axis
oc_chunk, oc_block = s[O].split(oc, factor=oc_bn)
oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)

parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
else:
raise ValueError("Unsupported output ndim: %s" % out_ndim)

return s

0 comments on commit faf39ad

Please sign in to comment.