Skip to content

Commit

Permalink
return csourcemodule from external codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Nov 25, 2019
1 parent 9867f79 commit 73bebb5
Show file tree
Hide file tree
Showing 17 changed files with 544 additions and 1,016 deletions.
20 changes: 11 additions & 9 deletions cmake/modules/contrib/Extern.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

message(STATUS "Build with relay.backend.contrib")

list(FIND USE_EXTERN "gcc" _gcc_idx)
if(_gcc_idx GREATER -1)
list(FIND USE_EXTERN "gcc" GCC_IDX)
if(GCC_IDX GREATER -1)
file(GLOB GCC_RELAY_CONTRIB_SRC src/relay/backend/contrib/gcc/codegen.cc)
list(APPEND COMPILER_SRCS ${GCC_RELAY_CONTRIB_SRC})

Expand All @@ -27,13 +27,15 @@ if(_gcc_idx GREATER -1)
message(STATUS "Use extern library: GCC")
endif()

list(FIND USE_EXTERN "dnnl" _dnnl_idx)
if(_dnnl_idx GREATER -1)
file(GLOB DNNL_RELAY_CONTRIB_SRC src/relay/backend/contrib/dnnl/codegen.cc)
list(APPEND COMPILER_SRCS ${DNNL_RELAY_CONTRIB_SRC})
list(FIND USE_EXTERN "dnnl" DNNL_IDX)
if(DNNL_IDX GREATER -1)
file(GLOB DNNL_RELAY_CONTRIB_SRC src/relay/backend/contrib/dnnl/codegen.cc)
list(APPEND COMPILER_SRCS ${DNNL_RELAY_CONTRIB_SRC})

file(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/*)
list(APPEND RUNTIME_SRCS ${DNNL_CONTRIB_SRC})
message(STATUS "Use extern library: MKLDNN")
find_library(EXTERN_LIBRARY_DNNL dnnl)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_LIBRARY_DNNL})
file(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/*)
list(APPEND RUNTIME_SRCS ${DNNL_CONTRIB_SRC})
message(STATUS "Use extern library: MKLDNN" ${EXTERN_LIBRARY_DNNL})
endif()

41 changes: 3 additions & 38 deletions include/tvm/runtime/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#ifndef TVM_RUNTIME_VM_H_
#define TVM_RUNTIME_VM_H_

#include <tvm/relay/expr.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
Expand Down Expand Up @@ -140,7 +139,6 @@ enum class Opcode {
LoadConsti = 14U,
Fatal = 15U,
AllocStorage = 16U,
InvokeExternal = 17U,
};

/*! \brief A single virtual machine instruction.
Expand Down Expand Up @@ -204,16 +202,6 @@ struct Instruction {
/*! \brief The arguments to pass to the packed function. */
RegName* packed_args;
};
struct /* InvokeExternal Operands */ {
/*! \brief The index into the external function table. */
Index ext_index;
/*! \brief The arity of the external function. */
Index ext_arity;
/*! \brief The number of outputs produced by the external function. */
Index ext_output_size;
/*! \brief The arguments to pass to the external function. */
RegName* ext_args;
};
struct /* If Operands */ {
/*! \brief The register containing the test value. */
RegName test;
Expand Down Expand Up @@ -301,7 +289,7 @@ struct Instruction {
*/
static Instruction InvokePacked(Index packed_index, Index arity, Index output_size,
const std::vector<RegName>& args);
/*!
/*!
* \brief Construct an allocate tensor instruction with constant shape.
* \param storage The storage to allocate out of.
* \param shape The shape of the tensor.
Expand All @@ -311,16 +299,6 @@ struct Instruction {
*/
static Instruction AllocTensor(RegName storage,
const std::vector<int64_t>& shape, DLDataType dtype, RegName dst);
/*!
* \brief Construct an invoke external instruction.
* \param packed_index The index of the external function.
* \param ext_arity The arity of the function.
* \param ext_output_size The number of outputs of the external function.
* \param args The argument registers.
* \return The invoke external instruction.
*/
static Instruction InvokeExternal(Index external_index, Index ext_arity, Index ext_output_size,
const std::vector<RegName>& args);
/*!
* \brief Construct an allocate tensor instruction with register.
* \param storage The storage to allocate out of.
Expand Down Expand Up @@ -611,13 +589,9 @@ class Executable : public ModuleNode {
return "VMExecutable";
}

/*!
* \brief The runtime module/library that contains both the host and also the device
* code when executing on non-CPU devices.
*/
/*! \brief The runtime module/library that contains both the host and also the device
* code when executing on non-CPU devices. */
runtime::Module lib;
/*! \brief The external module/library. */
std::vector<runtime::Module> ext_libs;
/*! \brief The global constant pool. */
std::vector<ObjectRef> constants;
/*! \brief A map from globals (as strings) to their index in the function map. */
Expand All @@ -628,13 +602,6 @@ class Executable : public ModuleNode {
std::unordered_map<std::string, Index> primitive_map;
/*! \brief The virtual machine's function table. */
std::vector<VMFunction> functions;
/*! \brief A mapping from the subgraph id to the external library index in the
* `ext_libs`.
*/
std::unordered_map<Index, Index> external_map;
/*! \brief A mapping from the subgraph id to the external function name.
*/
std::unordered_map<Index, std::string> external_func_map;

private:
/*!
Expand Down Expand Up @@ -747,8 +714,6 @@ class VirtualMachine : public runtime::ModuleNode {
protected:
/*! \brief The virtual machine's packed function table. */
std::vector<PackedFunc> packed_funcs_;
/*! \brief The virtual machine's external function table. */
std::vector<PackedFunc> external_funcs;
/*! \brief The current stack of call frames. */
std::vector<VMFrame> frames_;
/*! \brief The fuction table index of the current function. */
Expand Down
142 changes: 133 additions & 9 deletions src/relay/backend/contrib/contrib_codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
#define TVM_RELAY_BACKEND_CONTRIB_CONTRIB_CODEGEN_H_

#include <tvm/relay/expr.h>
#include <sstream>
#include <string>
#include "../../../runtime/contrib/extern_common.h"

namespace tvm {
namespace relay {
Expand All @@ -37,17 +37,16 @@ class ExternCodegenBase {
ExternCodegenBase() = default;

/*!
* \brief Compile the external library.
*/
virtual void CompileExternLib() = 0;

/*!
* \brief Build the shared library of external ops.
* \brief Create a runtime module for the external library. For example, it
* could be a CSourceModule that can be directly compiled and linked together
* with a DSOModule, or a json style module that emitts a json artifact that
* is able to be executed by a customized json runtime.
*
* \param ref The subgraph Relay expression/module to be executed using extern ops.
*
* \return A runtime module.
*/
virtual void Build(const NodeRef& ref) = 0;
virtual runtime::Module CreateExternModule(const NodeRef& ref) = 0;

/*!
* \brief Split the Relay function name to tokens.
Expand All @@ -61,8 +60,133 @@ class ExternCodegenBase {
FunctionGetAttr(func, attr::kFuncName).as<tvm::ir::StringImm>();
CHECK(name_node != nullptr) << "Fail to retrieve subgraph name.";
std::string name = name_node->value;
return runtime::contrib::GetSubgraphID(name);
return GetSubgraphID(name);
}

/*!
* \brief Split the encoded function name to tokens.
*
* \param the function name string.
*
* \return a vector of tokenized function name splitted by "_".
*/
std::string GetSubgraphID(const std::string& name) const {
std::string temp = name;
std::vector<std::string> tokens;
std::string delimiter = "_";
size_t pos = 0;
std::string token;
while ((pos = temp.find(delimiter)) != std::string::npos) {
token = temp.substr(0, pos);
tokens.push_back(token);
temp.erase(0, pos + delimiter.length());
}
tokens.push_back(temp);

CHECK(tokens.size() >= 2) << "Invalid subgraph name: " << name;
CHECK(tokens[0] == "subgraph")
<< "Function name does not start with \"subgraph\": " << name;
return tokens[1];
}
};

// A helper class to write the declaration of external functions.
class ExternSourcePrinter {
protected:
/*! \brief Print indents using spaces. */
void PrintIndents() {
for (int i = 0; i < indent_; i++) {
code_stream_ << ' ';
}
}

/*!
* \brief Enter a new scope.
*/
void EnterScope() { indent_ += 2; }

/*!
* \brief Exit a scope.
*/
void ExitScope() {
CHECK_GE(indent_, 2U) << "Wrong ident found.";
indent_ -= 2;
}

/*!
* \brief Gerenate a wrapper for the subgraph that will use external codegen.
*
* \param func_name The name of wrapper function.
* \param arg_cnt The expected number of arguments for the wrapper.
*
* \code
*
* // An example code for the wrapper.
* extern "C" void foo(TVMValue* value, int* type_code, int nargs) {
* if (nargs != 3) {
* printf("foo expects 3 args, but received %d\n", nargs);
* return 1;
* }
*
* DLTensor* arg0 = static_cast<DLTensor*>(value[0].v_handle);
* DLTensor* arg1 = static_cast<DLTensor*>(value[1].v_handle);
* DLTensor* out = static_cast<DLTensor*>(value[2].v_handle);
*
* foo_(static_cast<float*>(arg0->data),
* static_cast<float*>(arg1->data),
* static_cast<float*>(out->data));
* return 0;
* }
*
* \endcode
*/
void GenerateSubgraphWrapper(const std::string& func_name, int arg_cnt) {
// Print signature
code_stream_ << "\n";
code_stream_ << "extern \"C\" int " << func_name;
code_stream_ << "(TVMValue* value, int* type_code, int nargs) {\n";
EnterScope();
// Print guard
PrintIndents();
code_stream_ << "if (nargs != " << arg_cnt << "){\n";
EnterScope();
PrintIndents();
code_stream_ << "printf(\"" << func_name << " expects " << arg_cnt
<< "arguments, but received %d\\n\", nargs);\n";
PrintIndents();
code_stream_ << "return 1;\n";
ExitScope();
PrintIndents();
code_stream_ << "}\n";

// According to TVM's calling convention, the last one is output.
for (int i = 0; i < arg_cnt; i++) {
PrintIndents();
code_stream_ << "DLTensor* arg" << i << " = "
<< "static_cast<DLTensor*>(value[" << i << "].v_handle);\n";
}
// Generate the call.
PrintIndents();
code_stream_ << func_name << "_(";
for (int i = 0; i < arg_cnt - 1; i++) {
code_stream_ << "static_cast<float*>(arg" << i << "->data), ";
}
if (arg_cnt > 0) {
code_stream_ << "static_cast<float*>(arg" << arg_cnt - 1 << "->data)";
}
code_stream_ << ");\n\n";
PrintIndents();
code_stream_ << "return 0;\n";
ExitScope();
code_stream_ << "}";
}

/*! \brief The external function source code stream. */
std::ostringstream code_stream_;

private:
/*! \brief Indent of the source code. */
int indent_{0};
};

} // namespace contrib
Expand Down
Loading

0 comments on commit 73bebb5

Please sign in to comment.