Skip to content

Commit

Permalink
[ARM] Fix NCHWc int8 dot product schedule lowering (apache#10773)
Browse files Browse the repository at this point in the history
* [ARM] Fix NCHWc int8 dot product schedule lowering

* fix arm task extraction test not running

* skip test on i386
  • Loading branch information
masahi authored and pfk-beta committed Apr 11, 2022
1 parent 3238086 commit 7c10538
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 62 deletions.
16 changes: 8 additions & 8 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,6 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
if groups == 1:
if layout == "NCHW":
if kernel_layout == "OIHW":
# ARM conv2d spatial pack schedule.
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_spatial_pack),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_spatial_pack),
name="conv2d_nchw_spatial_pack.arm_cpu",
plevel=10,
)

if (
topi.arm_cpu.is_int8_hw_support(data.dtype, kernel.dtype)
and kernel.shape[1] >= 64
Expand All @@ -107,6 +99,14 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
plevel=15,
)
else:
# ARM conv2d spatial pack schedule.
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_spatial_pack),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_spatial_pack),
name="conv2d_nchw_spatial_pack.arm_cpu",
plevel=10,
)

strategy.add_implementation(
wrap_compute_conv2d(topi.x86.conv2d_nchw),
wrap_topi_schedule(topi.x86.schedule_conv2d_nchw),
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/topi/arm_cpu/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,9 @@ def _callback(op):
_, _, kh, kw, _, _, _ = get_const_tuple(kernel_vec.shape)
dtype = "uint" if data.dtype == "uint8" else "int"
if is_dotprod_available():
intrin = dot_int8_int8_int32_neon_82(int32_lanes=4)
intrin = dot_int8_int8_int32_neon_82(int32_lanes=4, dtype=dtype)
elif is_neon_available():
assert dtype == "int", "uint8 not supported if dot product is not available"
intrin = dot_int8_int8_int32_neon()
else:
raise RuntimeError(
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/arm_cpu/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def _instr(index):
int32_lanes * num_int8_elements,
)
vdot = tvm.tir.call_llvm_pure_intrin(
dtype_c, inst, tvm.tir.const(2, "uint32"), vec_c, vec_a, vec_b
dtype_c, inst, tvm.tir.const(3, "uint32"), vec_c, vec_a, vec_b
)
ib.emit(outs[0].vstore(0, vdot))
return ib.get()
Expand Down
97 changes: 46 additions & 51 deletions tests/python/topi/python/test_topi_conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def verify_conv2d_NCHWc_int8(
lo = -128 if in_dtype == "int8" else 0
hi = 127 if in_dtype == "int8" else 255

def check_target(target, compute, schedule, oc_block_factor):
def check_target(target, compute, schedule, oc_block_factor, build_only):
dev = tvm.device(target, 0)
if not tvm.testing.device_enabled(target):
print("Skip because %s is not enabled" % target)
Expand Down Expand Up @@ -323,45 +323,27 @@ def get_ref_data():
w = tvm.nd.array(w_np.astype(dtype), dev)
b = tvm.nd.array(b_np.astype(out_dtype), dev)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)

if add_bias:
tvm.build(
s,
[A, W, bias, C],
target,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
% (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
)
func = tvm.build(
s,
[A, W, bias, C],
target,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
% (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
)
try:
func(a, w, b, c)
except tvm.TVMError as e:
if "architecture mismatch" in str(e):
print(f"Skipping execution because {target} is not supported by this CPU")
return
else:
raise
compile_args = [A, W, bias, C]
run_args = [a, w, b, c]
else:
func = tvm.build(
s,
[A, W, C],
target,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
% (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
)
try:
func(a, w, c)
except tvm.TVMError as e:
if "architecture mismatch" in str(e):
print(f"Skipping execution because {target} is not supported by this CPU")
return
else:
raise
compile_args = [A, W, C]
run_args = [a, w, c]

func = tvm.build(
s,
compile_args,
target,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
% (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
)

if build_only:
return

func(*run_args)

tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)

targets = [
Expand All @@ -370,29 +352,42 @@ def get_ref_data():
lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
topi.cuda.schedule_conv2d_NCHWc_int8,
4,
False,
),
# Disable on CI since it does not support spirv int8 dot product
# (
# "vulkan -from_device=0",
# lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
# topi.cuda.schedule_conv2d_NCHWc_int8,
# 4,
# False,
# ),
]

# TODO(Mousius) Re-enable once implementation is fixed
# if in_dtype == "int8":
# targets.append(
# (
# "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
# topi.arm_cpu.conv2d_NCHWc_int8,
# topi.arm_cpu.schedule_conv2d_NCHWc_int8,
# 8,
# )
# )

for target, compute, schedule, oc_block_factor in targets:
check_target(target, compute, schedule, oc_block_factor)
# TODO(tvm-team): Properly run ARM code on CI aarch64 environment
targets.append(
(
"llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
topi.arm_cpu.conv2d_NCHWc_int8,
topi.arm_cpu.schedule_conv2d_NCHWc_int8,
8,
True,
)
)

if in_dtype == "int8":
targets.append(
(
"llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
topi.arm_cpu.conv2d_NCHWc_int8,
topi.arm_cpu.schedule_conv2d_NCHWc_int8,
8,
True,
)
)

for target, compute, schedule, oc_block_factor, build_only in targets:
check_target(target, compute, schedule, oc_block_factor, build_only)


def verify_conv2d_nchw_int8(
Expand Down
4 changes: 3 additions & 1 deletion tests/python/unittest/test_meta_schedule_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import pytest
import tvm
import tvm.testing
from tvm import relay
from tvm import meta_schedule as ms
from tvm.ir.module import IRModule
Expand Down Expand Up @@ -151,7 +152,8 @@ def extract_task_qbert():
assert "vnni" in annotations["schedule_rule"]


def extract_task_arm_conv2d_nchwc():
@tvm.testing.skip_if_32bit(reason="Apparently the LLVM version on i386 image is too old")
def test_extract_task_arm_conv2d_nchwc():
data_shape = (1, 64, 128, 128)
weight_shape = (32, 64, 1, 1)
bias_shape = (weight_shape[0],)
Expand Down

0 comments on commit 7c10538

Please sign in to comment.