From 7c10538f6491dd8e3b8684c80c3153013aca94be Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 29 Mar 2022 02:41:03 +0900 Subject: [PATCH] [ARM] Fix NCHWc int8 dot product schedule lowering (#10773) * [ARM] Fix NCHWc int8 dot product schedule lowering * fix arm task extraction test not running * skip test on i386 --- python/tvm/relay/op/strategy/arm_cpu.py | 16 +-- python/tvm/topi/arm_cpu/conv2d_int8.py | 3 +- python/tvm/topi/arm_cpu/tensor_intrin.py | 2 +- .../topi/python/test_topi_conv2d_int8.py | 97 +++++++++---------- .../test_meta_schedule_integration.py | 4 +- 5 files changed, 60 insertions(+), 62 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 862377887fecc..03e884e8a9656 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -88,14 +88,6 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): if groups == 1: if layout == "NCHW": if kernel_layout == "OIHW": - # ARM conv2d spatial pack schedule. - strategy.add_implementation( - wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_spatial_pack), - wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_spatial_pack), - name="conv2d_nchw_spatial_pack.arm_cpu", - plevel=10, - ) - if ( topi.arm_cpu.is_int8_hw_support(data.dtype, kernel.dtype) and kernel.shape[1] >= 64 @@ -107,6 +99,14 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): plevel=15, ) else: + # ARM conv2d spatial pack schedule. + strategy.add_implementation( + wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_spatial_pack), + wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_spatial_pack), + name="conv2d_nchw_spatial_pack.arm_cpu", + plevel=10, + ) + strategy.add_implementation( wrap_compute_conv2d(topi.x86.conv2d_nchw), wrap_topi_schedule(topi.x86.schedule_conv2d_nchw), diff --git a/python/tvm/topi/arm_cpu/conv2d_int8.py b/python/tvm/topi/arm_cpu/conv2d_int8.py index d09433b16a784..91e3e79cf8c7a 100644 --- a/python/tvm/topi/arm_cpu/conv2d_int8.py +++ b/python/tvm/topi/arm_cpu/conv2d_int8.py @@ -152,8 +152,9 @@ def _callback(op): _, _, kh, kw, _, _, _ = get_const_tuple(kernel_vec.shape) dtype = "uint" if data.dtype == "uint8" else "int" if is_dotprod_available(): - intrin = dot_int8_int8_int32_neon_82(int32_lanes=4) + intrin = dot_int8_int8_int32_neon_82(int32_lanes=4, dtype=dtype) elif is_neon_available(): + assert dtype == "int", "uint8 not supported if dot product is not available" intrin = dot_int8_int8_int32_neon() else: raise RuntimeError( diff --git a/python/tvm/topi/arm_cpu/tensor_intrin.py b/python/tvm/topi/arm_cpu/tensor_intrin.py index 1f3577a46681e..d6b6f225890af 100644 --- a/python/tvm/topi/arm_cpu/tensor_intrin.py +++ b/python/tvm/topi/arm_cpu/tensor_intrin.py @@ -516,7 +516,7 @@ def _instr(index): int32_lanes * num_int8_elements, ) vdot = tvm.tir.call_llvm_pure_intrin( - dtype_c, inst, tvm.tir.const(2, "uint32"), vec_c, vec_a, vec_b + dtype_c, inst, tvm.tir.const(3, "uint32"), vec_c, vec_a, vec_b ) ib.emit(outs[0].vstore(0, vdot)) return ib.get() diff --git a/tests/python/topi/python/test_topi_conv2d_int8.py b/tests/python/topi/python/test_topi_conv2d_int8.py index d7a8d8bf2ca5b..96457d9b08e6c 100644 --- a/tests/python/topi/python/test_topi_conv2d_int8.py +++ b/tests/python/topi/python/test_topi_conv2d_int8.py @@ -259,7 +259,7 @@ def verify_conv2d_NCHWc_int8( lo = -128 if in_dtype == "int8" else 0 hi = 127 if in_dtype == "int8" else 255 - def check_target(target, compute, schedule, oc_block_factor): + def check_target(target, compute, schedule, oc_block_factor, build_only): dev = tvm.device(target, 0) if not tvm.testing.device_enabled(target): print("Skip because %s is not enabled" % target) @@ -323,45 +323,27 @@ def get_ref_data(): w = tvm.nd.array(w_np.astype(dtype), dev) b = tvm.nd.array(b_np.astype(out_dtype), dev) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev) + if add_bias: - tvm.build( - s, - [A, W, bias, C], - target, - name="relu_%d_%d_%d_%d_%d_%d_%d_%d" - % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), - ) - func = tvm.build( - s, - [A, W, bias, C], - target, - name="relu_%d_%d_%d_%d_%d_%d_%d_%d" - % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), - ) - try: - func(a, w, b, c) - except tvm.TVMError as e: - if "architecture mismatch" in str(e): - print(f"Skipping execution because {target} is not supported by this CPU") - return - else: - raise + compile_args = [A, W, bias, C] + run_args = [a, w, b, c] else: - func = tvm.build( - s, - [A, W, C], - target, - name="relu_%d_%d_%d_%d_%d_%d_%d_%d" - % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), - ) - try: - func(a, w, c) - except tvm.TVMError as e: - if "architecture mismatch" in str(e): - print(f"Skipping execution because {target} is not supported by this CPU") - return - else: - raise + compile_args = [A, W, C] + run_args = [a, w, c] + + func = tvm.build( + s, + compile_args, + target, + name="relu_%d_%d_%d_%d_%d_%d_%d_%d" + % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), + ) + + if build_only: + return + + func(*run_args) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5) targets = [ @@ -370,6 +352,7 @@ def get_ref_data(): lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o), topi.cuda.schedule_conv2d_NCHWc_int8, 4, + False, ), # Disable on CI since it does not support spirv int8 dot product # ( @@ -377,22 +360,34 @@ def get_ref_data(): # lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o), # topi.cuda.schedule_conv2d_NCHWc_int8, # 4, + # False, # ), ] - # TODO(Mousius) Re-enable once implementation is fixed - # if in_dtype == "int8": - # targets.append( - # ( - # "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon", - # topi.arm_cpu.conv2d_NCHWc_int8, - # topi.arm_cpu.schedule_conv2d_NCHWc_int8, - # 8, - # ) - # ) - - for target, compute, schedule, oc_block_factor in targets: - check_target(target, compute, schedule, oc_block_factor) + # TODO(tvm-team): Properly run ARM code on CI aarch64 environment + targets.append( + ( + "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod", + topi.arm_cpu.conv2d_NCHWc_int8, + topi.arm_cpu.schedule_conv2d_NCHWc_int8, + 8, + True, + ) + ) + + if in_dtype == "int8": + targets.append( + ( + "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon", + topi.arm_cpu.conv2d_NCHWc_int8, + topi.arm_cpu.schedule_conv2d_NCHWc_int8, + 8, + True, + ) + ) + + for target, compute, schedule, oc_block_factor, build_only in targets: + check_target(target, compute, schedule, oc_block_factor, build_only) def verify_conv2d_nchw_int8( diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py index 8186d3c178d65..d70c5ab1dc0e6 100644 --- a/tests/python/unittest/test_meta_schedule_integration.py +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -20,6 +20,7 @@ import pytest import tvm +import tvm.testing from tvm import relay from tvm import meta_schedule as ms from tvm.ir.module import IRModule @@ -151,7 +152,8 @@ def extract_task_qbert(): assert "vnni" in annotations["schedule_rule"] -def extract_task_arm_conv2d_nchwc(): +@tvm.testing.skip_if_32bit(reason="Apparently the LLVM version on i386 image is too old") +def test_extract_task_arm_conv2d_nchwc(): data_shape = (1, 64, 128, 128) weight_shape = (32, 64, 1, 1) bias_shape = (weight_shape[0],)