Skip to content

Commit

Permalink
Improve AArch64 depthwise convolution through smlal/smlal2 intrinsic (a…
Browse files Browse the repository at this point in the history
…pache#6711)

* Improve depthwise convolution through smlal/smlal2 intrinsic

- Added an intrinsic to load a single int16x8 vector and produce two
  int32x4 output vectors through smlal/smlal2 instructions

- Changed the NHWC depthwise schedule to accomodate the aforementioned
  intrinsic

Change-Id: I347c3bf98fa8dd87057304dcda0d78e558424c57

* Address review comments

* Rebasing - 2

* Rebasing - 3

* Rebasing - 3

* Fix linting
  • Loading branch information
Giuseppe Rossini authored and trevor-m committed Dec 4, 2020
1 parent 8d122f1 commit 97ea828
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 9 deletions.
65 changes: 56 additions & 9 deletions python/tvm/topi/arm_cpu/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from .. import nn
from ..utils import traverse_inline, get_const_tuple, get_const_int
from ..nn.utils import get_pad_tuple
from .tensor_intrin import smlal_int16_int32
from .arm_utils import is_aarch64_arm


@autotvm.register_topi_compute("depthwise_conv2d_nchw.arm_cpu")
Expand Down Expand Up @@ -222,7 +224,6 @@ def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, o
output : tvm.te.Tensor
4-D with shape [batch, out_height, out_width, out_channel]
"""

out_dtype = out_dtype or data.dtype

N, IH, IW, IC = get_const_tuple(data.shape)
Expand Down Expand Up @@ -288,60 +289,106 @@ def schedule_depthwise_conv2d_nhwc(cfg, outs):

##### space definition begin #####
n, h, w, c = s[out].op.axis
# Split the number of input/output channels
cfg.define_split("tile_c", c, num_outputs=2)
# Split the height of the convolution
_, hi = cfg.define_split("tile_h", h, num_outputs=2)
# Split the width of the convolution
_, wi = cfg.define_split("tile_w", w, num_outputs=2)
# Additional out (e.g., requantization, bias addition, etc..)
# 0: locate the output on the second last axis of the main compuation
# 1: locate the output closest to the main computation
cfg.define_knob("locate_output", [0, 1])
# Determine if we should unroll the computation of the inner tile
cfg.define_knob("unroll_tile", [True, False])

# fallback support
if cfg.is_fallback:
cfg["tile_c"] = SplitEntity([-1, 8])
cfg["tile_h"] = SplitEntity([-1, 2])
cfg["tile_w"] = SplitEntity([-1, 2])
cfg["locate_output"] = OtherOptionEntity(1)
cfg["unroll_tile"] = OtherOptionEntity(True)
##### space definition end #####

def schedule_conv(conv):
conv_data = conv.op.input_tensors[0]
kernel_data = conv.op.input_tensors[1]
in_type = conv_data.dtype

_, _, IC, channel_multiplier = get_const_tuple(kernel_data.shape)

n, w, h, c = conv.op.axis
r_h, r_w = conv.op.reduce_axis
ho, hi = cfg["tile_h"].apply(s, conv, h)
wo, wi = cfg["tile_w"].apply(s, conv, w)
co, ci = cfg["tile_c"].apply(s, conv, c)

split_val = cfg["tile_c"].size[-1]
use_tensorization = (
(in_type == "int16")
and (split_val == 8)
and (IC % split_val == 0)
and (channel_multiplier == 1)
and is_aarch64_arm()
)

data_pad_value = -1
if conv_data.name == "data_pad":
assert isinstance(conv_data.op, tvm.te.ComputeOp)
# Define a policy for padding computation
cfg.define_knob("data_pad_inline", [1, 2, 3])
# Define a strategy for padding computation
cfg.define_knob("data_pad_strategy", [1, 2, 3])
if cfg.is_fallback:
cfg["data_pad_inline"] = OtherOptionEntity(3)
if cfg["data_pad_inline"].val == 1:
# We cannot inline padding when tensorizing.
# So, if we can tensorize, let's compute_at the closest axis
cfg["data_pad_strategy"] = (
OtherOptionEntity(2) if use_tensorization else OtherOptionEntity(3)
)
# Compute padding on the third to last axis of the computation
if cfg["data_pad_strategy"].val == 1:
s[conv_data].vectorize(list(s[conv_data].op.axis)[-1])
s[conv_data].compute_at(s[conv], ho)
if cfg["data_pad_inline"].val == 2:
# Compute padding on the second to last axis of the computation
if cfg["data_pad_strategy"].val == 2:
s[conv_data].vectorize(list(s[conv_data].op.axis)[-1])
s[conv_data].compute_at(s[conv], wo)
if cfg["data_pad_inline"].val == 3:
# Inline padding during computation
if cfg["data_pad_strategy"].val == 3:
s[conv_data].compute_inline()
data_pad_value = cfg["data_pad_strategy"].val

if use_tensorization and data_pad_value != 3:
smlal = smlal_int16_int32()
s[conv].tensorize(ci, smlal)
else:
s[conv].vectorize(ci)

if cfg["unroll_tile"].val:
s[conv].unroll(r_h)
s[conv].unroll(r_w)
s[conv].unroll(wi)
s[conv].unroll(hi)

s[conv].reorder(n, ho, wo, co, hi, wi, r_h, r_w, ci)
fused_n_ho = s[conv].fuse(n, ho)
s[conv].vectorize(ci)
return fused_n_ho

def schedule_conv_out(out):
n, h, w, c = out.op.axis
co, ci = cfg["tile_c"].apply(s, out, c)
wo, wi = cfg["tile_w"].apply(s, out, w)
ho, hi = cfg["tile_h"].apply(s, out, h)
s[out].reorder(n, ho, wo, co, hi, wi)
s[out].reorder(n, ho, wo, co, hi, wi, ci)
if cfg["unroll_tile"]:
s[out].unroll(wi)
s[out].unroll(hi)

if out.dtype in ["int8", "uint8"]:
# In case of quantized convolution further split the channel in batches of 4 elements
# so that we can use arm intrinsics to run fixed_point_multiplication
ci_outer, ci_inner = s[out].split(ci, 4)
s[out].vectorize(ci_inner)
s[out].unroll(ci_outer)

fused_n_ho = s[out].fuse(n, ho)
return hi, wi, fused_n_ho
Expand Down
90 changes: 90 additions & 0 deletions python/tvm/topi/arm_cpu/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,96 @@ def _instr(index):
)


def smlal_int16_int32():
"""
Intrinsic to be used in order to load two int16x8 vectors and multiply
them together through a pair of smlal/smlal2 instructions. The pseudo-code
for the algorithm is as follows:
vec_a = vload(A, "int16x8")
vec_b = vload(B, "int16x8")
vec_c[0:4] += vec_a[0:4]*vec_b[0:4] // -> smlal instruction
vec_c[4:8] += vec_a[4:8]*vec_b[4:8] // -> smlal2 instruction
So we load a single int16x8 vector and we accumulate its lower (0:4) and
higher part separately.
"""
int16_lanes = 8
A = te.placeholder((int16_lanes,), dtype="int16", name="A")
B = te.placeholder((int16_lanes, 1), dtype="int16", name="B")
C = te.compute(
(int16_lanes,),
lambda i: A[i].astype("int32") * B[i, 0].astype("int32"),
name="C",
)

a_buffer = tvm.tir.decl_buffer(
A.shape, dtype="int16", name="a_buffer", offset_factor=1, strides=[1]
)
b_buffer = tvm.tir.decl_buffer(
B.shape,
dtype="int16",
name="b_buffer",
offset_factor=1,
strides=[te.var("sb"), 1],
)
c_buffer = tvm.tir.decl_buffer(
C.shape,
dtype="int32",
name="c_buffer",
offset_factor=1,
strides=[1],
)

def _intrin_func(ins, outs):
def _instr(index):
ib = tvm.tir.ir_builder.create()
if index == 1:
ib.emit(outs[0].vstore(0, tvm.tir.const(0, "int32x8")))
return ib.get()

vec_a = ins[0].vload([0], "int16x8")
vec_b = ins[1].vload([0, 0], "int16x8")
inst = "llvm.aarch64.neon.smull"

# Higher part of the vector
vec_c_h = outs[0].vload([4], "int32x4")
vec_a_h = tvm.tir.call_intrin("int16x4", "tir.vectorhigh", vec_a)
vec_b_h = tvm.tir.call_intrin("int16x4", "tir.vectorhigh", vec_b)
vmull_h = tvm.tir.call_llvm_pure_intrin(
"int32x4", inst, tvm.tir.const(2, "uint32"), vec_a_h, vec_b_h
)
vec_out_h = vec_c_h + vmull_h

# Lower part of the vector
vec_c_l = outs[0].vload([0], "int32x4")
vec_a_l = tvm.tir.call_intrin("int16x4", "tir.vectorlow", vec_a)
vec_b_l = tvm.tir.call_intrin("int16x4", "tir.vectorlow", vec_b)
vmull_l = tvm.tir.call_llvm_pure_intrin(
"int32x4", inst, tvm.tir.const(2, "uint32"), vec_a_l, vec_b_l
)
vec_out_l = vec_c_l + vmull_l

# Combine higher and lower part in a single int32x8 vector to store
# (this will require two different store instructions, since the
# length of a NEON vector is fixed at 128
vec_out = tvm.tir.call_intrin("int32x8", "tir.vectorcombine", vec_out_l, vec_out_h)
ib.emit(outs[0].vstore(0, vec_out))
return ib.get()

# body, reset, update
return _instr(0), _instr(1), _instr(2)

buffer_params = {"offset_factor": 1}
return te.decl_tensor_intrin(
C.op,
_intrin_func,
binds={A: a_buffer, B: b_buffer, C: c_buffer},
default_buffer_params=buffer_params,
)


def _q_multiply_shift_arm(op):
"""
Implementation of q_multiply_shift_arm through arm intrinsics
Expand Down
54 changes: 54 additions & 0 deletions tests/python/topi/python/test_topi_depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,55 @@
}


def compile_depthwise_NHWC_int8_arm(
batch,
in_channel,
in_size,
kernel,
depth_multiplier,
stride,
padding,
add_bias=False,
dilation=1,
):
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
padding_sum = pad_top + pad_left + pad_bottom + pad_right

in_height = in_width = in_size
A = te.placeholder((batch, in_height, in_width, in_channel), name="A", dtype="int16")
W = te.placeholder((kernel, kernel, in_channel, depth_multiplier), name="W", dtype="int16")
bias = te.placeholder((in_channel * depth_multiplier,), name="bias", dtype="int32")
dtype = "int32"

device = "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu"
compute = topi.arm_cpu.compute_depthwise_conv2d_nhwc
schedule = topi.arm_cpu.schedule_depthwise_conv2d_nhwc

if not tvm.testing.device_enabled(device):
print("Skip because %s is not enabled" % device)
return

print("Compiling on arm AArch64 target: %s" % device)
with tvm.target.Target(device):
assert topi.arm_cpu.arm_utils.is_aarch64_arm(), "AArch64 target not recognized"

C = compute(A, W, (stride, stride), padding, (dilation, dilation), dtype)
if add_bias:
C += bias
ins_outs = [A, W, bias, C]
else:
ins_outs = [A, W, C]

s = schedule([C])

func = tvm.build(
s,
ins_outs,
device,
name="depthwise_conv2d",
)


def depthwise_conv2d_with_workload_nchw(
batch, in_channel, in_height, channel_multiplier, filter_height, stride, padding, dilation=1
):
Expand Down Expand Up @@ -478,6 +527,7 @@ def test_depthwise_conv2d():
depthwise_conv2d_with_workload_nhwc(4, 256, 64, 2, 5, 2, "SAME")
depthwise_conv2d_with_workload_nhwc(1, 728, 32, 1, 3, 1, "VALID")
depthwise_conv2d_with_workload_nhwc(4, 256, 64, 2, 5, 2, "VALID")

# dilation = 2
# disabled because it uses too large shared memory on cuda
# depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME", dilation=2)
Expand All @@ -487,6 +537,10 @@ def test_depthwise_conv2d():
depthwise_conv2d_with_workload_NCHWc(1, 728, 32, 1, 3, 1, "SAME")
depthwise_conv2d_with_workload_NCHWc(1, 728, 32, 1, 3, 1, "VALID")

# Test compilation on arm devices
compile_depthwise_NHWC_int8_arm(1, 728, 32, 1, 3, 1, "SAME")
compile_depthwise_NHWC_int8_arm(1, 728, 32, 1, 1, 1, "SAME", True)


if __name__ == "__main__":
test_depthwise_conv2d()

0 comments on commit 97ea828

Please sign in to comment.