Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOC][Verilator] Refactor Verilator runtime #7406

Merged
merged 10 commits into from
Feb 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 34 additions & 22 deletions src/relay/backend/contrib/verilator/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <sstream>

#include "../../../../runtime/contrib/json/json_node.h"
#include "../../../../runtime/contrib/verilator/verilator_runtime.h"
#include "../../utils.h"
#include "../codegen_json/codegen_json.h"

Expand Down Expand Up @@ -75,29 +76,34 @@ class VerilatorJSONSerializer : public backend::contrib::JSONSerializer {
}
};

/*! \brief Attributes to store the compiler options for Verilator */
struct VerilatorCompilerConfigNode : public tvm::AttrsNode<VerilatorCompilerConfigNode> {
String lib;

TVM_DECLARE_ATTRS(VerilatorCompilerConfigNode, "ext.attrs.VerilatorCompilerConfigNode") {
TVM_ATTR_FIELD(lib).set_default("libverilator.so");
/*! \brief Attributes to store options for Verilator */
struct VerilatorOptionsNode : public tvm::AttrsNode<VerilatorOptionsNode> {
String lib_path;
int reset_cycles;
bool profiler_enable;
int profiler_cycle_counter_id;

TVM_DECLARE_ATTRS(VerilatorOptionsNode, "ext.attrs.VerilatorOptionsNode") {
TVM_ATTR_FIELD(lib_path).describe("the design library path").set_default("libverilator.so");
TVM_ATTR_FIELD(reset_cycles).describe("the number of reset cycles").set_default(1);
TVM_ATTR_FIELD(profiler_enable).describe("enable profiler").set_default(false);
TVM_ATTR_FIELD(profiler_cycle_counter_id).describe("profiler cycle counter id").set_default(0);
}
};

class VerilatorCompilerConfig : public Attrs {
class VerilatorOptions : public Attrs {
public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(VerilatorCompilerConfig, Attrs,
VerilatorCompilerConfigNode);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(VerilatorOptions, Attrs, VerilatorOptionsNode);
};

TVM_REGISTER_NODE_TYPE(VerilatorCompilerConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.verilator.options", VerilatorCompilerConfig);
TVM_REGISTER_NODE_TYPE(VerilatorOptionsNode);
TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.verilator.options", VerilatorOptions);

/*!
* \brief The external compiler/codegen tool. It takes a Relay expression/module and
* compile it into a runtime module.
* \brief The Verilator codegen tool. It takes a Relay expression/module and
* compile it into a Verilator runtime module.
*/
runtime::Module VerilatorCompiler(const ObjectRef& ref) {
runtime::Module VerilatorBackend(const ObjectRef& ref) {
CHECK(ref->IsInstance<FunctionNode>());
auto func = Downcast<Function>(ref);
auto func_name = GetExtSymbol(func);
Expand All @@ -106,22 +112,28 @@ runtime::Module VerilatorCompiler(const ObjectRef& ref) {
std::string graph_json = serializer.GetJSON();
auto params = serializer.GetParams();

// Create runtime object
auto n = make_object<runtime::contrib::VerilatorRuntime>(func_name, graph_json, params);

// Get Verilator compiler options
auto ctx = transform::PassContext::Current();
auto cfg = ctx->GetConfig<VerilatorCompilerConfig>("relay.ext.verilator.options");
auto cfg = ctx->GetConfig<VerilatorOptions>("relay.ext.verilator.options");
if (!cfg.defined()) {
cfg = AttrsWithDefaultValues<VerilatorCompilerConfig>();
cfg = AttrsWithDefaultValues<VerilatorOptions>();
}

auto lib_name = cfg.value()->lib;
n->SetLibrary(cfg.value()->lib_path);
n->SetResetCycles(cfg.value()->reset_cycles);

if (cfg.value()->profiler_enable) {
n->EnableProfiler();
n->SetProfilerCycleCounterId(cfg.value()->profiler_cycle_counter_id);
}

const auto* pf = runtime::Registry::Get("runtime.verilator_runtime_create");
CHECK(pf != nullptr) << "Cannot find JSON runtime module to create";
auto mod = (*pf)(lib_name, func_name, graph_json, params);
return mod;
return runtime::Module(n);
}

TVM_REGISTER_GLOBAL("relay.ext.verilator").set_body_typed(VerilatorCompiler);
TVM_REGISTER_GLOBAL("relay.ext.verilator").set_body_typed(VerilatorBackend);

} // namespace contrib
} // namespace relay
Expand Down
39 changes: 33 additions & 6 deletions src/runtime/contrib/verilator/verilator_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,51 @@ namespace tvm {
namespace runtime {
namespace contrib {

/*! \brief Verilator device resource context */
typedef void* VerilatorHandle;

/* allocate Verilator object */
/*!
* \brief Allocate a verilator device resource handle
* \return The verilator device handle.
*/
extern "C" TVM_DLL VerilatorHandle VerilatorAlloc();

/* deallocate Verilator object */
/*!
* \brief Free a verilator device handle
* \param handle The verilator device handle to be freed.
*/
extern "C" TVM_DLL void VerilatorDealloc(VerilatorHandle handle);

/* read Verilator register or memory */
/*!
* \brief Read verilator register or memory
* \param handle The verilator device handle.
* \param id The register or memory identifier.
* \param addr The register or memory address (word-level).
* \return The value of register or memory.
*/
extern "C" TVM_DLL int VerilatorRead(VerilatorHandle handle, int id, int addr);

/* write Verilator register or memory */
/*!
* \brief Write verilator register or memory
* \param handle The verilator device handle.
* \param id The register or memory identifier.
* \param addr The register or memory address (word-level).
* \param value The value of register or memory.
*/
extern "C" TVM_DLL void VerilatorWrite(VerilatorHandle handle, int id, int addr, int value);

/* reset Verilator for n clock cycles */
/*!
* \brief Reset Verilator for n clock cycles
* \param handle The verilator device handle.
* \param n The number of reset cycles.
*/
extern "C" TVM_DLL void VerilatorReset(VerilatorHandle handle, int n);

/* run Verilator for n clock cycles */
/*!
* \brief Run Verilator for n clock cycles
* \param handle The verilator device handle.
* \param n The number of run cycles.
*/
extern "C" TVM_DLL void VerilatorRun(VerilatorHandle handle, int n);

} // namespace contrib
Expand Down
197 changes: 99 additions & 98 deletions src/runtime/contrib/verilator/verilator_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@

/*!
* \file src/runtime/contrib/verilator/verilator_runtime.cc
* \brief A simple JSON runtime for Verilator.
* \brief A runtime for Verilator.
*/

#include "verilator_runtime.h"

#include <dlfcn.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/registry.h>
Expand All @@ -40,124 +42,123 @@ namespace tvm {
namespace runtime {
namespace contrib {

typedef VerilatorHandle (*VerilatorAllocFunc)();
typedef void (*VerilatorResetFunc)(VerilatorHandle, int);
typedef void (*VerilatorAddFunc)(VerilatorHandle, int*, int*, int*, int, int);

using namespace tvm::runtime;
using namespace tvm::runtime::contrib;
using namespace tvm::runtime::json;

class VerilatorLibrary : public Library {
public:
~VerilatorLibrary() {
if (lib_handle_) Unload();
}
void Init(const std::string& name) { Load(name); }

void* GetSymbol(const char* name) final { return GetSymbol_(name); }

private:
// Library handle
void* lib_handle_{nullptr};
// load the library
void Load(const std::string& name) {
lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
ICHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name << " " << dlerror();
}

void* GetSymbol_(const char* name) { return dlsym(lib_handle_, name); }

void Unload() {
VerilatorLibrary::~VerilatorLibrary() {
if (lib_handle_) {
dlclose(lib_handle_);
lib_handle_ = nullptr;
}
};
}

class VerilatorJSONRuntime : public JSONRuntimeBase {
public:
VerilatorJSONRuntime(const std::string& symbol_name, const std::string& graph_json,
const Array<String> const_names)
: JSONRuntimeBase(symbol_name, graph_json, const_names) {}
void VerilatorLibrary::Load(const std::string& name) {
lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name << " "
<< dlerror();
}

const char* type_key() const { return "verilator_json"; }
void* VerilatorLibrary::GetSymbol(const char* name) { return dlsym(lib_handle_, name); }

void LoadLibrary(const std::string& lib_name) {
lib_ = new VerilatorLibrary();
lib_->Init(lib_name);
}
void VerilatorProfiler::Clear() { cycle_counter = 0; }

void Init(const Array<NDArray>& consts) override {
// get symbols
auto alloc_func = reinterpret_cast<VerilatorAllocFunc>(lib_->GetSymbol("VerilatorAlloc"));
ICHECK(alloc_func != nullptr);
auto reset_func = reinterpret_cast<VerilatorResetFunc>(lib_->GetSymbol("VerilatorReset"));
ICHECK(reset_func != nullptr);
vadd_func_ = reinterpret_cast<VerilatorAddFunc>(lib_->GetSymbol("verilator_add"));
ICHECK(vadd_func_ != nullptr);
std::string VerilatorProfiler::AsJSON() {
std::ostringstream os;
os << "{\n"
<< " \"cycle_counter\":" << cycle_counter << "\n"
<< "}\n";
return os.str();
}

// alloc device
device_ = (*alloc_func)();
VerilatorProfiler* VerilatorProfiler::ThreadLocal() {
static thread_local VerilatorProfiler inst;
return &inst;
}

// reset for 10 cycles
(*reset_func)(device_, 10);
VerilatorRuntime::~VerilatorRuntime() {
auto dealloc = reinterpret_cast<VerilatorDeallocFunc>(lib_->GetSymbol("VerilatorDealloc"));
ICHECK(dealloc != nullptr);
dealloc(device_);
lib_->~VerilatorLibrary();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just out of curiosity, why do we explicitly call the deallocator instead of delete lib_?

}

CHECK_EQ(consts.size(), const_idx_.size())
<< "The number of input constants must match the number of required.";
void VerilatorRuntime::SetLibrary(const std::string& lib_path) { lib_path_ = lib_path; }

// Setup constants entries for weights.
SetupConstants(consts);
}
void VerilatorRuntime::SetResetCycles(const int cycles) { reset_cycles_ = cycles; }

void Run() override {
std::vector<int*> in_ptr;
std::vector<int*> out_ptr;
for (size_t i = 0; i < input_nodes_.size(); ++i) {
uint32_t eid = EntryID(input_nodes_[i], 0);
int* data = static_cast<int*>(data_entry_[eid]->data);
in_ptr.push_back(data);
}
for (size_t i = 0; i < outputs_.size(); ++i) {
uint32_t eid = EntryID(outputs_[i]);
int* data = static_cast<int*>(data_entry_[eid]->data);
out_ptr.push_back(data);
}
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
const auto& node = nodes_[nid];
if (node.GetOpType() == "kernel") {
CHECK_EQ(node.GetOpType(), "kernel");
auto op_name = node.GetOpName();
if ("add" == op_name) {
auto entry = node.GetInputs()[0];
auto shape = nodes_[entry.id_].GetOpShape()[entry.index_];
(*vadd_func_)(device_, in_ptr[0], in_ptr[1], out_ptr[0], shape[0], shape[1]);
} else {
LOG(FATAL) << "Unsupported op: " << op_name;
}
void VerilatorRuntime::EnableProfiler() { prof_enable_ = true; }

void VerilatorRuntime::SetProfilerCycleCounterId(const int id) { prof_cycle_counter_id_ = id; }

void VerilatorRuntime::Init(const Array<NDArray>& consts) {
lib_ = new VerilatorLibrary();
lib_->Load(lib_path_);
auto alloc = reinterpret_cast<VerilatorAllocFunc>(lib_->GetSymbol("VerilatorAlloc"));
ICHECK(alloc != nullptr);
auto reset = reinterpret_cast<VerilatorResetFunc>(lib_->GetSymbol("VerilatorReset"));
ICHECK(reset != nullptr);
read_ = reinterpret_cast<VerilatorReadFunc>(lib_->GetSymbol("VerilatorRead"));
ICHECK(read_ != nullptr);
add_op_ = reinterpret_cast<VerilatorAddFunc>(lib_->GetSymbol("verilator_add"));

// alloc verilator device
device_ = alloc();

// enable profiler
if (prof_enable_) prof_ = VerilatorProfiler::ThreadLocal();

// reset verilator device
reset(device_, reset_cycles_);

CHECK_EQ(consts.size(), const_idx_.size())
<< "The number of input constants must match the number of required.";

// Setup constants entries for weights.
SetupConstants(consts);
}

void VerilatorRuntime::Run() {
std::vector<int*> in_ptr;
std::vector<int*> out_ptr;
for (size_t i = 0; i < input_nodes_.size(); ++i) {
uint32_t eid = EntryID(input_nodes_[i], 0);
int* data = static_cast<int*>(data_entry_[eid]->data);
in_ptr.push_back(data);
}
for (size_t i = 0; i < outputs_.size(); ++i) {
uint32_t eid = EntryID(outputs_[i]);
int* data = static_cast<int*>(data_entry_[eid]->data);
out_ptr.push_back(data);
}
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
const auto& node = nodes_[nid];
if (node.GetOpType() == "kernel") {
CHECK_EQ(node.GetOpType(), "kernel");
auto op_name = node.GetOpName();
if ("add" == op_name) {
auto entry = node.GetInputs()[0];
auto shape = nodes_[entry.id_].GetOpShape()[entry.index_];
ICHECK(add_op_ != nullptr);
add_op_(device_, in_ptr[0], in_ptr[1], out_ptr[0], shape[0], shape[1]);
} else {
LOG(FATAL) << "Unsupported op: " << op_name;
}
}
}

private:
/* The verilator device handle. */
VerilatorHandle device_{nullptr};
/* The verilator library handle. */
VerilatorLibrary* lib_{nullptr};
/* The verilator vadd function handle. */
VerilatorAddFunc vadd_func_{nullptr};
};

runtime::Module VerilatorJSONRuntimeCreate(String lib_name, String symbol_name, String graph_json,
const Array<String>& const_names) {
auto n = make_object<VerilatorJSONRuntime>(symbol_name, graph_json, const_names);
n->LoadLibrary(lib_name);
return runtime::Module(n);
if (prof_enable_) {
int cycles = read_(device_, prof_cycle_counter_id_, 0);
prof_->cycle_counter += cycles;
}
}

TVM_REGISTER_GLOBAL("runtime.verilator_runtime_create").set_body_typed(VerilatorJSONRuntimeCreate);
TVM_REGISTER_GLOBAL("verilator.profiler_clear").set_body([](TVMArgs args, TVMRetValue* rv) {
VerilatorProfiler::ThreadLocal()->Clear();
});

TVM_REGISTER_GLOBAL("runtime.module.loadbinary_verilator_json")
.set_body_typed(JSONRuntimeBase::LoadFromBinary<VerilatorJSONRuntime>);
TVM_REGISTER_GLOBAL("verilator.profiler_status").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = VerilatorProfiler::ThreadLocal()->AsJSON();
});

} // namespace contrib
} // namespace runtime
Expand Down
Loading