Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARM] Fix int8 NCHWc compute and alter layout #10839

Merged
merged 8 commits into from
Apr 1, 2022

Conversation

masahi
Copy link
Member

@masahi masahi commented Mar 31, 2022

This PR fixes a bug in TE ARM int8 compute for NCHWc conv2d, introduced in #10310. The compute itself, not the schedule, is broken for the following reasons:

Initially, I tried to keep n_elems = 8 in alter layout and fix the intrinsic definition. But n_elems = 8 breaks tensorization pattern matching, since now the compute is doing 4x8 innermost loop but this intrinsic is supposed to do 4x4 dot product, see

num_int8_elements = 4 # 4 int8 elements in int32
data = te.placeholder((num_int8_elements,), dtype="%s8" % dtype, name="data")
kernel = te.placeholder((int32_lanes, num_int8_elements), dtype="%s8" % dtype, name="kernel")
k = te.reduce_axis((0, num_int8_elements), name="k")
C = te.compute(
(int32_lanes,),
lambda i: te.sum(
data[k].astype("%s32" % dtype) * kernel[i, k].astype("%s32" % dtype), axis=k
),
name="C",
)
. Setting num_int8_elements = 8 there does fix the tensorize pattern matching, but the result was still incorrect.

Rather than fixing the intrin implementation in

def _intrin_func(ins, outs):
to adapt for 4x8 dot product, I settled on setting n_elems = 4 in alter layout. It turned out this change is enough to get the correct output. Moreover, n_elems = 8 is simply wrong for the dot product path in
if is_dotprod_available():
intrin = dot_int8_int8_int32_neon_82(int32_lanes=4, dtype=dtype)
which computes 4x4 dot product in one instruction.

@tkonolige I suggest doing perf benchmark again, since the numbers in #10310 are invalid.

cc @mbrookhart @Mousius @junrushao1994 @vinx13

@masahi masahi changed the title [ARM] Fix int8 NCHWc compute and tensor intrin for non dot product path [ARM] Fix int8 NCHWc compute and tensor intrin for non dot product path (rpi etc) Mar 31, 2022
@masahi masahi changed the title [ARM] Fix int8 NCHWc compute and tensor intrin for non dot product path (rpi etc) [ARM] Fix int8 NCHWc compute and alter layout Mar 31, 2022
@@ -364,7 +365,7 @@ def get_ref_data():
# ),
]

# TODO(tvm-team): Properly run ARM code on CI aarch64 environment
# TODO(tvm-team): Figure out ARM dot product availability on CI aarch64 environment
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @Mousius @u99127, I'd love to test the dot-product schedule on the aarch64 CI, do you know if it is supposed? Automatic detection would require /proc/cpuinfo etc as suggested by @u99127 in #10773 (comment), which I'd rather avoid.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I know, the CI environment should be good to run the dot-product schedules, I can take a look at cpuinfo detection later 😸

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup I enabled the dot product test on CI, it seems to be working!

https://ci.tlcpack.ai/blue/rest/organizations/jenkins/pipelines/tvm/branches/PR-10839/runs/5/nodes/316/steps/542/log/?start=0
(Search Running on target: llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod)

@@ -120,7 +120,7 @@ def _pack_data(cfg, data, kernel):
kernel = te.compute(
(oc_chunk, ic_chunk, kh, kw, ic_bn // n_elems, oc_bn, n_elems),
lambda occ, icc, k_h, k_w, icbc, ocb, icbb: kernel[
occ * oc_bn + ocb, icc * ic_bn + icbc * ic_bn // n_elems + icbb, k_h, k_w
occ * oc_bn + ocb, icc * ic_bn + icbc * n_elems + icbb, k_h, k_w
Copy link
Member Author

@masahi masahi Mar 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @tkonolige please have a look at this change. Since test_topi_conv2d_int8.py doesn't use the alter layout code (which had a bug), and _pack_data is using n_elems = 4, the reason aarch64 CI failed on test_topi_conv2d_int8.py was probably due to this bug.

Copy link
Contributor

@tkonolige tkonolige left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this @masahi! After reading through the code again (that I wrote...), it is doing a 4x4 dot product, so n_elems should be 4.

@junrushao junrushao merged commit 912993f into apache:main Apr 1, 2022
pfk-beta pushed a commit to pfk-beta/tvm that referenced this pull request Apr 11, 2022
This PR fixes a bug in TE ARM int8 compute for NCHWc conv2d, introduced in apache#10310. The compute itself, not the schedule, is broken for the following reasons:

* We are using `n_elems = 8` in https://github.com/apache/tvm/blob/e9091d6c68d5d70c28881e5c75bfe72e385c1f4d/python/tvm/topi/arm_cpu/conv2d_alter_op.py#L350. Thus, the innermost axis of the transformed kernel has extent 8: https://github.com/apache/tvm/blob/e9091d6c68d5d70c28881e5c75bfe72e385c1f4d/python/tvm/topi/arm_cpu/conv2d_alter_op.py#L375
* In the TE compute, we iterate over the innermost axis `ic_s_inner` of the kernel at https://github.com/apache/tvm/blob/f6f252f0abc8f621a96506739f9534083d1fe213/python/tvm/topi/nn/conv2d.py#L577. `ic_s_inner` has extent `n_elems` according to https://github.com/apache/tvm/blob/f6f252f0abc8f621a96506739f9534083d1fe213/python/tvm/topi/nn/conv2d.py#L566. `n_elems` is 4 by default according to https://github.com/apache/tvm/blob/f6f252f0abc8f621a96506739f9534083d1fe213/python/tvm/topi/nn/conv2d.py#L478
* The ARM code that calls this compute does not explicitly pass `n_elems`, according to https://github.com/apache/tvm/blob/e9091d6c68d5d70c28881e5c75bfe72e385c1f4d/python/tvm/topi/arm_cpu/conv2d_int8.py#L106-L108
* Thus, even though the innermost axis of the kernel has extent 8, the TE compute only loops over `n_elems = 4` of the input channel dimension. 

Initially, I tried to keep `n_elems = 8` in alter layout and fix the intrinsic definition. But `n_elems = 8` breaks tensorization pattern matching, since now the compute is doing 4x8 innermost loop but this intrinsic is supposed to do 4x4 dot product, see https://github.com/apache/tvm/blob/7896108fc41663a1fecbb52345194a93278e9e28/python/tvm/topi/arm_cpu/tensor_intrin.py#L467-L479. Setting `num_int8_elements = 8` there does fix the tensorize pattern matching, but the result was still incorrect.

Rather than fixing the intrin implementation in https://github.com/apache/tvm/blob/7896108fc41663a1fecbb52345194a93278e9e28/python/tvm/topi/arm_cpu/tensor_intrin.py#L492 to adapt for 4x8 dot product, I settled on setting `n_elems = 4` in alter layout. It turned out this change is enough to get the correct output. Moreover, `n_elems = 8` is simply wrong for the dot product path in https://github.com/apache/tvm/blob/7896108fc41663a1fecbb52345194a93278e9e28/python/tvm/topi/arm_cpu/conv2d_int8.py#L154-L155 which computes 4x4 dot product in one instruction. 

@tkonolige I suggest doing perf benchmark again, since the numbers in apache#10310 are invalid.

cc @mbrookhart @Mousius  @junrushao1994 @vinx13
mehrdadh pushed a commit to mehrdadh/tvm that referenced this pull request Apr 11, 2022
This PR fixes a bug in TE ARM int8 compute for NCHWc conv2d, introduced in apache#10310. The compute itself, not the schedule, is broken for the following reasons:

* We are using `n_elems = 8` in https://github.com/apache/tvm/blob/e9091d6c68d5d70c28881e5c75bfe72e385c1f4d/python/tvm/topi/arm_cpu/conv2d_alter_op.py#L350. Thus, the innermost axis of the transformed kernel has extent 8: https://github.com/apache/tvm/blob/e9091d6c68d5d70c28881e5c75bfe72e385c1f4d/python/tvm/topi/arm_cpu/conv2d_alter_op.py#L375
* In the TE compute, we iterate over the innermost axis `ic_s_inner` of the kernel at https://github.com/apache/tvm/blob/f6f252f0abc8f621a96506739f9534083d1fe213/python/tvm/topi/nn/conv2d.py#L577. `ic_s_inner` has extent `n_elems` according to https://github.com/apache/tvm/blob/f6f252f0abc8f621a96506739f9534083d1fe213/python/tvm/topi/nn/conv2d.py#L566. `n_elems` is 4 by default according to https://github.com/apache/tvm/blob/f6f252f0abc8f621a96506739f9534083d1fe213/python/tvm/topi/nn/conv2d.py#L478
* The ARM code that calls this compute does not explicitly pass `n_elems`, according to https://github.com/apache/tvm/blob/e9091d6c68d5d70c28881e5c75bfe72e385c1f4d/python/tvm/topi/arm_cpu/conv2d_int8.py#L106-L108
* Thus, even though the innermost axis of the kernel has extent 8, the TE compute only loops over `n_elems = 4` of the input channel dimension. 

Initially, I tried to keep `n_elems = 8` in alter layout and fix the intrinsic definition. But `n_elems = 8` breaks tensorization pattern matching, since now the compute is doing 4x8 innermost loop but this intrinsic is supposed to do 4x4 dot product, see https://github.com/apache/tvm/blob/7896108fc41663a1fecbb52345194a93278e9e28/python/tvm/topi/arm_cpu/tensor_intrin.py#L467-L479. Setting `num_int8_elements = 8` there does fix the tensorize pattern matching, but the result was still incorrect.

Rather than fixing the intrin implementation in https://github.com/apache/tvm/blob/7896108fc41663a1fecbb52345194a93278e9e28/python/tvm/topi/arm_cpu/tensor_intrin.py#L492 to adapt for 4x8 dot product, I settled on setting `n_elems = 4` in alter layout. It turned out this change is enough to get the correct output. Moreover, `n_elems = 8` is simply wrong for the dot product path in https://github.com/apache/tvm/blob/7896108fc41663a1fecbb52345194a93278e9e28/python/tvm/topi/arm_cpu/conv2d_int8.py#L154-L155 which computes 4x4 dot product in one instruction. 

@tkonolige I suggest doing perf benchmark again, since the numbers in apache#10310 are invalid.

cc @mbrookhart @Mousius  @junrushao1994 @vinx13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants