Skip to content

Commit

Permalink
Rebase, improve DivideConstants and expand testing
Browse files Browse the repository at this point in the history
Make the DivideConstants to operate on non-flattened
tensors to support two core execution in U65.
  • Loading branch information
ekalda committed May 5, 2022
1 parent fe41d03 commit a3d1858
Show file tree
Hide file tree
Showing 8 changed files with 366 additions and 137 deletions.
12 changes: 6 additions & 6 deletions python/tvm/relay/backend/contrib/ethosu/tir/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@ def get_conv2d_params(stmt, producers_consumers):
scale_bias2_length = SCALE_BIAS_LENGTH * math.floor(channels / 2)

serial_scale_bias = SerialAddressRange(
address=tvm.tir.BufferLoad("uint8", scale_bias_load.buffer, scale_bias_base),
address=tvm.tir.BufferLoad(scale_bias_load.buffer, scale_bias_base),
length=scale_bias_length,
)
serial_scale_bias2 = SerialAddressRange(
address=tvm.tir.BufferLoad(
"uint8", scale_bias_load.buffer, scale_bias_base + scale_bias_length
scale_bias_load.buffer, [scale_bias_base[0] + scale_bias_length]
),
length=scale_bias2_length,
)
Expand All @@ -107,18 +107,18 @@ def get_conv2d_params(stmt, producers_consumers):
)

serial_weight = SerialAddressRange(
address=tvm.tir.BufferLoad("uint8", weight_load.buffer, weight_base),
address=tvm.tir.BufferLoad(weight_load.buffer, weight_base),
length=weight_length,
)
serial_weight2 = SerialAddressRange(
address=tvm.tir.BufferLoad("uint8", weight_load.buffer, weight_base + weight_length),
address=tvm.tir.BufferLoad(weight_load.buffer, [weight_base[0] + weight_length]),
length=weight2_length,
)
else:
scale_bias_length = SCALE_BIAS_LENGTH * channels

serial_scale_bias = SerialAddressRange(
address=tvm.tir.BufferLoad("uint8", scale_bias_load.buffer, scale_bias_base),
address=tvm.tir.BufferLoad(scale_bias_load.buffer, scale_bias_base),
length=scale_bias_length,
)
# Insert -1s into the spec to denote the absence of the other pointer
Expand All @@ -130,7 +130,7 @@ def get_conv2d_params(stmt, producers_consumers):
weight_length = channels * serial_kernel[0] * serial_kernel[1] * rc.extent.value

serial_weight = SerialAddressRange(
address=tvm.tir.BufferLoad("uint8", weight_load.buffer, weight_base),
address=tvm.tir.BufferLoad(weight_load.buffer, weight_base),
length=weight_length,
)
serial_weight2 = SerialAddressRange(
Expand Down
168 changes: 103 additions & 65 deletions python/tvm/relay/backend/contrib/ethosu/tir/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,20 +228,33 @@ def DivideConstants(const_dict):

def _visit(stmt):
new_args = []
# We don't want to divide the constant that will be executed on two cores in parallel
is_u65_conv2d = (
vela_api.get_accelerator_config() == vapi.NpuAccelerator.Ethos_U65_512
and stmt.args[0] == "ethosu_conv2d"
)
for i, arg in enumerate(stmt.args):
if isinstance(arg, tvm.tir.expr.BufferLoad):
# If we're trying to load a buffer that maps to a constant
if arg.buffer.data in buffer_to_const:
const = buffer_to_const[arg.buffer.data]

assert len(arg.indices) == 1, "Ethos-U passes expects flattened buffers"
flattened_const_shape = np.prod(const.shape)

offset = int(arg.indices[0])
# Note by convention the arg after a constant read is the length of the read
length = int(stmt.args[i + 1])
# If it's anything other than a full read, create a new buffer
if offset != 0 or len(const) != length:
new_consts.append(const[offset : offset + length])
if (offset != 0 or flattened_const_shape != length) and not is_u65_conv2d:
out_channels = const.shape[0]
offset_channels = int((offset * out_channels) / flattened_const_shape)
length_channels = int((length * out_channels) / flattened_const_shape)
# split the constant up across channels
split_const = np.split(const, out_channels, axis=0)
# create a new const out of the channels we want to keep
new_const = np.concatenate(
split_const[offset_channels : offset_channels + length_channels], axis=0
)
new_consts.append(new_const)
new_buffer = tvm.tir.decl_buffer(
(length,), arg.dtype, scope=arg.buffer.scope()
)
Expand All @@ -257,8 +270,8 @@ def _visit(stmt):
def _ftransform(f, mod, ctx):
for i, param in enumerate(f.params):
if i in const_dict:
buffer_to_const[param] = const_dict[i].flatten()
buffer_to_const[f.buffer_map[param].data] = const_dict[i].flatten()
buffer_to_const[param] = const_dict[i]
buffer_to_const[f.buffer_map[param].data] = const_dict[i]

new_body = tvm.tir.stmt_functor.ir_transform(f.body, _visit, None, ["tir.Call"])
# Both the params and buffer map need updating for the newly introduced buffers
Expand Down Expand Up @@ -312,7 +325,6 @@ def EncodeConstants(const_dict):
"""
new_const_dict = {}
buffer_to_offset = {}

def collect_encoding_definitions(stmt, old_buffer_to_const):
# Map from copy destination to copy source.
Expand Down Expand Up @@ -341,7 +353,7 @@ def _encode_weights(tir_extern_call, weights):
value = np.frombuffer(value_bytes, dtype="uint8")
return value

def _declare_constant_buffer(old_buffer, encoded_constants):
def _declare_constant_buffer(old_buffer, encoded_constants, split_idx):
"""Create a new buffer and add the old buffer and its pointer to the
rewriting maps."""
new_buffer = tvm.tir.decl_buffer(
Expand All @@ -356,30 +368,30 @@ def _declare_constant_buffer(old_buffer, encoded_constants):
"old_buffer": old_buffer,
"new_buffer": new_buffer,
"encoded_constants": encoded_constants,
"split_idx": split_idx,
}
)

def _encode_weights_or_bias(buffer1, buffer2, stmt, encode_func):
"""Encode the weights or align the bias either for one or two cores,
depending on the variant."""
#assert ptr1 in pointer_to_buffer
#buffer = pointer_to_buffer[ptr1]
constant = old_buffer_to_const[buffer1]

# If we have just one core, encode the whole constant
if buffer2 is None:
new_const = encode_func(stmt, constant)
return new_const, len(new_const)
return new_const, None

# Assume OHWI
# Assume that the constant tensor has not been flattened yet
assert len(constant.shape) != 1
channels = constant.shape[0]
split_const = np.split(constant, channels, axis=0)

const_list = [split_const[i] for i in range(channels) if i % 2 == 0]
const_to_encode = np.concatenate(const_list, axis=0)

new_const = encode_func(stmt, const_to_encode)
new_const_length = len(new_const)
split_idx = len(new_const)

# Encode half of the constant separately for the other core if it exists
assert buffer1.same_as(buffer2)
Expand All @@ -389,7 +401,7 @@ def _encode_weights_or_bias(buffer1, buffer2, stmt, encode_func):
new_const2 = encode_func(stmt, const2_to_encode)
new_const = np.append(new_const, new_const2).astype("uint8")

return new_const, new_const_length
return new_const, split_idx

def _visit(stmt):
if isinstance(stmt, tvm.tir.Call):
Expand All @@ -405,51 +417,50 @@ def _visit(stmt):
copied_buffers.append({"source": read_buffer, "dest": write_buffer})
copy_map[write_buffer] = read_buffer

ops_with_weights = {
"ethosu_conv2d": tirtocs.translate_ethosu_conv2d,
"ethosu_depthwise_conv2d": tirtocs.translate_ethosu_depthwise_conv2d,
}
if op in ops_with_weights.keys():
npu_op, _ = ops_with_weights[op](stmt)

# Encode the weights
weights_buffer = npu_op.weights[0].address.buffer
if weights_buffer in copy_map:
weights_buffer = copy_map[weights_buffer]
weights2_buffer = (
npu_op.weights[1].address.buffer
if accel_config == vapi.NpuAccelerator.Ethos_U65_512
else None
)
if weights2_buffer in copy_map:
weights2_buffer = copy_map[weights2_buffer]

new_weights, new_weights_length = _encode_weights_or_bias(
weights_buffer, weights2_buffer, stmt, _encode_weights
)
_declare_constant_buffer(weights_buffer, new_weights)
buffer_to_offset[weights_buffer] = new_weights_length

# Align the scale_bias to 16 bytes
scale_bias_buffer = npu_op.biases[0].address.buffer
if scale_bias_buffer in copy_map:
scale_bias_buffer = copy_map[scale_bias_buffer]
scale_bias2_buffer = (
npu_op.biases[1].address.buffer
if accel_config == vapi.NpuAccelerator.Ethos_U65_512
else None
)
if scale_bias2_buffer in copy_map:
scale_bias2_buffer = copy_map[scale_bias2_buffer]
ops_with_weights = {
"ethosu_conv2d": tirtocs.translate_ethosu_conv2d,
"ethosu_depthwise_conv2d": tirtocs.translate_ethosu_depthwise_conv2d,
}
if op in ops_with_weights:
npu_op, _ = ops_with_weights[op](stmt)

# Encode the weights
weights_buffer = npu_op.weights[0].address.buffer
if weights_buffer in copy_map:
weights_buffer = copy_map[weights_buffer]

# In case of U65 512 mac variant the weights are split across two cores
# and need to be encoded separately
weights2_buffer = (
npu_op.weights[1].address.buffer
if accel_config == vapi.NpuAccelerator.Ethos_U65_512
else None
)
if weights2_buffer in copy_map:
weights2_buffer = copy_map[weights2_buffer]

new_scale_bias, new_scale_bias_length = _encode_weights_or_bias(
scale_bias_buffer, scale_bias2_buffer, stmt, _align_scale_bias
)
new_weights, split_idx = _encode_weights_or_bias(
weights_buffer, weights2_buffer, stmt, _encode_weights
)
_declare_constant_buffer(weights_buffer, new_weights, split_idx)

# Align the scale_bias to 16 bytes
scale_bias_buffer = npu_op.biases[0].address.buffer
if scale_bias_buffer in copy_map:
scale_bias_buffer = copy_map[scale_bias_buffer]
scale_bias2_buffer = (
npu_op.biases[1].address.buffer
if accel_config == vapi.NpuAccelerator.Ethos_U65_512
else None
)
if scale_bias2_buffer in copy_map:
scale_bias2_buffer = copy_map[scale_bias2_buffer]

#scale_bias_buffer = pointer_to_buffer[scale_bias_pointer]
new_scale_bias, split_idx = _encode_weights_or_bias(
scale_bias_buffer, scale_bias2_buffer, stmt, _align_scale_bias
)

_declare_constant_buffer(scale_bias_buffer, new_scale_bias)
buffer_to_offset[scale_bias_buffer] = new_scale_bias_length
_declare_constant_buffer(scale_bias_buffer, new_scale_bias, split_idx)

tvm.tir.stmt_functor.post_order_visit(stmt, _visit)

Expand All @@ -458,7 +469,9 @@ def _visit(stmt):
"constant_buffer_replacements": constant_buffer_replacements,
}

def transform_stmt(stmt, buf_remap, var_remap, pointer_to_buffer, new_buffer_to_const):
def transform_stmt(
stmt, buf_remap, var_remap, pointer_to_buffer, new_buffer_to_const, new_buffer_to_split_idx
):
def _visit_rewrite(stmt):
if isinstance(stmt, tvm.tir.Call):
# For extern calls, we need to rewrite pairs of arguments corresponding to
Expand All @@ -473,7 +486,19 @@ def _visit_rewrite(stmt):
isinstance(prev_arg, tvm.tir.BufferLoad)
and prev_arg.buffer in new_buffer_to_const
):
arg = np.prod(list(prev_arg.buffer.shape))
buffer_size = np.prod(list(prev_arg.buffer.shape))
arg = buffer_size
# We have to check for split weights/bias for conv2d and depthwise_conv2d
if old_args[0] in ("ethosu_conv2d", "depthwise_conv2d"):
# We have split weights/bias
if prev_arg.buffer in new_buffer_to_split_idx:
split_idx = new_buffer_to_split_idx[prev_arg.buffer]
# The first half of the split buffer
if prev_arg.indices[0] == 0:
arg = split_idx
# the second half of the split buffer
else:
arg = buffer_size - split_idx

new_args.append(arg)

Expand Down Expand Up @@ -501,11 +526,12 @@ def _visit_rewrite(stmt):
# rewrite the nodes which contain the Buffers.
if isinstance(stmt, tvm.tir.BufferLoad):
if stmt.buffer in buf_remap:
offset = stmt.index
if offset != 0:
offset = buffer_to_offset[stmt.buffer]
# TODO: integrate the offset
return tvm.tir.BufferLoad(buf_remap[stmt.buffer], stmt.indices, stmt.span)
new_buffer = buf_remap[stmt.buffer]
new_indices = stmt.indices
offset = new_indices[0]
if offset != 0 and new_buffer in new_buffer_to_split_idx:
offset = new_buffer_to_split_idx[new_buffer]
return tvm.tir.BufferLoad(buf_remap[stmt.buffer], [offset], stmt.span)

if isinstance(stmt, tvm.tir.AttrStmt):
node_pointer = stmt.node
Expand Down Expand Up @@ -533,7 +559,7 @@ def _ftransform(f, mod, ctx):
old_buffer_to_const = {}
for i, param in enumerate(f.params):
if i in const_dict:
old_buffer_to_const[f.buffer_map[param]] = const_dict[i].flatten()
old_buffer_to_const[f.buffer_map[param]] = const_dict[i]

# Step 1: Collect information on the buffers that will be
# replaced by encodings.
Expand All @@ -543,12 +569,16 @@ def _ftransform(f, mod, ctx):
# collected information.
buf_remap = {}
new_buffer_to_const = {}
new_buffer_to_split_idx = {}

# Any encoded buffers must be replaced
for info in buffer_information["constant_buffer_replacements"]:
buf_remap[info["old_buffer"]] = info["new_buffer"]
new_buffer_to_const[info["new_buffer"]] = info["encoded_constants"]

if info["split_idx"]:
new_buffer_to_split_idx[info["new_buffer"]] = info["split_idx"]

# Any buffers that are copied into from an encoded buffer must
# be replaced.
for info in buffer_information["copied_buffers"]:
Expand All @@ -569,6 +599,9 @@ def _ftransform(f, mod, ctx):
if copy_source in new_buffer_to_const:
new_buffer_to_const[new_dest] = new_buffer_to_const[copy_source]

if copy_source in new_buffer_to_split_idx:
new_buffer_to_split_idx[new_dest] = new_buffer_to_split_idx[copy_source]

# Define additional dependent lookup tables.
var_remap = {old.data: new.data for (old, new) in buf_remap.items()}
pointer_to_buffer = {
Expand All @@ -577,7 +610,12 @@ def _ftransform(f, mod, ctx):

# Step 3: Then perform the rewrites
new_body = transform_stmt(
f.body, buf_remap, var_remap, pointer_to_buffer, new_buffer_to_const
f.body,
buf_remap,
var_remap,
pointer_to_buffer,
new_buffer_to_const,
new_buffer_to_split_idx,
)

# Step 4: Rewrite the buffer map and const dict to instead use the encoded versions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def replace_npu_address_range_with_address(npu_addr_range):
)
assert buffer in buffer_addresses.keys(), f"searching for buffer : {buffer}, but not found"
address, buffer_type = buffer_addresses[buffer]
address = address + int(npu_addr_range.address.index.value)
address = address + int(npu_addr_range.address.indices[0].value)
return vapi.NpuAddressRange(_get_region(buffer_type), address, npu_addr_range.length)

def replace_tir_loads(npu_object):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/contrib/ethosu/vela_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def encode_weights(
# The weight layout is assumed to be OHWI, always.
weights_layout="OHWI",
ifm_bitdepth=npu_op.ifm.data_type.size_in_bits(),
block_depth=npu_op.block_config.depth//2,
block_depth=npu_op.block_config.depth,
dilation=(npu_op.kernel.dilation_x, npu_op.kernel.dilation_y),
accel_config=accel_config,
is_depthwise=is_depthwise,
Expand Down
2 changes: 1 addition & 1 deletion tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def conv2d(x):
@pytest.mark.parametrize("kernel_shape", [(3, 2), (1, 3)])
@pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))])
@pytest.mark.parametrize("padding", ["SAME", "VALID"])
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize("accel_type", ACCEL_TYPES + ["ethos-u65-512"])
@pytest.mark.parametrize("activation", ["NONE", "RELU"])
def test_ethosu_conv2d_double(
ifm_shape,
Expand Down
Loading

0 comments on commit a3d1858

Please sign in to comment.