Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[topi] add ARM v8.2 udot (uint8) support #3978

Merged
merged 8 commits into from
Oct 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions topi/python/topi/arm_cpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from . import conv2d
from . import depthwise_conv2d
from . import conv2d_transpose
from . import conv2d_int8
from . import bitserial_conv2d
from . import bitserial_dense
from . import injective
112 changes: 112 additions & 0 deletions topi/python/topi/arm_cpu/conv2d_int8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
"""Conv2D int8 schedule on ARM"""

import tvm
from tvm import autotvm
from .. import generic, tag
from ..util import get_const_tuple
from ..nn.conv2d import conv2d_NCHWc_int8
from ..generic import conv2d as conv2d_generic
from .. import nn
from ..nn.conv2d import _get_workload as _get_conv2d_workload
from .tensor_intrin import dot_int8_int8_int32


def _get_default_config(cfg, data, kernel, strides, padding, out_dtype):
"""
Get default int8 schedule config for the workload
"""
wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype)
is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
if is_kernel_1x1:
conv2d_generic.fallback_schedule_cpu_1x1_int8(
cfg, wkl, int32_lanes=2, num_int8_elements=4)
else:
conv2d_generic.fallback_schedule_cpu_common_int8(
cfg, wkl, int32_lanes=2, num_int8_elements=4)


@autotvm.register_topi_compute(conv2d_NCHWc_int8, ['arm_cpu'], 'direct')
def _declaration_conv_NCHWc_int8(cfg, data, kernel, strides,
padding, dilation, layout, out_layout, out_dtype):
# layout and out_layout are not used here,
# we keep them for debug convenience when dumping autotvm workload
n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
in_channel = ic_chunk * ic_bn

oc_chunk, ic_chunk, kh, kw, ic_bn, oc_bn, n_elems = get_const_tuple(kernel.shape)
num_filter = oc_chunk * oc_bn

# If no config was set, we can fallback to NCHW config.
if cfg.is_fallback:
_get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
tvm.placeholder((num_filter, in_channel, kh, kw), dtype=kernel.dtype),
strides, padding, out_dtype)
return nn.conv2d_NCHWc_int8_compute(data,
kernel,
strides,
padding,
dilation,
layout,
out_layout,
out_dtype)


@autotvm.register_topi_schedule(generic.schedule_conv2d_NCHWc_int8, ['arm_cpu'], ['direct'])
def _schedule_conv2d_NCHWc_int8(cfg, outs):
"""Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []

def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)

if 'conv2d_NCHWc_int8' in op.tag:
conv_out = op.output(0)
kernel = conv_out.op.input_tensors[1]
data_vec = conv_out.op.input_tensors[0]
data = data_vec.op.input_tensors[0] \
if isinstance(data_vec.op, tvm.tensor.ComputeOp) and "pad" not in data_vec.op.tag \
else data_vec
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]

args = [s, cfg, data_vec, conv_out, outs[0]]
# int8 conv kernel is 7-dim
_, _, kh, kw, _, _, _ = get_const_tuple(kernel.shape)
dtype = "uint" if data.dtype == "uint8" else "int"
if kh == 1 and kw == 1:
conv2d_generic.schedule_conv_NCHWc_cpu_1x1_int8(
*args, int32_lanes=4, intrin=dot_int8_int8_int32(int32_lanes=4, dtype=dtype))
else:
conv2d_generic.schedule_conv_NCHWc_cpu_common_int8(
*args, int32_lanes=4, intrin=dot_int8_int8_int32(int32_lanes=4, dtype=dtype))

scheduled_ops.append(op)

traverse(outs[0].op)
return s
110 changes: 110 additions & 0 deletions topi/python/topi/arm_cpu/tensor_intrin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
"""Conv2D int8 schedule on ARM"""

import tvm

def dot_int8_int8_int32(int32_lanes, dtype='uint'):
"""
Int8 dot product by every 4 elements using ARM v8.2 udot.
This function takes two arrays of int8 datatype -- data[4] and
kernel[int32_lanes][4] -- and computes a dot product of data[4] with every
4 elements of kernels, resulting in output[int32_lanes] of uint32 datatype.
The pseudo code is as follows.

.. code-block:: c
yzhliu marked this conversation as resolved.
Show resolved Hide resolved

void dot_int8_int8_int32(int8 data[4], int8 kernel[16][4], int32 output[16]){
for (int i = 0; i < int32_lanes; i++){
out[i] = 0;
for (int k = 0; k < 4; k++){
out[i] += data[k] * kernel[i][k]
}
}
}

Physically, the kernel array sits in a vector register and
the data[4] is broadcasted to another vector register. This
function returns a TensorIntrin that can be used to tensorize
a schedule.

Parameters
----------
int32_lanes: int
How many int32/uint32 to produce
dtype: str, optional, {"uint", "int"}
Whether it works on unsigned int or signed int

Returns
-------
intrin : TensorIntrin
The ARM uint8 TensorIntrin that can be used in tensorizing schedule
"""
num_int8_elements = 4 # 4 int8 elements in int32

data = tvm.placeholder((num_int8_elements,), dtype='%s8' % dtype, name='data')
kernel = tvm.placeholder((int32_lanes, num_int8_elements), dtype='%s8' % dtype, name='kernel')

k = tvm.reduce_axis((0, num_int8_elements), name='k')
C = tvm.compute((int32_lanes,),
lambda i: tvm.sum(data[k].astype('%s32' % dtype) *
kernel[i, k].astype('%s32' % dtype),
axis=k), name="C")

a_buffer = tvm.decl_buffer(data.shape, dtype='%s8' % dtype, name="a_buffer",
offset_factor=1,
strides=[1])
b_buffer = tvm.decl_buffer(kernel.shape, dtype='%s8' % dtype, name="b_buffer",
offset_factor=1,
strides=[tvm.var('s'), 1])

def _intrin_func(ins, outs):
def _instr(index):
ib = tvm.ir_builder.create()
if index == 1:
ib.emit(outs[0].vstore(0, tvm.const(0, '%s32x%d' % (dtype, int32_lanes))))
return ib.get()

dtype_a = '%s8x%d' % (dtype, num_int8_elements)
dtype_b = '%s8x%d' % (dtype, int32_lanes * num_int8_elements)
dtype_c = '%s32x%d' % (dtype, int32_lanes)

a_int8 = ins[0].vload([0], dtype_a)
re_int32 = tvm.call_pure_intrin('%s32' % dtype, 'reinterpret', a_int8)
# broadcast a
vec_ai32 = re_int32.astype(dtype_c)

vec_a = tvm.call_pure_intrin(dtype_b, 'reinterpret', vec_ai32)
vec_b = ins[1].vload([0, 0], dtype_b)
vec_c = outs[0].vload([0], dtype_c)

inst = 'udot' if dtype == 'uint' else 'sdot'
inst = 'llvm.aarch64.neon.%s.v%di32.v%di8' % (
inst, int32_lanes, int32_lanes * num_int8_elements)
vdot = tvm.call_llvm_intrin(dtype_c,
inst,
tvm.const(2, 'uint32'),
vec_c, vec_a, vec_b)
ib.emit(outs[0].vstore(0, vdot))
return ib.get()

# body, reset, update
return _instr(0), _instr(1), _instr(2)

with tvm.build_config(offset_factor=1, partition_const_loop=True):
return tvm.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer})
Loading