From 7e50fc42c94ab990a810b378f34386413cd42087 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Wed, 26 Jan 2022 12:19:28 +0000 Subject: [PATCH] [microNPU] Refactor base address determination to codegen (#9929) This commit introduces BaseAddress ObjectRef to determine base addresses in the codegen for microNPU. This is required when multiple memory pools become available. Thus, base addresses could not be statically determined in the source module. --- .../relay/backend/contrib/ethosu/codegen.py | 12 +-- .../contrib/ethosu/tir_to_cs_translator.py | 62 ++++++++++++- .../tvm/relay/backend/contrib/ethosu/util.py | 39 ++++++-- .../backend/contrib/ethosu/source_module.cc | 93 +++++++++---------- src/relay/backend/contrib/ethosu/utils.cc | 49 ++++++---- src/relay/backend/contrib/ethosu/utils.h | 86 ++++++++++++----- src/tir/transforms/make_unpacked_api.cc | 5 + .../test_ethosu/test_tir_to_cs_translator.py | 32 ++++--- 8 files changed, 258 insertions(+), 120 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index 7666691aa19f..98ee41f428b2 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -297,8 +297,6 @@ def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc: This returns the scheduled PrimFunc """ assert len(ext_func.params) == 1 - input_size = util.calculate_size_bytes(ext_func.params[0]) - output_size = util.calculate_size_bytes(ext_func.body) mod = tvm.IRModule() mod["main"] = ext_func mod = LegalizeEthosU()(mod) @@ -317,8 +315,6 @@ def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc: primfunc = tir_mod["main"] primfunc = primfunc.with_attr("global_symbol", ext_func.attrs["global_symbol"]) primfunc = primfunc.with_attr("ethos-u.constants", const_dict) - primfunc = primfunc.with_attr("ethos-u.input_size", input_size) - primfunc = primfunc.with_attr("ethos-u.output_size", output_size) return primfunc @@ -342,8 +338,6 @@ def primfunc_to_artifact(primfunc: tvm.tir.PrimFunc) -> util.CompilationArtifact """ symbol = str(primfunc.attrs["global_symbol"]) const_dict = primfunc.attrs["ethos-u.constants"] - input_size = primfunc.attrs["ethos-u.input_size"] - output_size = primfunc.attrs["ethos-u.output_size"] tir_mod = tvm.IRModule() tir_mod[symbol] = primfunc @@ -351,9 +345,7 @@ def primfunc_to_artifact(primfunc: tvm.tir.PrimFunc) -> util.CompilationArtifact for idx in const_dict.keys(): const_dict_with_int_keys[int(idx)] = const_dict[idx].numpy() - cmms, encoded_constants, scratch_size = tir_to_cs_translator.translate( + cmms, encoded_constants, base_addresses = tir_to_cs_translator.translate( tir_mod, const_dict_with_int_keys ) - return util.CompilationArtifact( - cmms, encoded_constants, scratch_size, input_size, output_size, symbol - ) + return util.CompilationArtifact(symbol, cmms, encoded_constants, base_addresses) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index 77fbc3e8628d..d7254511ebfc 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -18,7 +18,7 @@ the Relay to TIR compilation process, to Vela API calls to generate command stream. """ -from typing import Dict, NamedTuple, Tuple, Union +from typing import Dict, NamedTuple, Tuple, Union, List from enum import auto from enum import Enum import numpy as np # type: ignore @@ -102,8 +102,8 @@ def translate(tir_module, params): encoded_constants : str An hex string of the bytes that includes concat'd encoded weights, encoded biases and scales. - scratch_size : int - The size of the scratch buffer needed. + base_addresses : List[util.BaseAddress] + base addresses to be used by the driver """ buffer_info = extract_buffer_info(tir_module, params) @@ -112,10 +112,60 @@ def translate(tir_module, params): for call_extern in call_extern_list: _npu_ops.append(translate_ethosu_tir_call_extern(call_extern)) _npu_ops, constant_data, scratch_size = assign_addresses(buffer_info, _npu_ops) + base_addresses = extract_param_base_addresses(tir_module, buffer_info) + if scratch_size > 0: + base_addresses.append( + util.BaseAddress( + "scratch", + None, + _REGION_MAP[BufferType.scratch], + scratch_size, + True, + ) + ) target_accel_config = vela_api.get_accelerator_config() cmds = vapi.npu_generate_register_command_stream(_npu_ops, target_accel_config) payload = vapi.npu_create_driver_payload(cmds, target_accel_config) - return payload.hex(), constant_data, scratch_size + return payload.hex(), constant_data, base_addresses + + +def extract_param_base_addresses(mod, buffer_info) -> List[util.BaseAddress]: + """This function extracts base addresses to be used by the driver + + Parameters + ---------- + mod : tvm.IRModule + The TIR Module for NPU + buffer_info : Dict[tvm.tir.Var, BufferInfo] + Information regarding buffer vars used in the PrimFunc + + Returns + ------- + List[util.BaseAddress] + base addresses to be used by the driver + """ + # There should only be a single function + assert len(mod.functions.items()) == 1 + primfunc = mod.functions.items()[0][1] + + base_addresses = list() + idx = 0 + for param in primfunc.params: + # constants are pooled together and handled specially + # this will change after tir.allocate_const. + # For now, we are skipping generating buffer addresses here + if buffer_info[param].btype == BufferType.constant: + continue + buffer = primfunc.buffer_map[param] + dtype = buffer.dtype + element_size_bytes = np.iinfo(dtype).bits // 8 + size_bytes = element_size_bytes * np.prod(list(buffer.shape)) + base_addresses.append( + util.BaseAddress(param.name, idx, _REGION_MAP[buffer_info[param].btype], size_bytes) + ) + idx += 1 + + return base_addresses def extract_call_extern_list(mod): @@ -171,6 +221,7 @@ def extract_buffer_info( # There should only be a single function assert len(mod.functions.items()) == 1 primfunc = mod.functions.items()[0][1] + for idx, const_data in param_dict.items(): param = primfunc.params[idx] buffer_info[param] = BufferInfo( @@ -301,6 +352,9 @@ def classify_io(buffer): assert buffer_type in (BufferType.input, BufferType.output) address = 0 buffer_addresses[_buffer] = (address, buffer_type) + buffer_info[_buffer] = BufferInfo( + values=None, shape=info.dtype, dtype=info.dtype, btype=buffer_type + ) elif info.btype == BufferType.shram: accl_config = util.get_accelerator_config() arch_config = get_accelerator_arch_config(accl_config) diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index 21b0ecf789d2..fcc8e9e9df30 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -23,7 +23,7 @@ from inspect import signature from enum import Enum -from typing import Union, Tuple +from typing import Union, Tuple, List import numpy as np # type: ignore import tvm # type: ignore @@ -239,6 +239,31 @@ def calculate_size_bytes(expr): return element_size * elements +@register_object("relay.ext.ethos-u.BaseAddress") +class BaseAddress(Object): + """ + This is a structure to hold base addresses for pointers + provided for the driver. + """ + + def __init__( + self, + name: str, + primfunc_param_idx: int, + region: int, + size: int, + is_runtime_allocation: bool = False, + ): + self.__init_handle_by_constructor__( + _ffi_api.BaseAddress, # type: ignore # pylint: disable=no-member + name, + primfunc_param_idx, + region, + size, + is_runtime_allocation, + ) + + @register_object("relay.ext.ethos-u.CompilationArtifact") class CompilationArtifact(Object): """ @@ -248,19 +273,15 @@ class CompilationArtifact(Object): def __init__( self, + function_name: str, command_stream: str, encoded_constants: str, - scratch_size: int, - input_size: int, - output_size: int, - function_name: str, + base_addresses: List[BaseAddress], ): self.__init_handle_by_constructor__( _ffi_api.CompilationArtifact, # type: ignore # pylint: disable=no-member + function_name, command_stream, encoded_constants, - scratch_size, - input_size, - output_size, - function_name, + base_addresses, ) diff --git a/src/relay/backend/contrib/ethosu/source_module.cc b/src/relay/backend/contrib/ethosu/source_module.cc index 66955f8b201f..7d25505ab59c 100644 --- a/src/relay/backend/contrib/ethosu/source_module.cc +++ b/src/relay/backend/contrib/ethosu/source_module.cc @@ -124,7 +124,9 @@ class EthosUModuleNode : public ModuleNode { private: std::string c_source; Array compilation_artifacts_; + Map pool_var_names_; int indent_{0}; + constexpr static int kMaxBaseAddresses_ = 6; /*! * \brief Convert the raw string of hex values into a hex string @@ -150,7 +152,7 @@ class EthosUModuleNode : public ModuleNode { * \return string of code that updates the base_addrs array with the base address of the given * array */ - std::string SetBaseAddress(int index, std::string name, std::string size) { + std::string SetBaseAddress(int index, std::string name, int size) { std::stringstream ss; ss << " base_addrs[" << index << "] = (uintptr_t)(" << name << ");\n"; ss << " base_addrs_size[" << index << "] = " << size << ";\n"; @@ -178,11 +180,24 @@ class EthosUModuleNode : public ModuleNode { } /*! - * \brief Creates a runtime function header + * \brief Creates a runtime function signature */ - void PrintRuntimeFunctionHeader(std::stringstream& ss, std::string func_name) { - ss << "TVM_DLL int32_t "; - ss << func_name << "(void* input, void* output, void* resource_handle) {\n"; + void PrintRuntimeFunctionSignature(std::stringstream& ss, + const relay::contrib::ethosu::CompilationArtifact& artifact, + std::string func_name) { + ss << "TVM_DLL int32_t " << func_name; + ss << "("; + std::unordered_map param_idx_to_base_address; + for (const relay::contrib::ethosu::BaseAddress& base_address : artifact->base_addresses) { + if (base_address->primfunc_param_idx.defined()) { + param_idx_to_base_address[base_address->primfunc_param_idx] = base_address; + } + } + for (unsigned int i = 0; i < param_idx_to_base_address.size(); i++) { + relay::contrib::ethosu::BaseAddress base_address = param_idx_to_base_address[i]; + ss << "void* " << base_address->name << ","; + } + ss << "void* resource_handle) {\n"; } /*! @@ -216,7 +231,6 @@ class EthosUModuleNode : public ModuleNode { std::stringstream ss; size_t weights_size = (compilation_artifact->encoded_constants.size() / 2); - size_t scratch_size = compilation_artifact->scratch_size; ss << "// Update linker script to place .rodata.tvm in memory that can be accessed by the " "NPU\n"; if (weights_size > 0) { @@ -234,61 +248,44 @@ class EthosUModuleNode : public ModuleNode { ss << "\n"; PrintExternCPrefix(ss); - ss << "static int32_t " << func_no_dashes + "_(int8_t* in0, " - << "size_t in0_size, int8_t* out0, size_t out0_size, void* resource_handle) {\n"; - ss << " int num_tensors = 5;\n"; + PrintRuntimeFunctionSignature(ss, compilation_artifact, func_no_dashes); ss << " void* cms_data = (void*)(" << func_no_dashes << "_cms_data_data);\n"; ss << " int64_t device_type = kDLCPU;\n"; ss << " int64_t device_id = 0;\n"; - ss << " const size_t weights_size = " << std::to_string(weights_size) << ";\n"; - ss << " const size_t scratch_size = " << std::to_string(scratch_size) << ";\n"; ss << " const size_t cms_data_size = sizeof(" << func_no_dashes << "_cms_data_data);\n"; - if (scratch_size > 0) { - ss << " int8_t* scratch = (int8_t*) TVMBackendAllocWorkspace(device_type, device_id, " - "(uint64_t)scratch_size, 0, 16);\n"; - } else { - ss << " int8_t* scratch = NULL;\n"; - } - ss << " size_t base_addrs_size[num_tensors];\n"; - ss << " uint64_t base_addrs[num_tensors];\n"; + ss << " size_t base_addrs_size[" << kMaxBaseAddresses_ << "] = {0};\n"; + ss << " uint64_t base_addrs[" << kMaxBaseAddresses_ << "] = {0};\n"; ss << "\n"; - ss << SetBaseAddress(0, func_no_dashes + "_weights", "weights_size"); - ss << SetBaseAddress(1, "scratch", "scratch_size"); - ss << SetBaseAddress(2, "scratch", "scratch_size"); - ss << SetBaseAddress(3, "in0", "in0_size"); - ss << SetBaseAddress(4, "out0", "out0_size"); + + ss << SetBaseAddress(0, func_no_dashes + "_weights", weights_size); + for (const relay::contrib::ethosu::BaseAddress& base_address : + compilation_artifact->base_addresses) { + if (base_address->is_runtime_allocation) { + ss << " int8_t* " << base_address->name + << " = (int8_t*) TVMBackendAllocWorkspace(device_type, device_id, " + "(uint64_t)" + << base_address->size << ", 0, 16);\n"; + } + ss << SetBaseAddress(base_address->region->value, base_address->name.c_str(), + base_address->size->value); + } ss << "\n"; + ss << " int32_t result = TVMEthosULaunch(resource_handle, cms_data, cms_data_size, " - "base_addrs, base_addrs_size, num_tensors);\n"; - if (scratch_size > 0) { - ss << " TVMBackendFreeWorkspace(device_type, device_id, scratch);\n"; + "base_addrs, base_addrs_size, " + << kMaxBaseAddresses_ << ");\n"; + + for (const relay::contrib::ethosu::BaseAddress& base_address : + compilation_artifact->base_addresses) { + if (base_address->is_runtime_allocation) { + ss << " TVMBackendFreeWorkspace(device_type, device_id, " << base_address->name << ");\n"; + } } ss << " return result;\n"; ss << "}\n"; ss << "\n"; PrintExternCPostfix(ss); ss << "\n"; - PrintExternCPrefix(ss); - ss << "// Wrapper function is provided to allow for easier debugging\n"; - ss << "inline static int32_t " + func_no_dashes + - "_wrapper_(void* input, void* output, void* resource_handle) {\n"; - ss << " size_t input_data_size = " << compilation_artifact->input_size << ";\n"; - ss << " size_t output_data_size = " << compilation_artifact->output_size << ";\n"; - ss << " return " + func_no_dashes + - "_((int8_t*)input, input_data_size, (int8_t*)output, output_data_size, " + - "resource_handle);\n"; - ss << "}\n"; - PrintExternCPostfix(ss); - ss << "\n"; - PrintExternCPrefix(ss); - PrintRuntimeFunctionHeader(ss, func_no_dashes); - EnterScope(); - PrintIndents(ss); - ss << "return " << func_no_dashes << "_wrapper_(input, output, resource_handle);\n"; - ExitScope(); - ss << "}\n"; - PrintExternCPostfix(ss); - return ss.str(); } }; diff --git a/src/relay/backend/contrib/ethosu/utils.cc b/src/relay/backend/contrib/ethosu/utils.cc index 7e6c1c2ac840..01bd4d10324d 100644 --- a/src/relay/backend/contrib/ethosu/utils.cc +++ b/src/relay/backend/contrib/ethosu/utils.cc @@ -36,37 +36,54 @@ namespace relay { namespace contrib { namespace ethosu { -CompilationArtifact::CompilationArtifact(String command_stream, String encoded_constants, - Integer scratch_size, Integer input_size, - Integer output_size, String function_name) { +BaseAddress::BaseAddress(String name, Integer primfunc_param_idx, Integer region, Integer size, + Bool is_runtime_allocation) { + auto base_address_node = make_object(); + base_address_node->name = name; + base_address_node->primfunc_param_idx = primfunc_param_idx; + base_address_node->region = region; + base_address_node->size = size; + base_address_node->is_runtime_allocation = is_runtime_allocation; + data_ = std::move(base_address_node); +} + +TVM_REGISTER_NODE_TYPE(BaseAddressNode); +TVM_REGISTER_GLOBAL("relay.ext.ethos-u.BaseAddress") + .set_body_typed([](String name, Integer primfunc_param_idx, Integer region, Integer size, + Bool is_runtime_allocation) { + if (is_runtime_allocation.defined()) { + return BaseAddress(name, primfunc_param_idx, region, size, is_runtime_allocation); + } else { + return BaseAddress(name, primfunc_param_idx, region, size); + } + }); + +CompilationArtifact::CompilationArtifact(String function_name, String command_stream, + String encoded_constants, + Array base_addresses) { auto compilation_artifact_node = make_object(); + compilation_artifact_node->function_name = function_name; compilation_artifact_node->command_stream = command_stream; compilation_artifact_node->encoded_constants = encoded_constants; - compilation_artifact_node->scratch_size = scratch_size; - compilation_artifact_node->input_size = input_size; - compilation_artifact_node->output_size = output_size; - compilation_artifact_node->function_name = function_name; + compilation_artifact_node->base_addresses = base_addresses; data_ = std::move(compilation_artifact_node); } TVM_REGISTER_NODE_TYPE(CompilationArtifactNode); TVM_REGISTER_GLOBAL("relay.ext.ethos-u.CompilationArtifact") - .set_body_typed([](String command_stream, String encoded_constants, Integer scratch_size, - Integer input_size, Integer output_size, String function_name) { - return CompilationArtifact(command_stream, encoded_constants, scratch_size, input_size, - output_size, function_name); + .set_body_typed([](String function_name, String command_stream, String encoded_constants, + Array base_addresses) { + return CompilationArtifact(function_name, command_stream, encoded_constants, base_addresses); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "CompilationArtifactNode(\n" - << "command_stream=" << node->command_stream + << "function_name=" << node->function_name + << ",\n command_stream=" << node->command_stream << ",\n encoded_constants=" << node->encoded_constants - << ",\n scratch_size=" << node->scratch_size - << ",\n input_size=" << node->input_size - << ",\n output_size=" << node->output_size - << ",\n function_name=" << node->function_name << ")"; + << ",\n base_addresses=" << node->base_addresses << ")"; }); } // namespace ethosu diff --git a/src/relay/backend/contrib/ethosu/utils.h b/src/relay/backend/contrib/ethosu/utils.h index 5e9e337c3f17..5c61271d3425 100644 --- a/src/relay/backend/contrib/ethosu/utils.h +++ b/src/relay/backend/contrib/ethosu/utils.h @@ -34,47 +34,91 @@ namespace relay { namespace contrib { namespace ethosu { +/*! + * \brief Base addresses are input pointers to + * the driver that get accessed by the command stream + * using offsets to read/write data. + */ +struct BaseAddressNode : public Object { + /*! \brief The identifier, usually it the param name of the PrimFunc that gets lowered */ + String name; + /*! \brief The index in the params array of the PrimFunc. This is needed to keep aligned + * between the PrimFunc arguments ordering and argument ordering of generated code */ + Integer primfunc_param_idx; + /*! \brief The region used by the command stream. This needs to match with base address + * index passed into the driver */ + Integer region; + /*! \brief The size of the buffer accessible by this base address */ + Integer size; + /*! \brief This is a runtime allocation that needs to be done in the function */ + Bool is_runtime_allocation{Bool(false)}; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("primfunc_param_idx", &primfunc_param_idx); + v->Visit("region", ®ion); + v->Visit("size", &size); + v->Visit("is_runtime_allocation", &is_runtime_allocation); + } + + bool SEqualReduce(const BaseAddressNode* other, SEqualReducer equal) const { + return equal(name, other->name) && equal(primfunc_param_idx, other->primfunc_param_idx) && + equal(region, other->region) && equal(size, other->size) && + equal(is_runtime_allocation, other->is_runtime_allocation); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(name); + hash_reduce(primfunc_param_idx); + hash_reduce(region); + hash_reduce(size); + hash_reduce(is_runtime_allocation); + } + + static constexpr const char* _type_key = "relay.ext.ethos-u.BaseAddress"; + TVM_DECLARE_FINAL_OBJECT_INFO(BaseAddressNode, Object); +}; + +class BaseAddress : public ObjectRef { + public: + TVM_DLL BaseAddress(String name, Integer primfunc_param_idx, Integer region, Integer size, + Bool is_runtime_allocation = Bool(false)); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BaseAddress, ObjectRef, BaseAddressNode); +}; + /*! * \brief Captures all the binary artifactes required to create * the C-source runtime module */ struct CompilationArtifactNode : public Object { + /*! \brief The function name for this artifact belongs to */ + String function_name; /*! \brief The binary command stream (CS) in hex format */ String command_stream; /*! \brief The encoded biases and weights in hex format */ String encoded_constants; - /*! \brief The intermediary scratch area required for the execution of the CS */ - Integer scratch_size; - /*! \brief The size of the input tensor in bytes */ - Integer input_size; - /*! \brief The size of the output tensor in bytes */ - Integer output_size; - /*! \brief The name of the function */ - String function_name; + /*! \brief The information regarding the base addresses */ + Array base_addresses; void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("function_name", &function_name); v->Visit("command_stream", &command_stream); v->Visit("encoded_constants", &encoded_constants); - v->Visit("scratch_size", &scratch_size); - v->Visit("input_size", &input_size); - v->Visit("output_size", &output_size); - v->Visit("function_name", &function_name); + v->Visit("base_addresses", &base_addresses); } bool SEqualReduce(const CompilationArtifactNode* other, SEqualReducer equal) const { - return equal(command_stream, other->command_stream) && + return equal(function_name, other->function_name) && + equal(command_stream, other->command_stream) && equal(encoded_constants, other->encoded_constants) && - equal(scratch_size, other->scratch_size) && equal(input_size, other->input_size) && - equal(output_size, other->output_size) && equal(function_name, other->function_name); + equal(base_addresses, other->base_addresses); } void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(function_name); hash_reduce(command_stream); hash_reduce(encoded_constants); - hash_reduce(scratch_size); - hash_reduce(input_size); - hash_reduce(output_size); - hash_reduce(function_name); + hash_reduce(base_addresses); } static constexpr const char* _type_key = "relay.ext.ethos-u.CompilationArtifact"; @@ -83,8 +127,8 @@ struct CompilationArtifactNode : public Object { class CompilationArtifact : public ObjectRef { public: - TVM_DLL CompilationArtifact(String command_stream, String encoded_constants, Integer scratch_size, - Integer input_size, Integer output_size, String function_name); + TVM_DLL CompilationArtifact(String function_name, String command_stream, String encoded_constants, + Array base_addresses); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CompilationArtifact, ObjectRef, CompilationArtifactNode); }; diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index 6365e09246fc..fc43e1449d6a 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -60,12 +60,16 @@ PrimFunc MakeUnpackedAPI(PrimFunc&& func) { // Collect variables and buffers to map between Array args; + Map new_buffer_map; for (const Var& param : func->params) { // Ideally all func params should have Buffers defined in the buffer_map // We should look to insert buffer_maps for all PrimFuncs that are returned // to the core compiler. if (func->buffer_map.find(param) != func->buffer_map.end()) { args.push_back(func->buffer_map[param]->data); + // Rewiring the buffer_var to map to Buffers for low-level passes + // retain information about the buffer. + new_buffer_map.Set(func->buffer_map[param]->data, func->buffer_map[param]); } else { args.push_back(param); } @@ -79,6 +83,7 @@ PrimFunc MakeUnpackedAPI(PrimFunc&& func) { func_ptr->body = MergeNest(device_init, func_ptr->body); func_ptr->params = args; func_ptr->ret_type = PrimType(DataType::Int(32)); + func_ptr->buffer_map = new_buffer_map; // return the function. return std::move(func); diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py index c14deb636c25..0cadf96e7a18 100644 --- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py +++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py @@ -227,8 +227,16 @@ def test_buffer_info_extraction(): "uint8", tir_to_cs_translator.BufferType.input_or_output, ), - "ethosu_conv2d_2": ([1024], "uint8", tir_to_cs_translator.BufferType.scratch), - "ethosu_conv2d_3": ([2048], "uint8", tir_to_cs_translator.BufferType.scratch), + "ethosu_conv2d_2": ( + [1024], + "uint8", + tir_to_cs_translator.BufferType.scratch, + ), + "ethosu_conv2d_3": ( + [2048], + "uint8", + tir_to_cs_translator.BufferType.scratch, + ), }, }, ] @@ -805,12 +813,10 @@ def _check_buffer(address, region, length, buffer_var): # Every buffer is adjusted to align to 16 bytes size_in_bytes = util.round_up(size_in_bytes, 16) assert address + size_in_bytes <= scratch_size - # The scratch area should not be used by anyother buffer - assert not scratch_allocation_mask[address : address + size_in_bytes].any() + # The scratch area should not be used by any other buffer + assert not scratch_mask[address : address + size_in_bytes].any() # The scratch area is marked as used - scratch_allocation_mask[address : address + size_in_bytes] = np.ones( - size_in_bytes, dtype="uint8" - ) + scratch_mask[address : address + size_in_bytes] = np.ones(size_in_bytes, dtype="uint8") elif buffer_type == tir_to_cs_translator.BufferType.input: assert address == 0 else: @@ -887,14 +893,16 @@ def check_buffer(address, region, length, buffer_var): for extern_call in extern_calls: _npu_ops.append(tir_to_cs_translator.translate_ethosu_tir_call_extern(extern_call)) npu_op_tir_buffers = collect_tir_buffer_info(_npu_ops) - _npu_ops, constant_hex_string, scratch_size = tir_to_cs_translator.assign_addresses( - buffer_info, _npu_ops - ) - scratch_allocation_mask = np.zeros(scratch_size, dtype="uint8") + ( + _npu_ops, + constant_hex_string, + scratch_size, + ) = tir_to_cs_translator.assign_addresses(buffer_info, _npu_ops) + scratch_mask = np.zeros(scratch_size, dtype="uint8") constant_tensor_read_mask = np.zeros(len(constant_hex_string) // 2, dtype="uint8") verify(_npu_ops) # This will be only 1 if all allocated scratch is used. - assert np.prod(scratch_allocation_mask) == 1 + assert np.prod(scratch_mask) == 1 # This will be only 1 if all constant tensors is read at least once. assert np.prod(constant_tensor_read_mask) == 1