Skip to content

Commit

Permalink
[microNPU] Fix incorrectly calculated stride when converting NHWC to …
Browse files Browse the repository at this point in the history
…NHCWB16 (#9560)

Fixes an issue that causes strides to be incorrectly calculated when the
number of channels in the input is less than 16 and involves a conversion
from NHWC to NHCWB16. This is due to TVM being 'too smart' when analyzing
generated TE and removing compute that is deemed unnecessary. Consequently,
strides over data are incorrectly calculated leading to an output
mismatch.

The PR uses a reduce sum operation to trick TE's data dependency
analyzer into looping over a whole block (16), rather than the number
of channels actually used (< 16). This causes the calculated strides to
be a multiple of 16 which is required for NHCWB16 format.

Change-Id: Ibf76a94a12cebf51fa716fcac1de932a271c4a6d
  • Loading branch information
lhutton1 committed Nov 25, 2021
1 parent 0e818bb commit 238958f
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 46 deletions.
11 changes: 10 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/te/dma.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ def write_compute(
def convert_to_nhwc_compute(tensor: te.Tensor, layout: str, channels: int) -> te.Tensor:
"""Converts a tensor into NHWC layout if it's in NHWCB16 layout.
When the current layout is NHCWB16, a reduce sum operation is inserted
to ensure that the whole of the input tensor has a data dependency on
the copy operation. Without this, TVM removes compute that is deemed to
be unnecessary, which causes strides for the NPU to be calculated
incorrectly.
Parameters
----------
tensor : te.Tensor
Expand All @@ -167,9 +173,12 @@ def convert_to_nhwc_compute(tensor: te.Tensor, layout: str, channels: int) -> te
"layout": layout,
}
if layout == "NHCWB16":
rc = te.reduce_axis((0, 16), name="rc")
return te.compute(
(tensor.shape[0], tensor.shape[1], tensor.shape[3], channels),
lambda nn, hh, ww, cc: tensor(nn, hh, te.indexdiv(cc, 16), ww, te.indexmod(cc, 16)),
lambda nn, hh, ww, cc: te.sum(
tensor(nn, hh, te.indexdiv(cc, 16), ww, te.indexmod(rc, 16)), axis=rc
),
name="ethosu_convert_to_nhwc",
attrs=convert_to_nhwc_attrs,
)
Expand Down
12 changes: 10 additions & 2 deletions python/tvm/relay/backend/contrib/ethosu/tir/dma.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,18 @@ def get_convert_to_nhwc_params(stmt):
The pointer produced by the operation.
"""
_, body = get_op_attrs(stmt)
attrs, body = get_op_attrs(stmt)
_, _, _, c, _, inner = get_outer_loops(body, "NHWC")

# Ignore the reduce sum operation inserted to ensure
# compute that is deemed uneccesary isn't removed by TVM.
if attrs["layout"] == "NHCWB16":
inner = inner.body
input_pointer = inner.value.b.buffer_var
else:
input_pointer = inner.value.buffer_var

output_pointer = inner.buffer_var
input_pointer = inner.value.buffer_var
return c.extent, input_pointer, output_pointer


Expand Down
183 changes: 140 additions & 43 deletions tests/python/contrib/test_ethosu/test_replace_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,52 +26,18 @@
from .infra import make_ethosu_pooling, get_pooling_args


@pytest.mark.parametrize(
"ifm_shape, ofm_channels, ifm_layout, ofm_layout, rounding_mode",
[
((1, 5, 9, 3), 3, "NHWC", "NHWC", "TFL"),
((1, 8, 3, 9, 16), 40, "NHCWB16", "NHCWB16", "NATURAL"),
((1, 8, 3, 9, 16), 40, "NHCWB16", "NHWC", "TRUNCATE"),
((1, 8, 9, 40), 40, "NHWC", "NHCWB16", "TFL"),
],
)
@pytest.mark.parametrize("pooling_type", ["AVG", "MAX"])
@pytest.mark.parametrize("activation", ["NONE", "CLIP"])
def test_pooling_single(
def _create_serial_pooling(
ifm_shape,
ofm_channels,
ifm_layout,
ofm_layout,
pool_shape,
pooling_type,
activation,
rounding_mode,
strides,
padding,
activation="NONE",
rounding_mode="TFL",
):
pool_shape = (3, 2)
strides = (1, 2)
padding = (1, 1, 1, 0)
ifm = relay.var("ifm", shape=ifm_shape, dtype="int8")
pooling = make_ethosu_pooling(
ifm,
pooling_type,
pool_shape,
ofm_channels,
strides,
padding,
activation,
ifm_layout,
ofm_layout,
rounding_mode,
)
func = relay.Function(relay.analysis.free_vars(pooling), pooling)
func = run_opt_pass(func, relay.transform.InferType())
mod, _ = lower_to_tir(func)
data = []

def _visit(stmt):
if isinstance(stmt, tvm.tir.Call):
data.append(get_pooling_args(stmt))

tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _visit)
if ifm_layout == "NHWC":
ifm_stride_c = 1
ifm_stride_w = ifm_shape[3]
Expand All @@ -80,7 +46,7 @@ def _visit(stmt):
ofm_width = (ifm_shape[2] - pool_shape[1] + padding[1] + padding[1]) // strides[1] + 1
else:
ifm_stride_w = 16
ifm_stride_c = 16 * ifm_shape[3]
ifm_stride_c = 16 * ifm_shape[3] if ofm_channels >= 16 else 1
ifm_stride_h = 16 * ifm_shape[2] * ifm_shape[3]
ofm_height = (ifm_shape[1] - pool_shape[0] + padding[0] + padding[0]) // strides[0] + 1
ofm_width = (ifm_shape[3] - pool_shape[1] + padding[1] + padding[1]) // strides[1] + 1
Expand All @@ -91,10 +57,10 @@ def _visit(stmt):
ofm_stride_h = ofm_channels * ofm_width if ofm_height > 1 else 1
else:
ofm_stride_w = 16
ofm_stride_c = 16 * ofm_width
ofm_stride_c = 16 * ofm_width if ofm_channels >= 16 else 1
ofm_stride_h = 16 * ofm_width * ((ofm_channels - 1) // 16 + 1)

serial_pooling = spec.SerialPooling(
return spec.SerialPooling(
ifm=spec.SerialFeatureMap(
data_type="int8",
height=ifm_shape[1],
Expand Down Expand Up @@ -154,8 +120,139 @@ def _visit(stmt):
upscale="NONE",
)


@pytest.mark.parametrize(
"ifm_shape, ofm_channels, ifm_layout, ofm_layout, rounding_mode",
[
((1, 5, 9, 3), 3, "NHWC", "NHWC", "TFL"),
((1, 8, 3, 9, 16), 40, "NHCWB16", "NHCWB16", "NATURAL"),
((1, 8, 3, 9, 16), 40, "NHCWB16", "NHWC", "TRUNCATE"),
((1, 8, 9, 40), 40, "NHWC", "NHCWB16", "TFL"),
((1, 8, 9, 8), 8, "NHWC", "NHCWB16", "TFL"),
],
)
@pytest.mark.parametrize("pooling_type", ["AVG", "MAX"])
@pytest.mark.parametrize("activation", ["NONE", "CLIP"])
def test_pooling_single(
ifm_shape,
ofm_channels,
ifm_layout,
ofm_layout,
pooling_type,
activation,
rounding_mode,
):
pool_shape = (3, 2)
strides = (1, 2)
padding = (1, 1, 1, 0)
ifm = relay.var("ifm", shape=ifm_shape, dtype="int8")
pooling = make_ethosu_pooling(
ifm,
pooling_type,
pool_shape,
ofm_channels,
strides,
padding,
activation,
ifm_layout,
ofm_layout,
rounding_mode,
)
func = relay.Function(relay.analysis.free_vars(pooling), pooling)
func = run_opt_pass(func, relay.transform.InferType())
mod, _ = lower_to_tir(func)
data = []

def _visit(stmt):
if isinstance(stmt, tvm.tir.Call):
data.append(get_pooling_args(stmt))

tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _visit)

serial_pooling = _create_serial_pooling(
ifm_shape,
ofm_channels,
ifm_layout,
ofm_layout,
pool_shape,
pooling_type,
strides,
padding,
activation,
rounding_mode,
)
assert data[0] == ["ethosu_pooling"] + list(serial_pooling)


def test_correct_stride_with_multiple_pooling():
"""Testing a specific case of two pooling operations with NHWC inputs/outputs
but a NHCWB16 intermediate tensor. This lead to elements being accessed in the
wrong order by the NPU, due to incorrect stride values being calculated."""

ifm_shape = (1, 4, 4, 8)
ofm_channels = 8
pooling_type = "MAX"
pool_shape = (1, 1)
strides = (1, 1)
padding = (0, 0, 0, 0)

ifm = relay.var("ifm", shape=ifm_shape, dtype="int8")
op = make_ethosu_pooling(
ifm,
pooling_type,
pool_shape,
ofm_channels,
strides,
padding,
ifm_layout="NHWC",
ofm_layout="NHCWB16",
)
op = make_ethosu_pooling(
op,
pooling_type,
pool_shape,
ofm_channels,
strides,
padding,
ifm_layout="NHCWB16",
ofm_layout="NHWC",
)
func = relay.Function(relay.analysis.free_vars(op), op)
func = run_opt_pass(func, relay.transform.InferType())
mod, _ = lower_to_tir(func)

data = []

def _visit(stmt):
if isinstance(stmt, tvm.tir.Call):
data.append(get_pooling_args(stmt))

tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _visit)

serial_pooling_1 = _create_serial_pooling(
[1, 4, 4, 8],
8,
"NHWC",
"NHCWB16",
pool_shape,
pooling_type,
strides,
padding,
)
serial_pooling_2 = _create_serial_pooling(
[1, 4, 1, 4, 16],
8,
"NHCWB16",
"NHWC",
pool_shape,
pooling_type,
strides,
padding,
)

assert data[0] == ["ethosu_pooling"] + list(serial_pooling_1)
assert data[1] == ["ethosu_pooling"] + list(serial_pooling_2)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 238958f

Please sign in to comment.