diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index 7666691aa19f..98ee41f428b2 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -297,8 +297,6 @@ def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc: This returns the scheduled PrimFunc """ assert len(ext_func.params) == 1 - input_size = util.calculate_size_bytes(ext_func.params[0]) - output_size = util.calculate_size_bytes(ext_func.body) mod = tvm.IRModule() mod["main"] = ext_func mod = LegalizeEthosU()(mod) @@ -317,8 +315,6 @@ def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc: primfunc = tir_mod["main"] primfunc = primfunc.with_attr("global_symbol", ext_func.attrs["global_symbol"]) primfunc = primfunc.with_attr("ethos-u.constants", const_dict) - primfunc = primfunc.with_attr("ethos-u.input_size", input_size) - primfunc = primfunc.with_attr("ethos-u.output_size", output_size) return primfunc @@ -342,8 +338,6 @@ def primfunc_to_artifact(primfunc: tvm.tir.PrimFunc) -> util.CompilationArtifact """ symbol = str(primfunc.attrs["global_symbol"]) const_dict = primfunc.attrs["ethos-u.constants"] - input_size = primfunc.attrs["ethos-u.input_size"] - output_size = primfunc.attrs["ethos-u.output_size"] tir_mod = tvm.IRModule() tir_mod[symbol] = primfunc @@ -351,9 +345,7 @@ def primfunc_to_artifact(primfunc: tvm.tir.PrimFunc) -> util.CompilationArtifact for idx in const_dict.keys(): const_dict_with_int_keys[int(idx)] = const_dict[idx].numpy() - cmms, encoded_constants, scratch_size = tir_to_cs_translator.translate( + cmms, encoded_constants, base_addresses = tir_to_cs_translator.translate( tir_mod, const_dict_with_int_keys ) - return util.CompilationArtifact( - cmms, encoded_constants, scratch_size, input_size, output_size, symbol - ) + return util.CompilationArtifact(symbol, cmms, encoded_constants, base_addresses) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index 77fbc3e8628d..d7254511ebfc 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -18,7 +18,7 @@ the Relay to TIR compilation process, to Vela API calls to generate command stream. """ -from typing import Dict, NamedTuple, Tuple, Union +from typing import Dict, NamedTuple, Tuple, Union, List from enum import auto from enum import Enum import numpy as np # type: ignore @@ -102,8 +102,8 @@ def translate(tir_module, params): encoded_constants : str An hex string of the bytes that includes concat'd encoded weights, encoded biases and scales. - scratch_size : int - The size of the scratch buffer needed. + base_addresses : List[util.BaseAddress] + base addresses to be used by the driver """ buffer_info = extract_buffer_info(tir_module, params) @@ -112,10 +112,60 @@ def translate(tir_module, params): for call_extern in call_extern_list: _npu_ops.append(translate_ethosu_tir_call_extern(call_extern)) _npu_ops, constant_data, scratch_size = assign_addresses(buffer_info, _npu_ops) + base_addresses = extract_param_base_addresses(tir_module, buffer_info) + if scratch_size > 0: + base_addresses.append( + util.BaseAddress( + "scratch", + None, + _REGION_MAP[BufferType.scratch], + scratch_size, + True, + ) + ) target_accel_config = vela_api.get_accelerator_config() cmds = vapi.npu_generate_register_command_stream(_npu_ops, target_accel_config) payload = vapi.npu_create_driver_payload(cmds, target_accel_config) - return payload.hex(), constant_data, scratch_size + return payload.hex(), constant_data, base_addresses + + +def extract_param_base_addresses(mod, buffer_info) -> List[util.BaseAddress]: + """This function extracts base addresses to be used by the driver + + Parameters + ---------- + mod : tvm.IRModule + The TIR Module for NPU + buffer_info : Dict[tvm.tir.Var, BufferInfo] + Information regarding buffer vars used in the PrimFunc + + Returns + ------- + List[util.BaseAddress] + base addresses to be used by the driver + """ + # There should only be a single function + assert len(mod.functions.items()) == 1 + primfunc = mod.functions.items()[0][1] + + base_addresses = list() + idx = 0 + for param in primfunc.params: + # constants are pooled together and handled specially + # this will change after tir.allocate_const. + # For now, we are skipping generating buffer addresses here + if buffer_info[param].btype == BufferType.constant: + continue + buffer = primfunc.buffer_map[param] + dtype = buffer.dtype + element_size_bytes = np.iinfo(dtype).bits // 8 + size_bytes = element_size_bytes * np.prod(list(buffer.shape)) + base_addresses.append( + util.BaseAddress(param.name, idx, _REGION_MAP[buffer_info[param].btype], size_bytes) + ) + idx += 1 + + return base_addresses def extract_call_extern_list(mod): @@ -171,6 +221,7 @@ def extract_buffer_info( # There should only be a single function assert len(mod.functions.items()) == 1 primfunc = mod.functions.items()[0][1] + for idx, const_data in param_dict.items(): param = primfunc.params[idx] buffer_info[param] = BufferInfo( @@ -301,6 +352,9 @@ def classify_io(buffer): assert buffer_type in (BufferType.input, BufferType.output) address = 0 buffer_addresses[_buffer] = (address, buffer_type) + buffer_info[_buffer] = BufferInfo( + values=None, shape=info.dtype, dtype=info.dtype, btype=buffer_type + ) elif info.btype == BufferType.shram: accl_config = util.get_accelerator_config() arch_config = get_accelerator_arch_config(accl_config) diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index 21b0ecf789d2..fcc8e9e9df30 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -23,7 +23,7 @@ from inspect import signature from enum import Enum -from typing import Union, Tuple +from typing import Union, Tuple, List import numpy as np # type: ignore import tvm # type: ignore @@ -239,6 +239,31 @@ def calculate_size_bytes(expr): return element_size * elements +@register_object("relay.ext.ethos-u.BaseAddress") +class BaseAddress(Object): + """ + This is a structure to hold base addresses for pointers + provided for the driver. + """ + + def __init__( + self, + name: str, + primfunc_param_idx: int, + region: int, + size: int, + is_runtime_allocation: bool = False, + ): + self.__init_handle_by_constructor__( + _ffi_api.BaseAddress, # type: ignore # pylint: disable=no-member + name, + primfunc_param_idx, + region, + size, + is_runtime_allocation, + ) + + @register_object("relay.ext.ethos-u.CompilationArtifact") class CompilationArtifact(Object): """ @@ -248,19 +273,15 @@ class CompilationArtifact(Object): def __init__( self, + function_name: str, command_stream: str, encoded_constants: str, - scratch_size: int, - input_size: int, - output_size: int, - function_name: str, + base_addresses: List[BaseAddress], ): self.__init_handle_by_constructor__( _ffi_api.CompilationArtifact, # type: ignore # pylint: disable=no-member + function_name, command_stream, encoded_constants, - scratch_size, - input_size, - output_size, - function_name, + base_addresses, ) diff --git a/src/relay/backend/contrib/ethosu/source_module.cc b/src/relay/backend/contrib/ethosu/source_module.cc index 66955f8b201f..7d25505ab59c 100644 --- a/src/relay/backend/contrib/ethosu/source_module.cc +++ b/src/relay/backend/contrib/ethosu/source_module.cc @@ -124,7 +124,9 @@ class EthosUModuleNode : public ModuleNode { private: std::string c_source; Array compilation_artifacts_; + Map pool_var_names_; int indent_{0}; + constexpr static int kMaxBaseAddresses_ = 6; /*! * \brief Convert the raw string of hex values into a hex string @@ -150,7 +152,7 @@ class EthosUModuleNode : public ModuleNode { * \return string of code that updates the base_addrs array with the base address of the given * array */ - std::string SetBaseAddress(int index, std::string name, std::string size) { + std::string SetBaseAddress(int index, std::string name, int size) { std::stringstream ss; ss << " base_addrs[" << index << "] = (uintptr_t)(" << name << ");\n"; ss << " base_addrs_size[" << index << "] = " << size << ";\n"; @@ -178,11 +180,24 @@ class EthosUModuleNode : public ModuleNode { } /*! - * \brief Creates a runtime function header + * \brief Creates a runtime function signature */ - void PrintRuntimeFunctionHeader(std::stringstream& ss, std::string func_name) { - ss << "TVM_DLL int32_t "; - ss << func_name << "(void* input, void* output, void* resource_handle) {\n"; + void PrintRuntimeFunctionSignature(std::stringstream& ss, + const relay::contrib::ethosu::CompilationArtifact& artifact, + std::string func_name) { + ss << "TVM_DLL int32_t " << func_name; + ss << "("; + std::unordered_map param_idx_to_base_address; + for (const relay::contrib::ethosu::BaseAddress& base_address : artifact->base_addresses) { + if (base_address->primfunc_param_idx.defined()) { + param_idx_to_base_address[base_address->primfunc_param_idx] = base_address; + } + } + for (unsigned int i = 0; i < param_idx_to_base_address.size(); i++) { + relay::contrib::ethosu::BaseAddress base_address = param_idx_to_base_address[i]; + ss << "void* " << base_address->name << ","; + } + ss << "void* resource_handle) {\n"; } /*! @@ -216,7 +231,6 @@ class EthosUModuleNode : public ModuleNode { std::stringstream ss; size_t weights_size = (compilation_artifact->encoded_constants.size() / 2); - size_t scratch_size = compilation_artifact->scratch_size; ss << "// Update linker script to place .rodata.tvm in memory that can be accessed by the " "NPU\n"; if (weights_size > 0) { @@ -234,61 +248,44 @@ class EthosUModuleNode : public ModuleNode { ss << "\n"; PrintExternCPrefix(ss); - ss << "static int32_t " << func_no_dashes + "_(int8_t* in0, " - << "size_t in0_size, int8_t* out0, size_t out0_size, void* resource_handle) {\n"; - ss << " int num_tensors = 5;\n"; + PrintRuntimeFunctionSignature(ss, compilation_artifact, func_no_dashes); ss << " void* cms_data = (void*)(" << func_no_dashes << "_cms_data_data);\n"; ss << " int64_t device_type = kDLCPU;\n"; ss << " int64_t device_id = 0;\n"; - ss << " const size_t weights_size = " << std::to_string(weights_size) << ";\n"; - ss << " const size_t scratch_size = " << std::to_string(scratch_size) << ";\n"; ss << " const size_t cms_data_size = sizeof(" << func_no_dashes << "_cms_data_data);\n"; - if (scratch_size > 0) { - ss << " int8_t* scratch = (int8_t*) TVMBackendAllocWorkspace(device_type, device_id, " - "(uint64_t)scratch_size, 0, 16);\n"; - } else { - ss << " int8_t* scratch = NULL;\n"; - } - ss << " size_t base_addrs_size[num_tensors];\n"; - ss << " uint64_t base_addrs[num_tensors];\n"; + ss << " size_t base_addrs_size[" << kMaxBaseAddresses_ << "] = {0};\n"; + ss << " uint64_t base_addrs[" << kMaxBaseAddresses_ << "] = {0};\n"; ss << "\n"; - ss << SetBaseAddress(0, func_no_dashes + "_weights", "weights_size"); - ss << SetBaseAddress(1, "scratch", "scratch_size"); - ss << SetBaseAddress(2, "scratch", "scratch_size"); - ss << SetBaseAddress(3, "in0", "in0_size"); - ss << SetBaseAddress(4, "out0", "out0_size"); + + ss << SetBaseAddress(0, func_no_dashes + "_weights", weights_size); + for (const relay::contrib::ethosu::BaseAddress& base_address : + compilation_artifact->base_addresses) { + if (base_address->is_runtime_allocation) { + ss << " int8_t* " << base_address->name + << " = (int8_t*) TVMBackendAllocWorkspace(device_type, device_id, " + "(uint64_t)" + << base_address->size << ", 0, 16);\n"; + } + ss << SetBaseAddress(base_address->region->value, base_address->name.c_str(), + base_address->size->value); + } ss << "\n"; + ss << " int32_t result = TVMEthosULaunch(resource_handle, cms_data, cms_data_size, " - "base_addrs, base_addrs_size, num_tensors);\n"; - if (scratch_size > 0) { - ss << " TVMBackendFreeWorkspace(device_type, device_id, scratch);\n"; + "base_addrs, base_addrs_size, " + << kMaxBaseAddresses_ << ");\n"; + + for (const relay::contrib::ethosu::BaseAddress& base_address : + compilation_artifact->base_addresses) { + if (base_address->is_runtime_allocation) { + ss << " TVMBackendFreeWorkspace(device_type, device_id, " << base_address->name << ");\n"; + } } ss << " return result;\n"; ss << "}\n"; ss << "\n"; PrintExternCPostfix(ss); ss << "\n"; - PrintExternCPrefix(ss); - ss << "// Wrapper function is provided to allow for easier debugging\n"; - ss << "inline static int32_t " + func_no_dashes + - "_wrapper_(void* input, void* output, void* resource_handle) {\n"; - ss << " size_t input_data_size = " << compilation_artifact->input_size << ";\n"; - ss << " size_t output_data_size = " << compilation_artifact->output_size << ";\n"; - ss << " return " + func_no_dashes + - "_((int8_t*)input, input_data_size, (int8_t*)output, output_data_size, " + - "resource_handle);\n"; - ss << "}\n"; - PrintExternCPostfix(ss); - ss << "\n"; - PrintExternCPrefix(ss); - PrintRuntimeFunctionHeader(ss, func_no_dashes); - EnterScope(); - PrintIndents(ss); - ss << "return " << func_no_dashes << "_wrapper_(input, output, resource_handle);\n"; - ExitScope(); - ss << "}\n"; - PrintExternCPostfix(ss); - return ss.str(); } }; diff --git a/src/relay/backend/contrib/ethosu/utils.cc b/src/relay/backend/contrib/ethosu/utils.cc index 7e6c1c2ac840..01bd4d10324d 100644 --- a/src/relay/backend/contrib/ethosu/utils.cc +++ b/src/relay/backend/contrib/ethosu/utils.cc @@ -36,37 +36,54 @@ namespace relay { namespace contrib { namespace ethosu { -CompilationArtifact::CompilationArtifact(String command_stream, String encoded_constants, - Integer scratch_size, Integer input_size, - Integer output_size, String function_name) { +BaseAddress::BaseAddress(String name, Integer primfunc_param_idx, Integer region, Integer size, + Bool is_runtime_allocation) { + auto base_address_node = make_object(); + base_address_node->name = name; + base_address_node->primfunc_param_idx = primfunc_param_idx; + base_address_node->region = region; + base_address_node->size = size; + base_address_node->is_runtime_allocation = is_runtime_allocation; + data_ = std::move(base_address_node); +} + +TVM_REGISTER_NODE_TYPE(BaseAddressNode); +TVM_REGISTER_GLOBAL("relay.ext.ethos-u.BaseAddress") + .set_body_typed([](String name, Integer primfunc_param_idx, Integer region, Integer size, + Bool is_runtime_allocation) { + if (is_runtime_allocation.defined()) { + return BaseAddress(name, primfunc_param_idx, region, size, is_runtime_allocation); + } else { + return BaseAddress(name, primfunc_param_idx, region, size); + } + }); + +CompilationArtifact::CompilationArtifact(String function_name, String command_stream, + String encoded_constants, + Array base_addresses) { auto compilation_artifact_node = make_object(); + compilation_artifact_node->function_name = function_name; compilation_artifact_node->command_stream = command_stream; compilation_artifact_node->encoded_constants = encoded_constants; - compilation_artifact_node->scratch_size = scratch_size; - compilation_artifact_node->input_size = input_size; - compilation_artifact_node->output_size = output_size; - compilation_artifact_node->function_name = function_name; + compilation_artifact_node->base_addresses = base_addresses; data_ = std::move(compilation_artifact_node); } TVM_REGISTER_NODE_TYPE(CompilationArtifactNode); TVM_REGISTER_GLOBAL("relay.ext.ethos-u.CompilationArtifact") - .set_body_typed([](String command_stream, String encoded_constants, Integer scratch_size, - Integer input_size, Integer output_size, String function_name) { - return CompilationArtifact(command_stream, encoded_constants, scratch_size, input_size, - output_size, function_name); + .set_body_typed([](String function_name, String command_stream, String encoded_constants, + Array base_addresses) { + return CompilationArtifact(function_name, command_stream, encoded_constants, base_addresses); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "CompilationArtifactNode(\n" - << "command_stream=" << node->command_stream + << "function_name=" << node->function_name + << ",\n command_stream=" << node->command_stream << ",\n encoded_constants=" << node->encoded_constants - << ",\n scratch_size=" << node->scratch_size - << ",\n input_size=" << node->input_size - << ",\n output_size=" << node->output_size - << ",\n function_name=" << node->function_name << ")"; + << ",\n base_addresses=" << node->base_addresses << ")"; }); } // namespace ethosu diff --git a/src/relay/backend/contrib/ethosu/utils.h b/src/relay/backend/contrib/ethosu/utils.h index 5e9e337c3f17..5c61271d3425 100644 --- a/src/relay/backend/contrib/ethosu/utils.h +++ b/src/relay/backend/contrib/ethosu/utils.h @@ -34,47 +34,91 @@ namespace relay { namespace contrib { namespace ethosu { +/*! + * \brief Base addresses are input pointers to + * the driver that get accessed by the command stream + * using offsets to read/write data. + */ +struct BaseAddressNode : public Object { + /*! \brief The identifier, usually it the param name of the PrimFunc that gets lowered */ + String name; + /*! \brief The index in the params array of the PrimFunc. This is needed to keep aligned + * between the PrimFunc arguments ordering and argument ordering of generated code */ + Integer primfunc_param_idx; + /*! \brief The region used by the command stream. This needs to match with base address + * index passed into the driver */ + Integer region; + /*! \brief The size of the buffer accessible by this base address */ + Integer size; + /*! \brief This is a runtime allocation that needs to be done in the function */ + Bool is_runtime_allocation{Bool(false)}; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("primfunc_param_idx", &primfunc_param_idx); + v->Visit("region", ®ion); + v->Visit("size", &size); + v->Visit("is_runtime_allocation", &is_runtime_allocation); + } + + bool SEqualReduce(const BaseAddressNode* other, SEqualReducer equal) const { + return equal(name, other->name) && equal(primfunc_param_idx, other->primfunc_param_idx) && + equal(region, other->region) && equal(size, other->size) && + equal(is_runtime_allocation, other->is_runtime_allocation); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(name); + hash_reduce(primfunc_param_idx); + hash_reduce(region); + hash_reduce(size); + hash_reduce(is_runtime_allocation); + } + + static constexpr const char* _type_key = "relay.ext.ethos-u.BaseAddress"; + TVM_DECLARE_FINAL_OBJECT_INFO(BaseAddressNode, Object); +}; + +class BaseAddress : public ObjectRef { + public: + TVM_DLL BaseAddress(String name, Integer primfunc_param_idx, Integer region, Integer size, + Bool is_runtime_allocation = Bool(false)); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BaseAddress, ObjectRef, BaseAddressNode); +}; + /*! * \brief Captures all the binary artifactes required to create * the C-source runtime module */ struct CompilationArtifactNode : public Object { + /*! \brief The function name for this artifact belongs to */ + String function_name; /*! \brief The binary command stream (CS) in hex format */ String command_stream; /*! \brief The encoded biases and weights in hex format */ String encoded_constants; - /*! \brief The intermediary scratch area required for the execution of the CS */ - Integer scratch_size; - /*! \brief The size of the input tensor in bytes */ - Integer input_size; - /*! \brief The size of the output tensor in bytes */ - Integer output_size; - /*! \brief The name of the function */ - String function_name; + /*! \brief The information regarding the base addresses */ + Array base_addresses; void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("function_name", &function_name); v->Visit("command_stream", &command_stream); v->Visit("encoded_constants", &encoded_constants); - v->Visit("scratch_size", &scratch_size); - v->Visit("input_size", &input_size); - v->Visit("output_size", &output_size); - v->Visit("function_name", &function_name); + v->Visit("base_addresses", &base_addresses); } bool SEqualReduce(const CompilationArtifactNode* other, SEqualReducer equal) const { - return equal(command_stream, other->command_stream) && + return equal(function_name, other->function_name) && + equal(command_stream, other->command_stream) && equal(encoded_constants, other->encoded_constants) && - equal(scratch_size, other->scratch_size) && equal(input_size, other->input_size) && - equal(output_size, other->output_size) && equal(function_name, other->function_name); + equal(base_addresses, other->base_addresses); } void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(function_name); hash_reduce(command_stream); hash_reduce(encoded_constants); - hash_reduce(scratch_size); - hash_reduce(input_size); - hash_reduce(output_size); - hash_reduce(function_name); + hash_reduce(base_addresses); } static constexpr const char* _type_key = "relay.ext.ethos-u.CompilationArtifact"; @@ -83,8 +127,8 @@ struct CompilationArtifactNode : public Object { class CompilationArtifact : public ObjectRef { public: - TVM_DLL CompilationArtifact(String command_stream, String encoded_constants, Integer scratch_size, - Integer input_size, Integer output_size, String function_name); + TVM_DLL CompilationArtifact(String function_name, String command_stream, String encoded_constants, + Array base_addresses); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CompilationArtifact, ObjectRef, CompilationArtifactNode); }; diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index 6365e09246fc..fc43e1449d6a 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -60,12 +60,16 @@ PrimFunc MakeUnpackedAPI(PrimFunc&& func) { // Collect variables and buffers to map between Array args; + Map new_buffer_map; for (const Var& param : func->params) { // Ideally all func params should have Buffers defined in the buffer_map // We should look to insert buffer_maps for all PrimFuncs that are returned // to the core compiler. if (func->buffer_map.find(param) != func->buffer_map.end()) { args.push_back(func->buffer_map[param]->data); + // Rewiring the buffer_var to map to Buffers for low-level passes + // retain information about the buffer. + new_buffer_map.Set(func->buffer_map[param]->data, func->buffer_map[param]); } else { args.push_back(param); } @@ -79,6 +83,7 @@ PrimFunc MakeUnpackedAPI(PrimFunc&& func) { func_ptr->body = MergeNest(device_init, func_ptr->body); func_ptr->params = args; func_ptr->ret_type = PrimType(DataType::Int(32)); + func_ptr->buffer_map = new_buffer_map; // return the function. return std::move(func); diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py index c14deb636c25..0cadf96e7a18 100644 --- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py +++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py @@ -227,8 +227,16 @@ def test_buffer_info_extraction(): "uint8", tir_to_cs_translator.BufferType.input_or_output, ), - "ethosu_conv2d_2": ([1024], "uint8", tir_to_cs_translator.BufferType.scratch), - "ethosu_conv2d_3": ([2048], "uint8", tir_to_cs_translator.BufferType.scratch), + "ethosu_conv2d_2": ( + [1024], + "uint8", + tir_to_cs_translator.BufferType.scratch, + ), + "ethosu_conv2d_3": ( + [2048], + "uint8", + tir_to_cs_translator.BufferType.scratch, + ), }, }, ] @@ -805,12 +813,10 @@ def _check_buffer(address, region, length, buffer_var): # Every buffer is adjusted to align to 16 bytes size_in_bytes = util.round_up(size_in_bytes, 16) assert address + size_in_bytes <= scratch_size - # The scratch area should not be used by anyother buffer - assert not scratch_allocation_mask[address : address + size_in_bytes].any() + # The scratch area should not be used by any other buffer + assert not scratch_mask[address : address + size_in_bytes].any() # The scratch area is marked as used - scratch_allocation_mask[address : address + size_in_bytes] = np.ones( - size_in_bytes, dtype="uint8" - ) + scratch_mask[address : address + size_in_bytes] = np.ones(size_in_bytes, dtype="uint8") elif buffer_type == tir_to_cs_translator.BufferType.input: assert address == 0 else: @@ -887,14 +893,16 @@ def check_buffer(address, region, length, buffer_var): for extern_call in extern_calls: _npu_ops.append(tir_to_cs_translator.translate_ethosu_tir_call_extern(extern_call)) npu_op_tir_buffers = collect_tir_buffer_info(_npu_ops) - _npu_ops, constant_hex_string, scratch_size = tir_to_cs_translator.assign_addresses( - buffer_info, _npu_ops - ) - scratch_allocation_mask = np.zeros(scratch_size, dtype="uint8") + ( + _npu_ops, + constant_hex_string, + scratch_size, + ) = tir_to_cs_translator.assign_addresses(buffer_info, _npu_ops) + scratch_mask = np.zeros(scratch_size, dtype="uint8") constant_tensor_read_mask = np.zeros(len(constant_hex_string) // 2, dtype="uint8") verify(_npu_ops) # This will be only 1 if all allocated scratch is used. - assert np.prod(scratch_allocation_mask) == 1 + assert np.prod(scratch_mask) == 1 # This will be only 1 if all constant tensors is read at least once. assert np.prod(constant_tensor_read_mask) == 1