From 912993ff3984ddb47ec98888d645a8f157daa3fc Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 1 Apr 2022 09:11:32 +0900 Subject: [PATCH] [ARM] Fix int8 NCHWc compute and alter layout (#10839) This PR fixes a bug in TE ARM int8 compute for NCHWc conv2d, introduced in https://github.com/apache/tvm/pull/10310. The compute itself, not the schedule, is broken for the following reasons: * We are using `n_elems = 8` in https://github.com/apache/tvm/blob/e9091d6c68d5d70c28881e5c75bfe72e385c1f4d/python/tvm/topi/arm_cpu/conv2d_alter_op.py#L350. Thus, the innermost axis of the transformed kernel has extent 8: https://github.com/apache/tvm/blob/e9091d6c68d5d70c28881e5c75bfe72e385c1f4d/python/tvm/topi/arm_cpu/conv2d_alter_op.py#L375 * In the TE compute, we iterate over the innermost axis `ic_s_inner` of the kernel at https://github.com/apache/tvm/blob/f6f252f0abc8f621a96506739f9534083d1fe213/python/tvm/topi/nn/conv2d.py#L577. `ic_s_inner` has extent `n_elems` according to https://github.com/apache/tvm/blob/f6f252f0abc8f621a96506739f9534083d1fe213/python/tvm/topi/nn/conv2d.py#L566. `n_elems` is 4 by default according to https://github.com/apache/tvm/blob/f6f252f0abc8f621a96506739f9534083d1fe213/python/tvm/topi/nn/conv2d.py#L478 * The ARM code that calls this compute does not explicitly pass `n_elems`, according to https://github.com/apache/tvm/blob/e9091d6c68d5d70c28881e5c75bfe72e385c1f4d/python/tvm/topi/arm_cpu/conv2d_int8.py#L106-L108 * Thus, even though the innermost axis of the kernel has extent 8, the TE compute only loops over `n_elems = 4` of the input channel dimension. Initially, I tried to keep `n_elems = 8` in alter layout and fix the intrinsic definition. But `n_elems = 8` breaks tensorization pattern matching, since now the compute is doing 4x8 innermost loop but this intrinsic is supposed to do 4x4 dot product, see https://github.com/apache/tvm/blob/7896108fc41663a1fecbb52345194a93278e9e28/python/tvm/topi/arm_cpu/tensor_intrin.py#L467-L479. Setting `num_int8_elements = 8` there does fix the tensorize pattern matching, but the result was still incorrect. Rather than fixing the intrin implementation in https://github.com/apache/tvm/blob/7896108fc41663a1fecbb52345194a93278e9e28/python/tvm/topi/arm_cpu/tensor_intrin.py#L492 to adapt for 4x8 dot product, I settled on setting `n_elems = 4` in alter layout. It turned out this change is enough to get the correct output. Moreover, `n_elems = 8` is simply wrong for the dot product path in https://github.com/apache/tvm/blob/7896108fc41663a1fecbb52345194a93278e9e28/python/tvm/topi/arm_cpu/conv2d_int8.py#L154-L155 which computes 4x4 dot product in one instruction. @tkonolige I suggest doing perf benchmark again, since the numbers in https://github.com/apache/tvm/pull/10310 are invalid. cc @mbrookhart @Mousius @junrushao1994 @vinx13 --- python/tvm/topi/arm_cpu/conv2d_alter_op.py | 2 +- python/tvm/topi/arm_cpu/conv2d_int8.py | 9 +++++--- python/tvm/topi/arm_cpu/tensor_intrin.py | 21 ++++++++++--------- python/tvm/topi/nn/conv2d.py | 1 - python/tvm/topi/x86/conv2d_int8.py | 2 +- .../topi/python/test_topi_conv2d_int8.py | 14 ++++++------- 6 files changed, 26 insertions(+), 23 deletions(-) diff --git a/python/tvm/topi/arm_cpu/conv2d_alter_op.py b/python/tvm/topi/arm_cpu/conv2d_alter_op.py index eb719dd66777..728e0db102fe 100644 --- a/python/tvm/topi/arm_cpu/conv2d_alter_op.py +++ b/python/tvm/topi/arm_cpu/conv2d_alter_op.py @@ -347,7 +347,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape) out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape) - n_elems = 8 + n_elems = 4 if cfg.is_fallback: _get_default_config_int8( diff --git a/python/tvm/topi/arm_cpu/conv2d_int8.py b/python/tvm/topi/arm_cpu/conv2d_int8.py index 91e3e79cf8c7..b6ab89de8b0a 100644 --- a/python/tvm/topi/arm_cpu/conv2d_int8.py +++ b/python/tvm/topi/arm_cpu/conv2d_int8.py @@ -57,7 +57,7 @@ def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) in_channel = ic_chunk * ic_bn - oc_chunk, ic_chunk, kh, kw, ic_bn, oc_bn, n_elems = get_const_tuple(kernel.shape) + oc_chunk, ic_chunk, kh, kw, ic_bn, oc_bn, _ = get_const_tuple(kernel.shape) num_filter = oc_chunk * oc_bn else: # data is nchw, implicitly treat it as nchw1c @@ -103,8 +103,10 @@ def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out if len(data.shape) == 4: data, kernel = _pack_data(cfg, data, kernel) + n_elems = int(kernel.shape[-1]) + return nn.conv2d_NCHWc_int8( - data, kernel, strides, padding, dilation, layout, out_layout, out_dtype + data, kernel, strides, padding, dilation, layout, out_layout, out_dtype, n_elems=n_elems ) @@ -149,7 +151,8 @@ def _callback(op): args = [s, cfg, data_vec, kernel_vec, conv_out, outs[0]] # int8 conv kernel is 7-dim - _, _, kh, kw, _, _, _ = get_const_tuple(kernel_vec.shape) + _, _, kh, kw, _, _, n_elems = get_const_tuple(kernel_vec.shape) + assert n_elems == 4 dtype = "uint" if data.dtype == "uint8" else "int" if is_dotprod_available(): intrin = dot_int8_int8_int32_neon_82(int32_lanes=4, dtype=dtype) diff --git a/python/tvm/topi/arm_cpu/tensor_intrin.py b/python/tvm/topi/arm_cpu/tensor_intrin.py index d6b6f225890a..e27d00f17617 100644 --- a/python/tvm/topi/arm_cpu/tensor_intrin.py +++ b/python/tvm/topi/arm_cpu/tensor_intrin.py @@ -614,21 +614,22 @@ def _instr(index): ib.emit(outs[0].vstore(0, tvm.tir.const(0, int_32xl))) return ib.get() - def pairwise_add_mul(idx): - # this broadcasts data to the vector size - a_int8 = ins[0].vload([0], "int8x4") - re_int32 = tvm.tir.call_intrin("int32", "tir.reinterpret", a_int8) - vec_ai32 = re_int32.astype("int32x2") - vec_a = tvm.tir.call_intrin(int_8xl, "tir.reinterpret", vec_ai32) + # this broadcasts data to the vector size + a_int8 = ins[0].vload([0], "int8x4") + re_int32 = tvm.tir.call_intrin("int32", "tir.reinterpret", a_int8) + vec_ai32 = re_int32.astype("int32x2") + vec_a = tvm.tir.call_intrin(int_8xl, "tir.reinterpret", vec_ai32) - vec_b = ins[1].vload([idx * 2, 0], int_8xl) # we take two inputs at a time + vec_b = ins[1].vload([0, 0], "int8x16") + def pairwise_add_mul(extract_half): + vec_b_half = tvm.tir.call_intrin("int8x8", extract_half, vec_b) multiply = tvm.tir.call_llvm_pure_intrin( "int16x8", "llvm.aarch64.neon.smull.v8i16", # saturating pairwise multiplication tvm.tir.const(2, "uint32"), vec_a, - vec_b, + vec_b_half, ) pairwise_reduction = tvm.tir.call_llvm_pure_intrin( "int32x4", @@ -638,8 +639,8 @@ def pairwise_add_mul(idx): ) return pairwise_reduction - pair_1 = pairwise_add_mul(0) - pair_2 = pairwise_add_mul(1) + pair_1 = pairwise_add_mul("tir.vectorlow") + pair_2 = pairwise_add_mul("tir.vectorhigh") quad_reduction = tvm.tir.call_llvm_pure_intrin( "int32x4", "llvm.aarch64.neon.addp.v4i32", diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 68eb4eb6f01b..c27ea81144ac 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -486,7 +486,6 @@ def conv2d_NCHWc_int8( oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple( kernel.shape ) - num_filter = oc_chunk * oc_bn groups = ic_chunk // ic_chunk_group dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 diff --git a/python/tvm/topi/x86/conv2d_int8.py b/python/tvm/topi/x86/conv2d_int8.py index b0edb02b0804..048d9468051b 100644 --- a/python/tvm/topi/x86/conv2d_int8.py +++ b/python/tvm/topi/x86/conv2d_int8.py @@ -120,7 +120,7 @@ def _pack_data(cfg, data, kernel): kernel = te.compute( (oc_chunk, ic_chunk, kh, kw, ic_bn // n_elems, oc_bn, n_elems), lambda occ, icc, k_h, k_w, icbc, ocb, icbb: kernel[ - occ * oc_bn + ocb, icc * ic_bn + icbc * ic_bn // n_elems + icbb, k_h, k_w + occ * oc_bn + ocb, icc * ic_bn + icbc * n_elems + icbb, k_h, k_w ], name="kernel_vec", ) diff --git a/tests/python/topi/python/test_topi_conv2d_int8.py b/tests/python/topi/python/test_topi_conv2d_int8.py index 96457d9b08e6..860118531e51 100644 --- a/tests/python/topi/python/test_topi_conv2d_int8.py +++ b/tests/python/topi/python/test_topi_conv2d_int8.py @@ -21,7 +21,6 @@ import tvm from tvm import te from tvm import autotvm -from tvm.autotvm.task.space import FallbackConfigEntity from tvm import topi import tvm.topi.testing from tvm.contrib.pickle_memoize import memoize @@ -34,6 +33,7 @@ from common import Int8Fallback import tvm.testing import pytest +import platform def compile_conv2d_NHWC_gemm_int8_arm( @@ -299,7 +299,6 @@ def get_ref_data(): a_np, w_np, b_np, c_np = get_ref_data() - print("Running on target: %s" % target) with tvm.target.Target(target): C = compute( A, @@ -311,8 +310,6 @@ def get_ref_data(): "NCHW", out_dtype, ) - print(C.shape) - print(bias.shape) if add_bias: C = topi.add(C, bias) if add_relu: @@ -342,6 +339,8 @@ def get_ref_data(): if build_only: return + print("Running on target: %s" % target) + func(*run_args) tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5) @@ -364,14 +363,15 @@ def get_ref_data(): # ), ] - # TODO(tvm-team): Properly run ARM code on CI aarch64 environment + build_only_aarch64 = platform.machine() != "aarch64" + targets.append( ( "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod", topi.arm_cpu.conv2d_NCHWc_int8, topi.arm_cpu.schedule_conv2d_NCHWc_int8, 8, - True, + build_only_aarch64, ) ) @@ -382,7 +382,7 @@ def get_ref_data(): topi.arm_cpu.conv2d_NCHWc_int8, topi.arm_cpu.schedule_conv2d_NCHWc_int8, 8, - True, + build_only_aarch64, ) )