diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index f407436e58683..b2b73e9bad020 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -490,6 +490,11 @@ class RelayBuildModule : public runtime::ModuleNode { auto lowered_funcs = executor_codegen_->GetIRModule(); + // No need to build for external functions. + if (lowered_funcs.find("ext_dev") != lowered_funcs.end()) { + lowered_funcs.Set("ext_dev", IRModule()); + } + // Generate a placeholder function that attaches linked params as its arguments. if (target_host->GetAttr("link-params").value_or(Bool(false))) { CHECK(pf != nullptr) << "Unable to link-params with no target_host and no llvm codegen."; diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 208e6356355de..7840960ec268e 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -195,6 +195,7 @@ class TECompilerImpl : public TECompilerNode { auto target = Target("ext_dev"); auto global_var = GlobalVar(func_name); global_var->checked_type_ = key->source_func->checked_type(); + ir_module->Add(global_var, key->source_func); value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module); return value; } @@ -347,12 +348,6 @@ class LowerTensorExpr : public ExprMutator { << ext_func->prim_fn_var->name_hint; Map prim_fns; - - for (auto prim_fn : ext_func->funcs->functions) { - CHECK(prim_fn.second.as()) << "must be a prim fn"; - prim_fns.Set(prim_fn.first, Downcast(prim_fn.second)); - } - relay::Function func_with_metadata = func; func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", ext_func->prim_fn_var); func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index ddb1911a6b71c..b3eab91d202c2 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -45,7 +45,6 @@ #include #include "../../../target/source/codegen_source_base.h" -#include "../../backend/compile_engine.h" #include "../../op/op_common.h" #include "../../transforms/pass_utils.h" #include "../utils.h" @@ -79,6 +78,7 @@ namespace vm { using namespace tvm::runtime; using namespace tvm::runtime::vm; using namespace relay::transform; +using namespace tec; // (@jroesch): VM passes, eventually declare as passes. bool IsClosure(const Function& func); @@ -253,7 +253,6 @@ class VMFunctionCompiler : ExprFunctor { ExprDeviceMap expr_device_map) : last_register_(0), registers_num_(0), - engine_(CompileEngine::Global()), context_(context), target_host_(target_host), expr_device_map_(std::move(expr_device_map)) { @@ -465,7 +464,7 @@ class VMFunctionCompiler : ExprFunctor { void EmitShapeFunc(Function func, Array inputs, Array outputs) { // Lower shape function CCacheKey key(func, target_host_); - auto cfunc = engine_->LowerShapeFunc(key); + auto cfunc = context_->compiler->LowerShapeFunc(key); int op_index = -1; // pick the only function inside the context ICHECK_EQ(cfunc->funcs->functions.size(), 1); @@ -551,7 +550,7 @@ class VMFunctionCompiler : ExprFunctor { CCacheKey key(func, target); auto mangle_fn = [](String name) { return name; }; - auto cfunc = engine_->Lower(key, mangle_fn); + auto cfunc = context_->compiler->Lower(key, mangle_fn); auto op_index = -1; if (func->GetAttr(attr::kCompiler).defined()) { @@ -857,8 +856,6 @@ class VMFunctionCompiler : ExprFunctor { size_t last_register_; /*! \brief Total number of virtual registers allocated. */ size_t registers_num_; - /*! \brief Compiler engine to lower primitive functions. */ - CompileEngine engine_; /*! \brief Global shared meta data */ VMCompilerContext* context_; /*! \brief Target devices. */ @@ -1134,8 +1131,8 @@ void VMCompiler::Codegen() { } } - auto compile_engine = CompileEngine::Global(); - auto ext_mods = compile_engine->LowerExternalFunctions(); + auto ext_mods = context_.compiler->LowerExternalFunctions(); + runtime::Module lib; if (funcs.size() > 0) { lib = tvm::build(funcs, target_host_); @@ -1146,7 +1143,6 @@ void VMCompiler::Codegen() { } lib = codegen::CreateMetadataModule(params_, lib, ext_mods, target_host_, runtime::Metadata()); exec_->SetLib(lib); - CompileEngine::Global()->Clear(); } ExprDeviceMap VMCompiler::AnalyzeContext() const { diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index 3a3796373a614..a05c52ced07f9 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -43,8 +43,9 @@ #include "../../../runtime/vm/naive_allocator.h" #include "../../../runtime/vm/profiler/vm.h" -#include "../../backend/compile_engine.h" #include "../../transforms/pass_utils.h" +#include "../te_compiler.h" +#include "../te_compiler_cache.h" namespace tvm { namespace relay { @@ -75,12 +76,14 @@ struct VMCompilerContext { TagMap tag_map; // Map from global var to a unique integer GlobalMap global_map; + // TEcompiler for lowering + tec::TECompiler compiler; // List of constants std::vector constants; // Device type for constants std::vector const_device_type; // List of cached functions - std::vector cached_funcs; + std::vector cached_funcs; // The functions that have been lowered. std::unordered_map seen_funcs; }; diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc index b61567d0bae08..657e2c3924555 100644 --- a/src/relay/transforms/memory_alloc.cc +++ b/src/relay/transforms/memory_alloc.cc @@ -41,7 +41,8 @@ #include #include -#include "../backend/compile_engine.h" +#include "../backend/te_compiler.h" +#include "../backend/te_compiler_cache.h" #include "../op/memory/memory.h" #include "../op/vm/vm.h" #include "./pass_utils.h" @@ -49,6 +50,7 @@ #include "pattern_utils.h" using namespace tvm::runtime; +using namespace tvm::relay::tec; namespace tvm { namespace relay { @@ -271,9 +273,11 @@ class DialectRewriter : public ExprMutator { Array EmitShapeFunc(LetList* scope, const Function& func, const std::vector& new_args) { Array shape_func_ins; - auto engine = CompileEngine::Global(); + + TECompiler compiler; + CCacheKey key(func, target_host_); - auto cfunc = engine->LowerShapeFunc(key); + auto cfunc = compiler->LowerShapeFunc(key); auto input_states = cfunc->shape_func_param_states; Array is_inputs;