From d7a72d2cd4b479b1f98fb7a642813bd0dce26f7a Mon Sep 17 00:00:00 2001 From: masahi Date: Sat, 13 Feb 2021 19:40:00 +0900 Subject: [PATCH] [VM] Move param bind to OptimizeModule (#7451) * [VM] Move param bind to OptimizeModule * add test to verify the number of free vars after opt * remove const from OptimizeModule --- src/relay/backend/vm/compiler.cc | 20 ++++++++++---------- src/relay/backend/vm/compiler.h | 3 +-- tests/python/relay/test_vm.py | 4 ++++ 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 7861502965a8..7697b59437f0 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -892,15 +892,6 @@ void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) { } void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) { - if (params_.size()) { - BaseFunc base_func = mod->Lookup("main"); - ICHECK(base_func->IsInstance()) - << "VM compiler expects to compile relay::Function"; - auto f = relay::backend::BindParamsByName(Downcast(base_func), params_); - auto gvar = mod->GetGlobalVar("main"); - mod->Add(gvar, f); - } - exec_ = make_object(); targets_ = targets; target_host_ = target_host; @@ -1005,8 +996,17 @@ transform::Sequential MemoryOpt(tvm::Target host_target, TargetsMap targets) { return transform::Sequential(pass_seqs); } -IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets, +IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetsMap& targets, const Target& target_host) { + if (params_.size()) { + BaseFunc base_func = mod->Lookup("main"); + ICHECK(base_func->IsInstance()) + << "VM compiler expects to compile relay::Function"; + auto f = relay::backend::BindParamsByName(Downcast(base_func), params_); + auto gvar = mod->GetGlobalVar("main"); + mod->Add(gvar, f); + } + Array pass_seqs; Array entry_functions{"main"}; pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index 56965c544701..615a8181b387 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -125,8 +125,7 @@ class VMCompiler : public runtime::ModuleNode { * * \return The optimized IRModule. */ - IRModule OptimizeModule(const IRModule& mod, const TargetsMap& targets, - const Target& target_host); + IRModule OptimizeModule(IRModule mod, const TargetsMap& targets, const Target& target_host); /*! * \brief Populate the global function names in a map where the value is used diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 6958010176e3..975070ad1aaa 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -678,6 +678,10 @@ def test_vm_optimize(): comp = relay.vm.VMCompiler() opt_mod, _ = comp.optimize(mod, target="llvm", params=params) + free_vars = relay.analysis.free_vars(opt_mod["main"].body) + # Paremeters should all be bound, so the only free var is data + assert len(free_vars) == 1 + @tvm.testing.uses_gpu def test_loop_free_var():