Skip to content

Commit

Permalink
Fix Conv2D TIR type sensitivity
Browse files Browse the repository at this point in the history
Change-Id: I3741f9dd8bb5952590ff8c586f6b96e5c3a03795
  • Loading branch information
mbaret authored and manupak committed Aug 25, 2021
1 parent f4aaae4 commit 02cbf82
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 11 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/contrib/ethosu/te/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def conv2d_compute(
).astype(ifm.dtype)
* weight[cc, rh, rw, rc].astype(ifm.dtype)
# This is a trick to load 10 elements of the scale_bias at once, not accurate maths
+ (scale_bias[cc, 0] * scale_bias[cc, 9]),
+ (scale_bias[cc, 0] * scale_bias[cc, 9]).astype(ifm.dtype),
axis=[rh, rw, rc],
),
name="ethosu_conv2d",
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/te/dma.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def _pad(*indices):
not_zero.append(indices[i] < tensor.shape[i] + pad_before[i])
if not_zero:
not_zero = tvm.tir.all(*not_zero)
return tvm.tir.if_then_else(not_zero, tensor(*index_tuple), tvm.tir.const(0, "uint8"))
return tvm.tir.if_then_else(
not_zero, tensor(*index_tuple), tvm.tir.const(0, tensor.dtype)
)
return tensor(*index_tuple)

return _pad
Expand Down
18 changes: 9 additions & 9 deletions python/tvm/relay/backend/contrib/ethosu/tir/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""Extract information from the convolution operators in TIR."""
import tvm
from ..vela_api import SCALE_BIAS_LENGTH
from .utils import get_outer_loops, get_op_attrs, get_base_address
from .utils import get_outer_loops, get_op_attrs, get_base_address, get_loads, get_stores
from .dma import get_ifm_params, get_ofm_params
from .spec import SerialKernel, SerialAddressRange, SerialActivation, Serial2DConvolution

Expand Down Expand Up @@ -53,9 +53,12 @@ def get_conv2d_params(stmt, producers, consumers):
rh = inner
rw = rh.body
rc = rw.body
compute = rc.body.value.b
input_pointer = compute.a.a.buffer_var
output_pointer = rc.body.buffer_var
# loads = [output, input, weights, scale_bias, scale_bias]
loads = get_loads(rc.body)
# stores = [output]
stores = get_stores(rc.body)
input_pointer = loads[1].buffer_var
output_pointer = stores[0].buffer_var
# Get feature map info
serial_ifm, serial_padding = get_ifm_params(input_pointer, producers)
serial_ofm, replace_pointer = get_ofm_params(output_pointer, consumers)
Expand All @@ -69,17 +72,14 @@ def get_conv2d_params(stmt, producers, consumers):
dilation_h=int(attrs["dilation_h"]),
)
# Get scale_bias info
scale_bias_mul = compute.b
if isinstance(scale_bias_mul, tvm.tir.Cast):
scale_bias_mul = scale_bias_mul.value
scale_bias_load = scale_bias_mul.a
scale_bias_load = loads[3]
scale_bias_base = get_base_address(scale_bias_load.index)
serial_scale_bias = SerialAddressRange(
address=tvm.tir.Load("uint8", scale_bias_load.buffer_var, scale_bias_base),
length=SCALE_BIAS_LENGTH * serial_ofm[3],
)
# Get weight info
weight_load = compute.a.b
weight_load = loads[2]
weight_base = get_base_address(weight_load.index)
serial_weight = SerialAddressRange(
address=tvm.tir.Load("uint8", weight_load.buffer_var, weight_base),
Expand Down
48 changes: 48 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/tir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,51 @@ def get_outer_loops(stmt, layout):
b = w.body
return n, h, w, cb, b, b.body
return None


def get_loads(stmt):
"""Get the Load statements.
Parameters
----------
stmt : tvm.tir.Stmt
The statement to get the Loads from.
Returns
-------
loads : list of tvm.tir.Load
The Loads found.
"""
loads = []

def _visit(s):
if isinstance(s, tvm.tir.Load):
loads.append(s)

tvm.tir.stmt_functor.post_order_visit(stmt, _visit)
return loads


def get_stores(stmt):
"""Get the Store statements.
Parameters
----------
stmt : tvm.tir.Stmt
The statement to get the Stores from.
Returns
-------
stores : list of tvm.tir.Store
The Stores found.
"""
stores = []

def _visit(s):
if isinstance(s, tvm.tir.Store):
stores.append(s)

tvm.tir.stmt_functor.post_order_visit(stmt, _visit)
return stores

0 comments on commit 02cbf82

Please sign in to comment.