Skip to content

Commit

Permalink
[BYOC] [ACL] ACL Runtime padding workaround
Browse files Browse the repository at this point in the history
This workaround prevents execution of operations via ACL runtime
in case if arguments or output tensor require memory padding.
Workaround is applicable to all ACL versions prior forecoming ACL 20.11
(which will not use data padding).
  • Loading branch information
d-smirnov committed Nov 5, 2020
1 parent 3ff0100 commit ad7d495
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 30 deletions.
42 changes: 36 additions & 6 deletions python/tvm/relay/op/contrib/arm_compute_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
# pylint: disable=invalid-name, unused-argument
"""Arm Compute Library supported operators."""
import tvm
import numpy as np

from tvm.relay.expr import const
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name
Expand Down Expand Up @@ -279,7 +281,7 @@ def dense(expr):
return False
if attrs.out_dtype != "float32" and attrs.out_dtype != "":
return False
return True
return not require_padding([*args, expr.checked_type])


def qnn_dense(expr):
Expand All @@ -293,7 +295,7 @@ def qnn_dense(expr):
return False
if attrs.out_dtype != "int32":
return False
return True
return not require_padding([*args, expr.checked_type])


@tvm.ir.register_op_attr("nn.max_pool2d", "target.arm_compute_lib")
Expand All @@ -305,14 +307,41 @@ def max_pool2d(expr):
typ = args[0].checked_type
if typ.dtype not in ["float32", "uint8"]:
return False
return True
return not require_padding([*args, expr.checked_type])


def require_padding(inputs):
"""Checks whether supplied data will require padding.
Most of the operators ACL up to 20.11 uses padded data.
"""

def _check(shape, dtype):
"""NEON has 128bits/16bytes per vector"""
if len(shape) == 0:
return False
return (shape[-1] * np.dtype(dtype).itemsize) % 16 != 0

for i in inputs:
if isinstance(i, (tvm.relay.expr.Var, tvm.relay.expr.Call)):
if _check(i.checked_type.shape, i.checked_type.dtype):
return True
elif isinstance(i, tvm.relay.expr.Constant):
if _check(i.data.shape, i.data.dtype):
return True
elif isinstance(i, tvm.ir.tensor_type.TensorType):
if _check(i.shape, i.dtype):
return True
else:
raise Exception("Not supported")
return False


@tvm.ir.register_op_attr("nn.avg_pool2d", "target.arm_compute_lib")
def avg_pool2d(expr, from_quantized_composite=False):
"""Check if the external ACL codegen for avgpool2d should be used."""
attrs, args = expr.attrs, expr.args
typ = args[0].checked_type

if from_quantized_composite:
if typ.dtype != "int32":
return False
Expand All @@ -321,7 +350,8 @@ def avg_pool2d(expr, from_quantized_composite=False):
return False
if attrs.layout != "NHWC":
return False
return True

return not require_padding([*args, expr.checked_type])


@tvm.ir.register_op_attr("nn.global_max_pool2d", "target.arm_compute_lib")
Expand All @@ -333,7 +363,7 @@ def global_max_pool2d(expr):
return False
if attrs.layout != "NHWC":
return False
return True
return not require_padding([*args, expr.checked_type])


@tvm.ir.register_op_attr("nn.global_avg_pool2d", "target.arm_compute_lib")
Expand All @@ -345,7 +375,7 @@ def global_avg_pool2d(expr):
return False
if attrs.layout != "NHWC":
return False
return True
return not require_padding([*args, expr.checked_type])


@tvm.ir.register_op_attr("maximum", "target.arm_compute_lib")
Expand Down
1 change: 1 addition & 0 deletions src/runtime/contrib/arm_compute_lib/acl_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ arm_compute::Tensor MakeACLTensor(const JSONGraphNode& tensor_rep, void* data,
std::vector<int64_t> shape = tensor_rep.GetOpShape()[0];
DLDataType dtype = tensor_rep.GetOpDataType()[0];
arm_compute::TensorInfo info = MakeACLTensorInfo(shape, dtype, scale, offset);
info.set_is_resizable(false);
tensor.allocator()->init(info);
if (data != nullptr) {
CheckACLError(tensor.allocator()->import_memory(data));
Expand Down
3 changes: 2 additions & 1 deletion tests/python/contrib/test_arm_compute_lib/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,11 @@ def verify_codegen(
module,
known_good_codegen,
num_acl_modules,
tvm_ops=0,
target="llvm -mtriple=aarch64-linux-gnu -mattr=+neon",
):
"""Check acl codegen against a known good output."""
module = build_module(module, target)
module = build_module(module, target, tvm_ops=tvm_ops, acl_partitions=num_acl_modules)
acl_modules = extract_acl_modules(module)

assert len(acl_modules) == num_acl_modules, (
Expand Down
62 changes: 48 additions & 14 deletions tests/python/contrib/test_arm_compute_lib/test_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

import tvm
from tvm import relay

from .infrastructure import (
from tvm import testing
from test_arm_compute_lib.infrastructure import (
Device,
skip_runtime_test,
skip_codegen_test,
Expand Down Expand Up @@ -185,18 +185,34 @@ def test_dense():
np.random.seed(0)

dtype = ["float32"]
shape = [((1, 128), (16, 128), 16), ((32, 32), (32, 32), 32), ((1, 64), (1, 64), 1)]
shape = [
(1, (1, 128), (16, 128), 16),
(1, (32, 32), (32, 32), 32),
(0, (1, 64), (1, 64), 1),
(0, (11, 2), (2, 2), 2),
]
composite = [False, True]
trials = generate_trials([dtype, shape, composite], 3)

for dtype, (shape, weight_shape, units), composite in trials:
for dtype, (acl_partitions, shape, weight_shape, units), composite in trials:
outputs = []
inputs = {"a": tvm.nd.array(np.random.uniform(-128, 127, shape).astype(dtype))}
func, params = _get_model(
shape, weight_shape, units, dtype, var_names=iter(inputs), has_bias=composite
)
for acl in [False, True]:
outputs.append(build_and_run(func, inputs, 1, params, device, enable_acl=acl)[0])
outputs.append(
build_and_run(
func,
inputs,
1,
params,
device,
enable_acl=acl,
tvm_ops=(1 - acl_partitions) * (2 - int(not composite)),
acl_partitions=acl_partitions,
)[0]
)

config = {
"shape": shape,
Expand All @@ -215,18 +231,18 @@ def test_codegen_dense():
np.random.seed(0)

dtype = ["float32"]
shape = [((1, 128), (16, 128), 16), ((32, 32), (32, 32), 32), ((1, 64), (1, 64), 1)]
shape = [(1, (1, 128), (16, 128), 16), (1, (32, 32), (32, 32), 32), (0, (1, 64), (1, 64), 1)]
composite = [False, True]
trials = generate_trials([dtype, shape, composite], 3)

for dtype, (shape, weight_shape, units), composite in trials:
for dtype, (acl_partitions, shape, weight_shape, units), composite in trials:
inputs = {"a"}

args = (shape, weight_shape, units, dtype)

func, params = _get_model(*args, var_names=iter(inputs), has_bias=composite)
exp_codegen = _get_expected_codegen(*args, has_bias=composite)
verify_codegen(func, exp_codegen, 1)
verify_codegen(func, exp_codegen, acl_partitions, 1 - acl_partitions)


def test_qnn_dense():
Expand All @@ -239,11 +255,18 @@ def test_qnn_dense():
np.random.seed(0)

dtype = ["uint8"]
shape = [((1, 128), (16, 128), 16), ((32, 32), (32, 32), 32), ((1, 64), (1, 64), 1)]
shape = [
(0, (4, 4), (4, 4), 4),
(1, (16, 16), (4, 16), 4),
(1, (1, 128), (16, 128), 16),
(1, (32, 32), (32, 32), 32),
(0, (1, 64), (1, 64), 1),
]

composite = [False, True]
trials = generate_trials([dtype, shape, composite], 3)

for dtype, (shape, weight_shape, units), composite in trials:
for dtype, (acl_partitions, shape, weight_shape, units), composite in trials:
outputs = []
inputs = {"a": tvm.nd.array(np.random.uniform(0, 255, shape).astype(dtype))}
input_zp = 100
Expand All @@ -270,7 +293,18 @@ def test_qnn_dense():
)

for acl in [False, True]:
outputs.append(build_and_run(func, inputs, 1, params, device, enable_acl=acl)[0])
outputs.append(
build_and_run(
func,
inputs,
1,
params,
device,
tvm_ops=(1 - acl_partitions) * (3 - int(not composite)),
acl_partitions=acl_partitions,
enable_acl=acl,
)[0]
)

config = {
"shape": shape,
Expand All @@ -295,11 +329,11 @@ def test_codegen_qnn_dense():
np.random.seed(0)

dtype = ["uint8"]
shape = [((1, 128), (16, 128), 16), ((32, 32), (32, 32), 32), ((1, 64), (1, 64), 1)]
shape = [(1, (1, 128), (16, 128), 16), (1, (32, 32), (32, 32), 32), (0, (1, 64), (1, 64), 1)]
composite = [False, True]
trials = generate_trials([dtype, shape, composite], 3)

for dtype, (shape, weight_shape, units), composite in trials:
for dtype, (acl_partitions, shape, weight_shape, units), composite in trials:
inputs = {"a"}
args = (shape, weight_shape, units, dtype)

Expand All @@ -323,7 +357,7 @@ def test_codegen_qnn_dense():
has_bias=composite,
)
exp_codegen = _get_expected_codegen(*args, has_bias=composite)
verify_codegen(func, exp_codegen, 1)
verify_codegen(func, exp_codegen, acl_partitions, 2 - 2 * acl_partitions)


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions tests/python/contrib/test_arm_compute_lib/test_maximum.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import tvm
from tvm import relay
from tvm import testing

from .infrastructure import (
skip_runtime_test,
Expand Down
7 changes: 4 additions & 3 deletions tests/python/contrib/test_arm_compute_lib/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
"""Arm Compute Library network tests."""

import numpy as np

import pytest
from tvm import testing
from tvm import relay

from .infrastructure import skip_runtime_test, build_and_run, verify
from .infrastructure import Device
from test_arm_compute_lib.infrastructure import skip_runtime_test, build_and_run, verify
from test_arm_compute_lib.infrastructure import Device


def _build_and_run_network(mod, params, inputs, device, tvm_ops, acl_partitions, atol, rtol):
Expand Down
15 changes: 11 additions & 4 deletions tests/python/contrib/test_arm_compute_lib/test_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@

import tvm
from tvm import relay
from tvm import testing

from .infrastructure import (
from test_arm_compute_lib.infrastructure import (
skip_runtime_test,
skip_codegen_test,
build_and_run,
verify,
verify_codegen,
)
from .infrastructure import Device
from test_arm_compute_lib.infrastructure import Device


def _calculate_output_shape(shape, sizes, padding, strides):
Expand Down Expand Up @@ -166,7 +167,12 @@ def test_pooling():
fp32_dtype = ("float32", -127, 128, 0.001, 0.001)
uint8_dtype = ("uint8", 0, 255, 1, 0)

# nn.max_pool2d(%arm_compute_lib_11_i0, pool_size=[3, 3], strides=[2, 2], padding=[0, 0, 0, 0], layout="NHWC") /* ty=Tensor[(1, 27, 27, 256), float32] */

trials = [
# ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), False, False, (16, 16, 16)],
# ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 0), False, False, (2,2,1)],
["nn.max_pool2d", fp32_dtype, (3, 3), (2, 2), (0, 0), False, False, (27, 27, 512)],
["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)],
["nn.max_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), True, True, (15, 15, 16)],
["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 1), False, False, (16, 16, 16)],
Expand All @@ -175,7 +181,8 @@ def test_pooling():
["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), False, False, (16, 16, 16)],
["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)],
["nn.avg_pool2d", fp32_dtype, (3, 3), (2, 2), (0, 1), True, False, (15, 15, 16)],
["nn.avg_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), False, True, (16, 16, 16)],
# 20.05: "exclude_padding equal false is not supported for AVG Pooling with padding on quantized types"
# ["nn.avg_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), False, True, (16, 16, 16)],
["nn.avg_pool2d", uint8_dtype, (3, 3), (2, 2), (0, 1), False, False, (16, 16, 16)],
["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 1), True, False, (16, 16, 16)],
["nn.l2_pool2d", fp32_dtype, (3, 3), (2, 2), (0, 0), False, False, (16, 16, 16)],
Expand Down Expand Up @@ -211,6 +218,7 @@ def test_pooling():
"padding": pad,
"ceil_mode": ceil_mode,
"count_include_pad": count_include_pad,
"inputs": inputs,
}
verify_saturation = True if dtype == "uint8" else False

Expand Down Expand Up @@ -255,7 +263,6 @@ def test_global_pooling():
}

func = _get_global_pooling_model(shape, dtype, typef, iter(inputs))

config = {
"shape": shape,
"pooling type": typef,
Expand Down
5 changes: 3 additions & 2 deletions tests/python/contrib/test_arm_compute_lib/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import tvm
from tvm import relay
from tvm import testing

from .infrastructure import (
skip_runtime_test,
Expand Down Expand Up @@ -77,7 +78,7 @@ def test_reshape():
]:
inputs = {"a": tvm.nd.array(np.random.uniform(low, high, (1, 1, 1, 1000)).astype(dtype))}

for new_shape in [(1, 1000), (10, 10, 10)]:
for new_shape in [(1, 1000), (10, 10, 10), (10, 100, 1), (1, 1000, 1)]:
outputs = []
func = _get_model(inputs["a"].shape, new_shape, dtype, iter(inputs))
for acl in [False, True]:
Expand All @@ -98,7 +99,7 @@ def test_codegen_reshape():
shape = (1, 1, 1, 1000)
inputs = {"a"}
for dtype in ["float32", "uint8"]:
for new_shape in [(1, 1000), (10, 10, 10)]:
for new_shape in [(1, 1000), (10, 10, 10), (10, 100, 1)]:
args = (shape, new_shape, dtype)
func = _get_model(*args, iter(inputs))
exp_codegen = _get_expected_codegen(*args)
Expand Down

0 comments on commit ad7d495

Please sign in to comment.