Skip to content

Commit

Permalink
[microNPU] Refactor base address determination to codegen (apache#9929)
Browse files Browse the repository at this point in the history
This commit introduces BaseAddress ObjectRef to determine
base addresses in the codegen for microNPU. This is
required when multiple memory pools become available. Thus,
base addresses could not be statically determined in the
source module.
  • Loading branch information
manupak authored and ylc committed Feb 16, 2022
1 parent 40869dd commit 7e50fc4
Show file tree
Hide file tree
Showing 8 changed files with 258 additions and 120 deletions.
12 changes: 2 additions & 10 deletions python/tvm/relay/backend/contrib/ethosu/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,6 @@ def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc:
This returns the scheduled PrimFunc
"""
assert len(ext_func.params) == 1
input_size = util.calculate_size_bytes(ext_func.params[0])
output_size = util.calculate_size_bytes(ext_func.body)
mod = tvm.IRModule()
mod["main"] = ext_func
mod = LegalizeEthosU()(mod)
Expand All @@ -317,8 +315,6 @@ def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc:
primfunc = tir_mod["main"]
primfunc = primfunc.with_attr("global_symbol", ext_func.attrs["global_symbol"])
primfunc = primfunc.with_attr("ethos-u.constants", const_dict)
primfunc = primfunc.with_attr("ethos-u.input_size", input_size)
primfunc = primfunc.with_attr("ethos-u.output_size", output_size)
return primfunc


Expand All @@ -342,18 +338,14 @@ def primfunc_to_artifact(primfunc: tvm.tir.PrimFunc) -> util.CompilationArtifact
"""
symbol = str(primfunc.attrs["global_symbol"])
const_dict = primfunc.attrs["ethos-u.constants"]
input_size = primfunc.attrs["ethos-u.input_size"]
output_size = primfunc.attrs["ethos-u.output_size"]
tir_mod = tvm.IRModule()
tir_mod[symbol] = primfunc

const_dict_with_int_keys = dict()
for idx in const_dict.keys():
const_dict_with_int_keys[int(idx)] = const_dict[idx].numpy()

cmms, encoded_constants, scratch_size = tir_to_cs_translator.translate(
cmms, encoded_constants, base_addresses = tir_to_cs_translator.translate(
tir_mod, const_dict_with_int_keys
)
return util.CompilationArtifact(
cmms, encoded_constants, scratch_size, input_size, output_size, symbol
)
return util.CompilationArtifact(symbol, cmms, encoded_constants, base_addresses)
62 changes: 58 additions & 4 deletions python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
the Relay to TIR compilation process, to Vela API calls to
generate command stream.
"""
from typing import Dict, NamedTuple, Tuple, Union
from typing import Dict, NamedTuple, Tuple, Union, List
from enum import auto
from enum import Enum
import numpy as np # type: ignore
Expand Down Expand Up @@ -102,8 +102,8 @@ def translate(tir_module, params):
encoded_constants : str
An hex string of the bytes that includes concat'd
encoded weights, encoded biases and scales.
scratch_size : int
The size of the scratch buffer needed.
base_addresses : List[util.BaseAddress]
base addresses to be used by the driver
"""

buffer_info = extract_buffer_info(tir_module, params)
Expand All @@ -112,10 +112,60 @@ def translate(tir_module, params):
for call_extern in call_extern_list:
_npu_ops.append(translate_ethosu_tir_call_extern(call_extern))
_npu_ops, constant_data, scratch_size = assign_addresses(buffer_info, _npu_ops)
base_addresses = extract_param_base_addresses(tir_module, buffer_info)
if scratch_size > 0:
base_addresses.append(
util.BaseAddress(
"scratch",
None,
_REGION_MAP[BufferType.scratch],
scratch_size,
True,
)
)
target_accel_config = vela_api.get_accelerator_config()
cmds = vapi.npu_generate_register_command_stream(_npu_ops, target_accel_config)
payload = vapi.npu_create_driver_payload(cmds, target_accel_config)
return payload.hex(), constant_data, scratch_size
return payload.hex(), constant_data, base_addresses


def extract_param_base_addresses(mod, buffer_info) -> List[util.BaseAddress]:
"""This function extracts base addresses to be used by the driver
Parameters
----------
mod : tvm.IRModule
The TIR Module for NPU
buffer_info : Dict[tvm.tir.Var, BufferInfo]
Information regarding buffer vars used in the PrimFunc
Returns
-------
List[util.BaseAddress]
base addresses to be used by the driver
"""
# There should only be a single function
assert len(mod.functions.items()) == 1
primfunc = mod.functions.items()[0][1]

base_addresses = list()
idx = 0
for param in primfunc.params:
# constants are pooled together and handled specially
# this will change after tir.allocate_const.
# For now, we are skipping generating buffer addresses here
if buffer_info[param].btype == BufferType.constant:
continue
buffer = primfunc.buffer_map[param]
dtype = buffer.dtype
element_size_bytes = np.iinfo(dtype).bits // 8
size_bytes = element_size_bytes * np.prod(list(buffer.shape))
base_addresses.append(
util.BaseAddress(param.name, idx, _REGION_MAP[buffer_info[param].btype], size_bytes)
)
idx += 1

return base_addresses


def extract_call_extern_list(mod):
Expand Down Expand Up @@ -171,6 +221,7 @@ def extract_buffer_info(
# There should only be a single function
assert len(mod.functions.items()) == 1
primfunc = mod.functions.items()[0][1]

for idx, const_data in param_dict.items():
param = primfunc.params[idx]
buffer_info[param] = BufferInfo(
Expand Down Expand Up @@ -301,6 +352,9 @@ def classify_io(buffer):
assert buffer_type in (BufferType.input, BufferType.output)
address = 0
buffer_addresses[_buffer] = (address, buffer_type)
buffer_info[_buffer] = BufferInfo(
values=None, shape=info.dtype, dtype=info.dtype, btype=buffer_type
)
elif info.btype == BufferType.shram:
accl_config = util.get_accelerator_config()
arch_config = get_accelerator_arch_config(accl_config)
Expand Down
39 changes: 30 additions & 9 deletions python/tvm/relay/backend/contrib/ethosu/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from inspect import signature
from enum import Enum
from typing import Union, Tuple
from typing import Union, Tuple, List
import numpy as np # type: ignore

import tvm # type: ignore
Expand Down Expand Up @@ -239,6 +239,31 @@ def calculate_size_bytes(expr):
return element_size * elements


@register_object("relay.ext.ethos-u.BaseAddress")
class BaseAddress(Object):
"""
This is a structure to hold base addresses for pointers
provided for the driver.
"""

def __init__(
self,
name: str,
primfunc_param_idx: int,
region: int,
size: int,
is_runtime_allocation: bool = False,
):
self.__init_handle_by_constructor__(
_ffi_api.BaseAddress, # type: ignore # pylint: disable=no-member
name,
primfunc_param_idx,
region,
size,
is_runtime_allocation,
)


@register_object("relay.ext.ethos-u.CompilationArtifact")
class CompilationArtifact(Object):
"""
Expand All @@ -248,19 +273,15 @@ class CompilationArtifact(Object):

def __init__(
self,
function_name: str,
command_stream: str,
encoded_constants: str,
scratch_size: int,
input_size: int,
output_size: int,
function_name: str,
base_addresses: List[BaseAddress],
):
self.__init_handle_by_constructor__(
_ffi_api.CompilationArtifact, # type: ignore # pylint: disable=no-member
function_name,
command_stream,
encoded_constants,
scratch_size,
input_size,
output_size,
function_name,
base_addresses,
)
93 changes: 45 additions & 48 deletions src/relay/backend/contrib/ethosu/source_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ class EthosUModuleNode : public ModuleNode {
private:
std::string c_source;
Array<CompilationArtifact> compilation_artifacts_;
Map<Integer, String> pool_var_names_;
int indent_{0};
constexpr static int kMaxBaseAddresses_ = 6;

/*!
* \brief Convert the raw string of hex values into a hex string
Expand All @@ -150,7 +152,7 @@ class EthosUModuleNode : public ModuleNode {
* \return string of code that updates the base_addrs array with the base address of the given
* array
*/
std::string SetBaseAddress(int index, std::string name, std::string size) {
std::string SetBaseAddress(int index, std::string name, int size) {
std::stringstream ss;
ss << " base_addrs[" << index << "] = (uintptr_t)(" << name << ");\n";
ss << " base_addrs_size[" << index << "] = " << size << ";\n";
Expand Down Expand Up @@ -178,11 +180,24 @@ class EthosUModuleNode : public ModuleNode {
}

/*!
* \brief Creates a runtime function header
* \brief Creates a runtime function signature
*/
void PrintRuntimeFunctionHeader(std::stringstream& ss, std::string func_name) {
ss << "TVM_DLL int32_t ";
ss << func_name << "(void* input, void* output, void* resource_handle) {\n";
void PrintRuntimeFunctionSignature(std::stringstream& ss,
const relay::contrib::ethosu::CompilationArtifact& artifact,
std::string func_name) {
ss << "TVM_DLL int32_t " << func_name;
ss << "(";
std::unordered_map<int, relay::contrib::ethosu::BaseAddress> param_idx_to_base_address;
for (const relay::contrib::ethosu::BaseAddress& base_address : artifact->base_addresses) {
if (base_address->primfunc_param_idx.defined()) {
param_idx_to_base_address[base_address->primfunc_param_idx] = base_address;
}
}
for (unsigned int i = 0; i < param_idx_to_base_address.size(); i++) {
relay::contrib::ethosu::BaseAddress base_address = param_idx_to_base_address[i];
ss << "void* " << base_address->name << ",";
}
ss << "void* resource_handle) {\n";
}

/*!
Expand Down Expand Up @@ -216,7 +231,6 @@ class EthosUModuleNode : public ModuleNode {
std::stringstream ss;

size_t weights_size = (compilation_artifact->encoded_constants.size() / 2);
size_t scratch_size = compilation_artifact->scratch_size;
ss << "// Update linker script to place .rodata.tvm in memory that can be accessed by the "
"NPU\n";
if (weights_size > 0) {
Expand All @@ -234,61 +248,44 @@ class EthosUModuleNode : public ModuleNode {
ss << "\n";

PrintExternCPrefix(ss);
ss << "static int32_t " << func_no_dashes + "_(int8_t* in0, "
<< "size_t in0_size, int8_t* out0, size_t out0_size, void* resource_handle) {\n";
ss << " int num_tensors = 5;\n";
PrintRuntimeFunctionSignature(ss, compilation_artifact, func_no_dashes);
ss << " void* cms_data = (void*)(" << func_no_dashes << "_cms_data_data);\n";
ss << " int64_t device_type = kDLCPU;\n";
ss << " int64_t device_id = 0;\n";
ss << " const size_t weights_size = " << std::to_string(weights_size) << ";\n";
ss << " const size_t scratch_size = " << std::to_string(scratch_size) << ";\n";
ss << " const size_t cms_data_size = sizeof(" << func_no_dashes << "_cms_data_data);\n";
if (scratch_size > 0) {
ss << " int8_t* scratch = (int8_t*) TVMBackendAllocWorkspace(device_type, device_id, "
"(uint64_t)scratch_size, 0, 16);\n";
} else {
ss << " int8_t* scratch = NULL;\n";
}
ss << " size_t base_addrs_size[num_tensors];\n";
ss << " uint64_t base_addrs[num_tensors];\n";
ss << " size_t base_addrs_size[" << kMaxBaseAddresses_ << "] = {0};\n";
ss << " uint64_t base_addrs[" << kMaxBaseAddresses_ << "] = {0};\n";
ss << "\n";
ss << SetBaseAddress(0, func_no_dashes + "_weights", "weights_size");
ss << SetBaseAddress(1, "scratch", "scratch_size");
ss << SetBaseAddress(2, "scratch", "scratch_size");
ss << SetBaseAddress(3, "in0", "in0_size");
ss << SetBaseAddress(4, "out0", "out0_size");

ss << SetBaseAddress(0, func_no_dashes + "_weights", weights_size);
for (const relay::contrib::ethosu::BaseAddress& base_address :
compilation_artifact->base_addresses) {
if (base_address->is_runtime_allocation) {
ss << " int8_t* " << base_address->name
<< " = (int8_t*) TVMBackendAllocWorkspace(device_type, device_id, "
"(uint64_t)"
<< base_address->size << ", 0, 16);\n";
}
ss << SetBaseAddress(base_address->region->value, base_address->name.c_str(),
base_address->size->value);
}
ss << "\n";

ss << " int32_t result = TVMEthosULaunch(resource_handle, cms_data, cms_data_size, "
"base_addrs, base_addrs_size, num_tensors);\n";
if (scratch_size > 0) {
ss << " TVMBackendFreeWorkspace(device_type, device_id, scratch);\n";
"base_addrs, base_addrs_size, "
<< kMaxBaseAddresses_ << ");\n";

for (const relay::contrib::ethosu::BaseAddress& base_address :
compilation_artifact->base_addresses) {
if (base_address->is_runtime_allocation) {
ss << " TVMBackendFreeWorkspace(device_type, device_id, " << base_address->name << ");\n";
}
}
ss << " return result;\n";
ss << "}\n";
ss << "\n";
PrintExternCPostfix(ss);
ss << "\n";
PrintExternCPrefix(ss);
ss << "// Wrapper function is provided to allow for easier debugging\n";
ss << "inline static int32_t " + func_no_dashes +
"_wrapper_(void* input, void* output, void* resource_handle) {\n";
ss << " size_t input_data_size = " << compilation_artifact->input_size << ";\n";
ss << " size_t output_data_size = " << compilation_artifact->output_size << ";\n";
ss << " return " + func_no_dashes +
"_((int8_t*)input, input_data_size, (int8_t*)output, output_data_size, " +
"resource_handle);\n";
ss << "}\n";
PrintExternCPostfix(ss);
ss << "\n";
PrintExternCPrefix(ss);
PrintRuntimeFunctionHeader(ss, func_no_dashes);
EnterScope();
PrintIndents(ss);
ss << "return " << func_no_dashes << "_wrapper_(input, output, resource_handle);\n";
ExitScope();
ss << "}\n";
PrintExternCPostfix(ss);

return ss.str();
}
};
Expand Down
Loading

0 comments on commit 7e50fc4

Please sign in to comment.