diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py index 553f7779bb2a6..27aa462d60f26 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py @@ -89,12 +89,12 @@ def get_conv2d_params(stmt, producers_consumers): scale_bias2_length = SCALE_BIAS_LENGTH * math.floor(channels / 2) serial_scale_bias = SerialAddressRange( - address=tvm.tir.BufferLoad("uint8", scale_bias_load.buffer, scale_bias_base), + address=tvm.tir.BufferLoad(scale_bias_load.buffer, scale_bias_base), length=scale_bias_length, ) serial_scale_bias2 = SerialAddressRange( address=tvm.tir.BufferLoad( - "uint8", scale_bias_load.buffer, scale_bias_base + scale_bias_length + scale_bias_load.buffer, [scale_bias_base[0] + scale_bias_length] ), length=scale_bias2_length, ) @@ -107,18 +107,18 @@ def get_conv2d_params(stmt, producers_consumers): ) serial_weight = SerialAddressRange( - address=tvm.tir.BufferLoad("uint8", weight_load.buffer, weight_base), + address=tvm.tir.BufferLoad(weight_load.buffer, weight_base), length=weight_length, ) serial_weight2 = SerialAddressRange( - address=tvm.tir.BufferLoad("uint8", weight_load.buffer, weight_base + weight_length), + address=tvm.tir.BufferLoad(weight_load.buffer, [weight_base[0] + weight_length]), length=weight2_length, ) else: scale_bias_length = SCALE_BIAS_LENGTH * channels serial_scale_bias = SerialAddressRange( - address=tvm.tir.BufferLoad("uint8", scale_bias_load.buffer, scale_bias_base), + address=tvm.tir.BufferLoad(scale_bias_load.buffer, scale_bias_base), length=scale_bias_length, ) # Insert -1s into the spec to denote the absence of the other pointer @@ -130,7 +130,7 @@ def get_conv2d_params(stmt, producers_consumers): weight_length = channels * serial_kernel[0] * serial_kernel[1] * rc.extent.value serial_weight = SerialAddressRange( - address=tvm.tir.BufferLoad("uint8", weight_load.buffer, weight_base), + address=tvm.tir.BufferLoad(weight_load.buffer, weight_base), length=weight_length, ) serial_weight2 = SerialAddressRange( diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 6e0cfc13077fb..baadede08d668 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -228,20 +228,33 @@ def DivideConstants(const_dict): def _visit(stmt): new_args = [] + # We don't want to divide the constant that will be executed on two cores in parallel + is_u65_conv2d = ( + vela_api.get_accelerator_config() == vapi.NpuAccelerator.Ethos_U65_512 + and stmt.args[0] == "ethosu_conv2d" + ) for i, arg in enumerate(stmt.args): if isinstance(arg, tvm.tir.expr.BufferLoad): # If we're trying to load a buffer that maps to a constant if arg.buffer.data in buffer_to_const: const = buffer_to_const[arg.buffer.data] - - assert len(arg.indices) == 1, "Ethos-U passes expects flattened buffers" + flattened_const_shape = np.prod(const.shape) offset = int(arg.indices[0]) # Note by convention the arg after a constant read is the length of the read length = int(stmt.args[i + 1]) # If it's anything other than a full read, create a new buffer - if offset != 0 or len(const) != length: - new_consts.append(const[offset : offset + length]) + if (offset != 0 or flattened_const_shape != length) and not is_u65_conv2d: + out_channels = const.shape[0] + offset_channels = int((offset * out_channels) / flattened_const_shape) + length_channels = int((length * out_channels) / flattened_const_shape) + # split the constant up across channels + split_const = np.split(const, out_channels, axis=0) + # create a new const out of the channels we want to keep + new_const = np.concatenate( + split_const[offset_channels : offset_channels + length_channels], axis=0 + ) + new_consts.append(new_const) new_buffer = tvm.tir.decl_buffer( (length,), arg.dtype, scope=arg.buffer.scope() ) @@ -257,8 +270,8 @@ def _visit(stmt): def _ftransform(f, mod, ctx): for i, param in enumerate(f.params): if i in const_dict: - buffer_to_const[param] = const_dict[i].flatten() - buffer_to_const[f.buffer_map[param].data] = const_dict[i].flatten() + buffer_to_const[param] = const_dict[i] + buffer_to_const[f.buffer_map[param].data] = const_dict[i] new_body = tvm.tir.stmt_functor.ir_transform(f.body, _visit, None, ["tir.Call"]) # Both the params and buffer map need updating for the newly introduced buffers @@ -312,7 +325,6 @@ def EncodeConstants(const_dict): """ new_const_dict = {} - buffer_to_offset = {} def collect_encoding_definitions(stmt, old_buffer_to_const): # Map from copy destination to copy source. @@ -341,7 +353,7 @@ def _encode_weights(tir_extern_call, weights): value = np.frombuffer(value_bytes, dtype="uint8") return value - def _declare_constant_buffer(old_buffer, encoded_constants): + def _declare_constant_buffer(old_buffer, encoded_constants, split_idx): """Create a new buffer and add the old buffer and its pointer to the rewriting maps.""" new_buffer = tvm.tir.decl_buffer( @@ -356,22 +368,22 @@ def _declare_constant_buffer(old_buffer, encoded_constants): "old_buffer": old_buffer, "new_buffer": new_buffer, "encoded_constants": encoded_constants, + "split_idx": split_idx, } ) def _encode_weights_or_bias(buffer1, buffer2, stmt, encode_func): """Encode the weights or align the bias either for one or two cores, depending on the variant.""" - #assert ptr1 in pointer_to_buffer - #buffer = pointer_to_buffer[ptr1] constant = old_buffer_to_const[buffer1] # If we have just one core, encode the whole constant if buffer2 is None: new_const = encode_func(stmt, constant) - return new_const, len(new_const) + return new_const, None - # Assume OHWI + # Assume that the constant tensor has not been flattened yet + assert len(constant.shape) != 1 channels = constant.shape[0] split_const = np.split(constant, channels, axis=0) @@ -379,7 +391,7 @@ def _encode_weights_or_bias(buffer1, buffer2, stmt, encode_func): const_to_encode = np.concatenate(const_list, axis=0) new_const = encode_func(stmt, const_to_encode) - new_const_length = len(new_const) + split_idx = len(new_const) # Encode half of the constant separately for the other core if it exists assert buffer1.same_as(buffer2) @@ -389,7 +401,7 @@ def _encode_weights_or_bias(buffer1, buffer2, stmt, encode_func): new_const2 = encode_func(stmt, const2_to_encode) new_const = np.append(new_const, new_const2).astype("uint8") - return new_const, new_const_length + return new_const, split_idx def _visit(stmt): if isinstance(stmt, tvm.tir.Call): @@ -405,51 +417,50 @@ def _visit(stmt): copied_buffers.append({"source": read_buffer, "dest": write_buffer}) copy_map[write_buffer] = read_buffer - ops_with_weights = { - "ethosu_conv2d": tirtocs.translate_ethosu_conv2d, - "ethosu_depthwise_conv2d": tirtocs.translate_ethosu_depthwise_conv2d, - } - if op in ops_with_weights.keys(): - npu_op, _ = ops_with_weights[op](stmt) - - # Encode the weights - weights_buffer = npu_op.weights[0].address.buffer - if weights_buffer in copy_map: - weights_buffer = copy_map[weights_buffer] - weights2_buffer = ( - npu_op.weights[1].address.buffer - if accel_config == vapi.NpuAccelerator.Ethos_U65_512 - else None - ) - if weights2_buffer in copy_map: - weights2_buffer = copy_map[weights2_buffer] - - new_weights, new_weights_length = _encode_weights_or_bias( - weights_buffer, weights2_buffer, stmt, _encode_weights - ) - _declare_constant_buffer(weights_buffer, new_weights) - buffer_to_offset[weights_buffer] = new_weights_length - - # Align the scale_bias to 16 bytes - scale_bias_buffer = npu_op.biases[0].address.buffer - if scale_bias_buffer in copy_map: - scale_bias_buffer = copy_map[scale_bias_buffer] - scale_bias2_buffer = ( - npu_op.biases[1].address.buffer - if accel_config == vapi.NpuAccelerator.Ethos_U65_512 - else None - ) - if scale_bias2_buffer in copy_map: - scale_bias2_buffer = copy_map[scale_bias2_buffer] + ops_with_weights = { + "ethosu_conv2d": tirtocs.translate_ethosu_conv2d, + "ethosu_depthwise_conv2d": tirtocs.translate_ethosu_depthwise_conv2d, + } + if op in ops_with_weights: + npu_op, _ = ops_with_weights[op](stmt) + + # Encode the weights + weights_buffer = npu_op.weights[0].address.buffer + if weights_buffer in copy_map: + weights_buffer = copy_map[weights_buffer] + + # In case of U65 512 mac variant the weights are split across two cores + # and need to be encoded separately + weights2_buffer = ( + npu_op.weights[1].address.buffer + if accel_config == vapi.NpuAccelerator.Ethos_U65_512 + else None + ) + if weights2_buffer in copy_map: + weights2_buffer = copy_map[weights2_buffer] - new_scale_bias, new_scale_bias_length = _encode_weights_or_bias( - scale_bias_buffer, scale_bias2_buffer, stmt, _align_scale_bias - ) + new_weights, split_idx = _encode_weights_or_bias( + weights_buffer, weights2_buffer, stmt, _encode_weights + ) + _declare_constant_buffer(weights_buffer, new_weights, split_idx) + + # Align the scale_bias to 16 bytes + scale_bias_buffer = npu_op.biases[0].address.buffer + if scale_bias_buffer in copy_map: + scale_bias_buffer = copy_map[scale_bias_buffer] + scale_bias2_buffer = ( + npu_op.biases[1].address.buffer + if accel_config == vapi.NpuAccelerator.Ethos_U65_512 + else None + ) + if scale_bias2_buffer in copy_map: + scale_bias2_buffer = copy_map[scale_bias2_buffer] - #scale_bias_buffer = pointer_to_buffer[scale_bias_pointer] + new_scale_bias, split_idx = _encode_weights_or_bias( + scale_bias_buffer, scale_bias2_buffer, stmt, _align_scale_bias + ) - _declare_constant_buffer(scale_bias_buffer, new_scale_bias) - buffer_to_offset[scale_bias_buffer] = new_scale_bias_length + _declare_constant_buffer(scale_bias_buffer, new_scale_bias, split_idx) tvm.tir.stmt_functor.post_order_visit(stmt, _visit) @@ -458,7 +469,9 @@ def _visit(stmt): "constant_buffer_replacements": constant_buffer_replacements, } - def transform_stmt(stmt, buf_remap, var_remap, pointer_to_buffer, new_buffer_to_const): + def transform_stmt( + stmt, buf_remap, var_remap, pointer_to_buffer, new_buffer_to_const, new_buffer_to_split_idx + ): def _visit_rewrite(stmt): if isinstance(stmt, tvm.tir.Call): # For extern calls, we need to rewrite pairs of arguments corresponding to @@ -473,7 +486,19 @@ def _visit_rewrite(stmt): isinstance(prev_arg, tvm.tir.BufferLoad) and prev_arg.buffer in new_buffer_to_const ): - arg = np.prod(list(prev_arg.buffer.shape)) + buffer_size = np.prod(list(prev_arg.buffer.shape)) + arg = buffer_size + # We have to check for split weights/bias for conv2d and depthwise_conv2d + if old_args[0] in ("ethosu_conv2d", "depthwise_conv2d"): + # We have split weights/bias + if prev_arg.buffer in new_buffer_to_split_idx: + split_idx = new_buffer_to_split_idx[prev_arg.buffer] + # The first half of the split buffer + if prev_arg.indices[0] == 0: + arg = split_idx + # the second half of the split buffer + else: + arg = buffer_size - split_idx new_args.append(arg) @@ -501,11 +526,12 @@ def _visit_rewrite(stmt): # rewrite the nodes which contain the Buffers. if isinstance(stmt, tvm.tir.BufferLoad): if stmt.buffer in buf_remap: - offset = stmt.index - if offset != 0: - offset = buffer_to_offset[stmt.buffer] - # TODO: integrate the offset - return tvm.tir.BufferLoad(buf_remap[stmt.buffer], stmt.indices, stmt.span) + new_buffer = buf_remap[stmt.buffer] + new_indices = stmt.indices + offset = new_indices[0] + if offset != 0 and new_buffer in new_buffer_to_split_idx: + offset = new_buffer_to_split_idx[new_buffer] + return tvm.tir.BufferLoad(buf_remap[stmt.buffer], [offset], stmt.span) if isinstance(stmt, tvm.tir.AttrStmt): node_pointer = stmt.node @@ -533,7 +559,7 @@ def _ftransform(f, mod, ctx): old_buffer_to_const = {} for i, param in enumerate(f.params): if i in const_dict: - old_buffer_to_const[f.buffer_map[param]] = const_dict[i].flatten() + old_buffer_to_const[f.buffer_map[param]] = const_dict[i] # Step 1: Collect information on the buffers that will be # replaced by encodings. @@ -543,12 +569,16 @@ def _ftransform(f, mod, ctx): # collected information. buf_remap = {} new_buffer_to_const = {} + new_buffer_to_split_idx = {} # Any encoded buffers must be replaced for info in buffer_information["constant_buffer_replacements"]: buf_remap[info["old_buffer"]] = info["new_buffer"] new_buffer_to_const[info["new_buffer"]] = info["encoded_constants"] + if info["split_idx"]: + new_buffer_to_split_idx[info["new_buffer"]] = info["split_idx"] + # Any buffers that are copied into from an encoded buffer must # be replaced. for info in buffer_information["copied_buffers"]: @@ -569,6 +599,9 @@ def _ftransform(f, mod, ctx): if copy_source in new_buffer_to_const: new_buffer_to_const[new_dest] = new_buffer_to_const[copy_source] + if copy_source in new_buffer_to_split_idx: + new_buffer_to_split_idx[new_dest] = new_buffer_to_split_idx[copy_source] + # Define additional dependent lookup tables. var_remap = {old.data: new.data for (old, new) in buf_remap.items()} pointer_to_buffer = { @@ -577,7 +610,12 @@ def _ftransform(f, mod, ctx): # Step 3: Then perform the rewrites new_body = transform_stmt( - f.body, buf_remap, var_remap, pointer_to_buffer, new_buffer_to_const + f.body, + buf_remap, + var_remap, + pointer_to_buffer, + new_buffer_to_const, + new_buffer_to_split_idx, ) # Step 4: Rewrite the buffer map and const dict to instead use the encoded versions diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index 8b7dcf12ad822..a3d46170dfcaf 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -441,7 +441,7 @@ def replace_npu_address_range_with_address(npu_addr_range): ) assert buffer in buffer_addresses.keys(), f"searching for buffer : {buffer}, but not found" address, buffer_type = buffer_addresses[buffer] - address = address + int(npu_addr_range.address.index.value) + address = address + int(npu_addr_range.address.indices[0].value) return vapi.NpuAddressRange(_get_region(buffer_type), address, npu_addr_range.length) def replace_tir_loads(npu_object): diff --git a/python/tvm/relay/backend/contrib/ethosu/vela_api.py b/python/tvm/relay/backend/contrib/ethosu/vela_api.py index e2d42e03a3b7f..6d01e8de57b54 100644 --- a/python/tvm/relay/backend/contrib/ethosu/vela_api.py +++ b/python/tvm/relay/backend/contrib/ethosu/vela_api.py @@ -158,7 +158,7 @@ def encode_weights( # The weight layout is assumed to be OHWI, always. weights_layout="OHWI", ifm_bitdepth=npu_op.ifm.data_type.size_in_bits(), - block_depth=npu_op.block_config.depth//2, + block_depth=npu_op.block_config.depth, dilation=(npu_op.kernel.dilation_x, npu_op.kernel.dilation_y), accel_config=accel_config, is_depthwise=is_depthwise, diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 68eca6951243b..4268392f1b788 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -79,7 +79,7 @@ def conv2d(x): @pytest.mark.parametrize("kernel_shape", [(3, 2), (1, 3)]) @pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))]) @pytest.mark.parametrize("padding", ["SAME", "VALID"]) -@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize("accel_type", ACCEL_TYPES + ["ethos-u65-512"]) @pytest.mark.parametrize("activation", ["NONE", "RELU"]) def test_ethosu_conv2d_double( ifm_shape, diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 0fc7e57b18c54..457edc861d067 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -32,7 +32,7 @@ # fmt: off @tvm.script.ir_module -class WeightStreamOnly: +class WeightStreamOnlyU55: @T.prim_func def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: # function attr dict @@ -65,10 +65,66 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), T.evaluate(T.call_extern("ethosu_copy", buffer_7[0], 32, p2_global_1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1_global_1[0], 112, T.int8(-1), T.int8(-1), 12, p2_global_1[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None + + +@tvm.script.ir_module +class WeightStreamOnlyU65: + @T.prim_func + def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # buffer definition + buffer_encoded = T.buffer_decl([160], dtype="uint8") + buffer_encoded_1 = T.buffer_decl([32], dtype="uint8") + buffer_encoded_2 = T.buffer_decl([160], dtype="uint8") + buffer_encoded_3 = T.buffer_decl([32], dtype="uint8") + buffer_encoded_4 = T.buffer_decl([176], dtype="uint8") + buffer_encoded_5 = T.buffer_decl([32], dtype="uint8") + buffer_encoded_6 = T.buffer_decl([160], dtype="uint8") + buffer_encoded_7 = T.buffer_decl([32], dtype="uint8") + T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) + T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + # body + placeholder_global = T.allocate([176], "uint8", "global", annotations={"disable_lower_builtin":True}) + placeholder_global_1 = T.buffer_decl([160], dtype="uint8", data=placeholder_global.data) + placeholder_global_2 = T.buffer_decl([160], dtype="uint8", data=placeholder_global.data) + placeholder_global_3 = T.buffer_decl([160], dtype="uint8", data=placeholder_global.data) + placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) + placeholder_d_global_1 = T.buffer_decl([32], dtype="uint8", data=placeholder_d_global.data) + placeholder_d_global_2 = T.buffer_decl([32], dtype="uint8", data=placeholder_d_global.data) + placeholder_d_global_3 = T.buffer_decl([32], dtype="uint8", data=placeholder_d_global.data) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded[0], 160, placeholder_global_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_1[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_1[0], 80, placeholder_global_1[80], 80, 12, placeholder_d_global[0], 16, placeholder_d_global[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_2[0], 160, placeholder_global_2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_3[0], 32, placeholder_d_global_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_2[0], 80, placeholder_global_2[80], 80, 12, placeholder_d_global_1[0], 16, placeholder_d_global_1[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_4[0], 176, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_5[0], 32, placeholder_d_global_2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 96, placeholder_global[96], 80, 12, placeholder_d_global_2[0], 16, placeholder_d_global_2[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_6[0], 160, placeholder_global_3[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_7[0], 32, placeholder_d_global_3[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_3[0], 80, placeholder_global_3[80], 80, 12, placeholder_d_global_3[0], 16, placeholder_d_global_3[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + __tvm_meta__ = None # fmt: on -def test_weight_stream_only(): +@pytest.mark.parametrize( + "accelerator, reference_mod, reference_const_sizes", + [ + ( + "ethos-u55-128", + WeightStreamOnlyU55, + [128, 32, 112, 32, 112, 32, 112, 32], + ), + ( + "ethos-u65-512", + WeightStreamOnlyU65, + [160, 32, 160, 32, 176, 32, 160, 32], + ), + ], +) +def test_weight_stream_only(accelerator, reference_mod, reference_const_sizes): def _planner(cached_func, const_dict, sch): weights = cached_func.inputs[1] bias = cached_func.inputs[2] @@ -95,21 +151,23 @@ def _get_func(): func = run_opt_pass(func, relay.transform.InferType()) return func - func = _get_func() - mod, consts = _lower_to_tir(func, cascader=_planner) - script = mod.script(show_meta=True) - test_mod = tvm.script.from_source(script) - reference_mod = WeightStreamOnly - tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) + config = { + "accelerator_config": accelerator, + } + with tvm.transform.PassContext(config={"relay.ext.ethos-u.options": config}): + func = _get_func() + mod, consts = _lower_to_tir(func, cascader=_planner) + script = mod.script(show_meta=True) + test_mod = tvm.script.from_source(script) + tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) - reference_const_sizes = [128, 32, 112, 32, 112, 32, 112, 32] - test_const_size = [value.size for value in list(consts.values())] - assert reference_const_sizes == test_const_size + test_const_size = [value.size for value in list(consts.values())] + assert reference_const_sizes == test_const_size # fmt: off @tvm.script.ir_module -class RereadWeights: +class RereadWeightsU55: @T.prim_func def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: # function attr dict @@ -128,10 +186,51 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 80, placeholder_d_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[64], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 304, T.int8(-1), T.int8(-1), 12, placeholder_d_global[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None + + +@tvm.script.ir_module +class RereadWeightsU65: + @T.prim_func + def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # buffer definition + placeholder_encoded = T.buffer_decl([368], dtype="uint8") + placeholder_encoded_1 = T.buffer_decl([96], dtype="uint8") + T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) + T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + # body + placeholder_global = T.allocate([368], "uint8", "global", annotations={"disable_lower_builtin":True}) + placeholder_global_1 = T.buffer_decl([368], dtype="uint8", data=placeholder_global.data) + placeholder_d_global = T.allocate([96], "uint8", "global", annotations={"disable_lower_builtin":True}) + placeholder_d_global_1 = T.buffer_decl([96], dtype="uint8", data=placeholder_d_global.data) + T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded[0], 368, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1[0], 96, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 192, placeholder_global[192], 176, 12, placeholder_d_global[0], 48, placeholder_d_global[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded[0], 368, placeholder_global_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1[0], 96, placeholder_d_global_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[64], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_1[0], 192, placeholder_global_1[192], 176, 12, placeholder_d_global_1[0], 48, placeholder_d_global_1[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + + __tvm_meta__ = None # fmt: on -def test_re_read_weights(): +@pytest.mark.parametrize( + "accelerator, reference_mod, reference_const_sizes", + [ + ( + "ethos-u55-128", + RereadWeightsU55, + [304, 80], + ), + ( + "ethos-u65-512", + RereadWeightsU65, + [368, 96], + ), + ], +) +def test_re_read_weights(accelerator, reference_mod, reference_const_sizes): def _cascader(cached_func, const_dict, sch): weights = cached_func.inputs[1] bias = cached_func.inputs[2] @@ -158,21 +257,23 @@ def _get_func(): func = run_opt_pass(func, relay.transform.InferType()) return func - func = _get_func() - mod, consts = _lower_to_tir(func, cascader=_cascader) - script = mod.script(show_meta=True) - test_mod = tvm.script.from_source(script) - reference_mod = RereadWeights - tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) + config = { + "accelerator_config": accelerator, + } + with tvm.transform.PassContext(config={"relay.ext.ethos-u.options": config}): + func = _get_func() + mod, consts = _lower_to_tir(func, cascader=_cascader) + script = mod.script(show_meta=True) + test_mod = tvm.script.from_source(script) + tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) - reference_const_sizes = [304, 80] - test_const_size = [value.size for value in list(consts.values())] - assert reference_const_sizes == test_const_size + test_const_size = [value.size for value in list(consts.values())] + assert reference_const_sizes == test_const_size # fmt: off @tvm.script.ir_module -class DirectReadOnly: +class DirectReadOnlyU55: @T.prim_func def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: # function attr dict @@ -188,10 +289,45 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer[0], 592, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, buffer_2[0], 160, T.int8(-1), T.int8(-1), 12, buffer_3[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None + + +@tvm.script.ir_module +class DirectReadOnlyU65: + @T.prim_func + def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # buffer definition + placeholder_encoded = T.buffer_decl([608], dtype="uint8") + placeholder_encoded_1 = T.buffer_decl([160], dtype="uint8") + placeholder_encoded_2 = T.buffer_decl([208], dtype="uint8") + placeholder_encoded_3 = T.buffer_decl([96], dtype="uint8") + T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) + T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + # body + ethosu_write_2 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_encoded[0], 304, placeholder_encoded[304], 304, 12, placeholder_encoded_1[0], 80, placeholder_encoded_1[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_encoded_2[0], 112, placeholder_encoded_2[112], 96, 12, placeholder_encoded_3[0], 48, placeholder_encoded_3[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + __tvm_meta__ = None # fmt: on -def test_direct_read_only(): +@pytest.mark.parametrize( + "accelerator, reference_mod, reference_const_sizes", + [ + ( + "ethos-u55-128", + DirectReadOnlyU55, + [592, 160, 160, 80], + ), + ( + "ethos-u65-512", + DirectReadOnlyU65, + [608, 160, 208, 96], + ), + ], +) +def test_direct_read_only(accelerator, reference_mod, reference_const_sizes): def _get_func(): ifm = relay.var("ifm", shape=(1, 16, 16, 32), dtype="int8") conv1 = make_ethosu_conv2d( @@ -216,22 +352,25 @@ def _get_func(): func = run_opt_pass(func, relay.transform.InferType()) return func - func = _get_func() - mod, consts = _lower_to_tir(func) + config = { + "accelerator_config": accelerator, + } + with tvm.transform.PassContext(config={"relay.ext.ethos-u.options": config}): + func = _get_func() + mod, consts = _lower_to_tir(func) - script = mod.script(show_meta=True) - test_mod = tvm.script.from_source(script) - reference_mod = DirectReadOnly - tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) + print(mod.script()) + script = mod.script(show_meta=True) + test_mod = tvm.script.from_source(script) + tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) - reference_const_sizes = [592, 160, 160, 80] - test_const_size = [value.size for value in list(consts.values())] - assert reference_const_sizes == test_const_size + test_const_size = [value.size for value in list(consts.values())] + assert reference_const_sizes == test_const_size # fmt: off @tvm.script.ir_module -class MixedRead: +class MixedReadU55: @T.prim_func def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: # function attr dict @@ -266,10 +405,70 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), T.evaluate(T.call_extern("ethosu_copy", buffer_9[0], 32, placeholder_d_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, T.int8(-1), T.int8(-1), 12, placeholder_d_global[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None + + +@tvm.script.ir_module +class MixedReadU65: + @T.prim_func + def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # buffer definition + buffer_encoded = T.buffer_decl([96], dtype="uint8") + buffer_encoded_1 = T.buffer_decl([32], dtype="uint8") + buffer_encoded_2 = T.buffer_decl([96], dtype="uint8") + buffer_encoded_3 = T.buffer_decl([32], dtype="uint8") + buffer_encoded_4 = T.buffer_decl([96], dtype="uint8") + buffer_encoded_5 = T.buffer_decl([32], dtype="uint8") + buffer_encoded_6 = T.buffer_decl([96], dtype="uint8") + buffer_encoded_7 = T.buffer_decl([32], dtype="uint8") + placeholder_encoded = T.buffer_decl([608], dtype="uint8") + placeholder_encoded_1 = T.buffer_decl([160], dtype="uint8") + T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) + T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + # body + ethosu_write_2 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) + placeholder_global = T.allocate([96], "uint8", "global", annotations={"disable_lower_builtin":True}) + placeholder_global_1 = T.buffer_decl([96], dtype="uint8", data=placeholder_global.data) + placeholder_global_2 = T.buffer_decl([96], dtype="uint8", data=placeholder_global.data) + placeholder_global_3 = T.buffer_decl([96], dtype="uint8", data=placeholder_global.data) + placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) + placeholder_d_global_1 = T.buffer_decl([32], dtype="uint8", data=placeholder_d_global.data) + placeholder_d_global_2 = T.buffer_decl([32], dtype="uint8", data=placeholder_d_global.data) + placeholder_d_global_3 = T.buffer_decl([32], dtype="uint8", data=placeholder_d_global.data) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_encoded[0], 304, placeholder_encoded[304], 304, 12, placeholder_encoded_1[0], 80, placeholder_encoded_1[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded[0], 96, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_1[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 48, placeholder_global[48], 48, 12, placeholder_d_global[0], 16, placeholder_d_global[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_2[0], 96, placeholder_global_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_3[0], 32, placeholder_d_global_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_1[0], 48, placeholder_global_1[48], 48, 12, placeholder_d_global_1[0], 16, placeholder_d_global_1[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_4[0], 96, placeholder_global_2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_5[0], 32, placeholder_d_global_2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_2[0], 48, placeholder_global_2[48], 48, 12, placeholder_d_global_2[0], 16, placeholder_d_global_2[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_6[0], 96, placeholder_global_3[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_7[0], 32, placeholder_d_global_3[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_3[0], 48, placeholder_global_3[48], 48, 12, placeholder_d_global_3[0], 16, placeholder_d_global_3[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + __tvm_meta__ = None # fmt: on -def test_mixed_read(): +@pytest.mark.parametrize( + "accelerator, reference_mod, reference_const_sizes", + [ + ( + "ethos-u55-128", + MixedReadU55, + [592, 160, 80, 32, 80, 32, 80, 32, 80, 32], + ), + ( + "ethos-u65-512", + MixedReadU65, + [608, 160, 96, 32, 96, 32, 96, 32, 96, 32], + ), + ], +) +def test_mixed_read(accelerator, reference_mod, reference_const_sizes): def _planner(cached_func, const_dict, sch): weight = cached_func.inputs[4] scale_bias = cached_func.inputs[5] @@ -305,28 +504,20 @@ def _get_func(): func = run_opt_pass(func, relay.transform.InferType()) return func - func = _get_func() - mod, consts = _lower_to_tir(func, cascader=_planner) - - script = mod.script(show_meta=True) - test_mod = tvm.script.from_source(script) - reference_mod = MixedRead - tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) - - reference_const_sizes = [ - 592, - 160, - 80, - 32, - 80, - 32, - 80, - 32, - 80, - 32, - ] - test_const_size = [value.size for value in list(consts.values())] - assert reference_const_sizes == test_const_size + config = { + "accelerator_config": accelerator, + } + with tvm.transform.PassContext(config={"relay.ext.ethos-u.options": config}): + func = _get_func() + mod, consts = _lower_to_tir(func, cascader=_planner) + + script = mod.script(show_meta=True) + test_mod = tvm.script.from_source(script) + print(mod.script()) + tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) + + test_const_size = [value.size for value in list(consts.values())] + assert reference_const_sizes == test_const_size def test_constant_as_input(): diff --git a/tests/python/contrib/test_ethosu/test_remove_concatenates.py b/tests/python/contrib/test_ethosu/test_remove_concatenates.py index c09afdfbdaab1..cc996e59412ce 100644 --- a/tests/python/contrib/test_ethosu/test_remove_concatenates.py +++ b/tests/python/contrib/test_ethosu/test_remove_concatenates.py @@ -46,10 +46,10 @@ def main(placeholder: T.Buffer[(1536,), "int8"], placeholder_1: T.Buffer[(1280,) T.preflattened_buffer(T_concat, [1, 8, 32, 16], "int8", data=T_concat.data) # body T_concat_1 = T.allocate([2816], "int8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, T.load("int8", placeholder_1.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 160, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T.load("int8", T_concat_1, 192), 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer, 0), 2992, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_1, 0), 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, T.load("int8", T_concat_1, 192), 0, 0, 0, T.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T.load("int8", T_concat.data, 352), 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 2992, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_3, 0), 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 12, 16, 8, 0, 12, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 192, 16, 1, "int8", 8, 12, 16, 8, 0, 12, T.load("int8", T_concat_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_4, 0), 2992, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_5, 0), 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 22, 16, 8, 0, 22, T.load("int8", T_concat_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 22, 16, 8, 0, 22, T.load("int8", T_concat.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_6, 0), 2992, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_7, 0), 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, placeholder_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 160, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T_concat_1[192], 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 2992, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, T_concat_1[192], 0, 0, 0, T.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T_concat[352], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 2992, T.int8(-1), T.int8(-1), 12, buffer_3[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 12, 16, 8, 0, 12, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 192, 16, 1, "int8", 8, 12, 16, 8, 0, 12, T_concat_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, buffer_4[0], 2992, T.int8(-1), T.int8(-1), 12, buffer_5[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 22, 16, 8, 0, 22, T_concat_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 22, 16, 8, 0, 22, T_concat[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, buffer_6[0], 2992, T.int8(-1), T.int8(-1), 12, buffer_7[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index 3646b99190c98..b49890a9cf366 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -146,8 +146,8 @@ def get_conv2d_args(call, include_buffers=False, remove_constants=False): continue elif isinstance(arg, tvm.tir.expr.IntImm) or isinstance(arg, tvm.tir.expr.FloatImm): conv_args.append(arg.value) - elif isinstance(arg, tvm.tir.expr.Load) and not include_buffers: - conv_args.append(arg.index) + elif isinstance(arg, tvm.tir.expr.BufferLoad) and not include_buffers: + conv_args.append(arg.indices[0]) else: conv_args.append(arg)