diff --git a/src/nnfusion/engine/pass/codegen/base_codegen_pass.cpp b/src/nnfusion/engine/pass/codegen/base_codegen_pass.cpp index 672e4a3a8..cc2d6509e 100644 --- a/src/nnfusion/engine/pass/codegen/base_codegen_pass.cpp +++ b/src/nnfusion/engine/pass/codegen/base_codegen_pass.cpp @@ -297,22 +297,33 @@ std::pair return std::make_pair(lup_alloc, lup_free); } -nnfusion::LanguageUnit_p BaseCodegenPass::codegen_mem_ref(KernelEmitter::Pointer kernel) +nnfusion::LanguageUnit_p BaseCodegenPass::codegen_mem_ref(nnfusion::ir::Instruction::Pointer ins) { - if (!kernel || FLAGS_fcustomized_mem_imp) + auto kernel = ins->getKernel(); + if (!kernel || FLAGS_fcustomized_mem_imp || ins->getGNode()->get_op_type() == "Result") return nullptr; LanguageUnit_p _lu(new LanguageUnit(kernel->get_function_name() + "_mem_ref")); auto& lu = *_lu; bool empty = true; - if (auto annotations = kernel->m_context->annotations) + if ((*ins)["InplaceTensorMapping"].is_valid()) { - for (auto oi_pair : annotations->get_in_place_oi_pairs()) + auto in_place_outputs = + (*ins)["InplaceTensorMapping"] + .as, + std::pair, size_t>>>(); + for (auto output : kernel->m_context->outputs) { - if (oi_pair.force_inplace == true) + if (is_ref_tensor(ins, output)) { - auto input = kernel->m_context->inputs[oi_pair.input]; - auto output = kernel->m_context->outputs[oi_pair.output]; - lu << output->get_name() << " = " << input->get_name() << ";\n"; + auto parent_tensor = in_place_outputs.at(output).first; + size_t tensor_offset = in_place_outputs.at(output).second; + + auto root_tensor = parent_tensor->get_root_tensor() + ? parent_tensor->get_root_tensor() + : parent_tensor; + lu << output->get_name() << " = " << root_tensor->get_name() + << ((tensor_offset > 0) ? (" + " + std::to_string(tensor_offset)) : ("")) + << ";\n"; empty = false; } } @@ -323,6 +334,23 @@ nnfusion::LanguageUnit_p BaseCodegenPass::codegen_mem_ref(KernelEmitter::Pointer return _lu; } +bool BaseCodegenPass::is_ref_tensor(nnfusion::ir::Instruction::Pointer ins, + shared_ptr output) +{ + if ((*ins)["InplaceTensorMapping"].is_valid()) + { + auto in_place_outputs = + (*ins)["InplaceTensorMapping"] + .as, + std::pair, size_t>>>(); + // input tensor is unallocated (e.g., Parameter), need to assign address at runtime + if (in_place_outputs.count(output) > 0 && + (in_place_outputs.at(output).first)->get_pool_offset() == SIZE_MAX) + return true; + } + return false; +} + LanguageUnit_p BaseCodegenPass::codegen_device_type() { auto lu_devtype = make_shared("device_type"); diff --git a/src/nnfusion/engine/pass/codegen/base_codegen_pass.hpp b/src/nnfusion/engine/pass/codegen/base_codegen_pass.hpp index 40cab2b48..515e3c20e 100644 --- a/src/nnfusion/engine/pass/codegen/base_codegen_pass.hpp +++ b/src/nnfusion/engine/pass/codegen/base_codegen_pass.hpp @@ -92,7 +92,10 @@ namespace nnfusion virtual NNFusion_DeviceType device_type() { return NNFusion_DeviceType::UNKNOWN; } virtual std::pair get_customized_mem_imp(nnfusion::ir::Instruction::Pointer ins); - LanguageUnit_p codegen_mem_ref(KernelEmitter::Pointer kernel); + LanguageUnit_p codegen_mem_ref(nnfusion::ir::Instruction::Pointer ins); + // check if an output tensor of ins is ref_tensor, that needs to assign address at runtime + bool is_ref_tensor(nnfusion::ir::Instruction::Pointer ins, + shared_ptr out); LanguageUnit_p codegen_device_type(); LanguageUnit_p codegen_workspace_size(std::shared_ptr tu); CodeGenerator::Pointer projgen; diff --git a/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp b/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp index 361d553f7..a80139f61 100644 --- a/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp +++ b/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp @@ -327,13 +327,17 @@ bool CudaCodegenPass::collect_funcs(std::shared_ptr ctx, { for (size_t i = 0; i < gnode->get_out_edges().size(); i++) { - if (gnode->get_out_edges()[i]->get_dst()->get_op_ptr()->is_output()) + auto out_tensor = + kernel->m_context->outputs[gnode->get_out_edges()[i]->get_src_output()]; + if (gnode->get_out_edges()[i]->get_dst()->get_op_ptr()->is_output() && + !is_ref_tensor(ins, out_tensor)) { std::shared_ptr output = gnode->get_out_edges()[i]->get_dst(); std::string in_name = output->get_input_tensor(0).get_name(); std::string out_name = output->get_output_tensor(0).get_name(); int pos = call_str.find(", " + in_name); call_str.replace(pos, in_name.size() + 2, ", " + out_name); + (*output)["is_eliminative"] = true; } } } @@ -716,9 +720,9 @@ nnfusion::LanguageUnit_p CudaCodegenPass::func_call_codegen(nnfusion::ir::Instru } } - auto mem_ref = codegen_mem_ref(kernel); + auto mem_ref = codegen_mem_ref(ins); if (mem_ref != nullptr) - lu << codegen_mem_ref(kernel)->get_code(); + lu << codegen_mem_ref(ins)->get_code(); if (ins->name() == "Memcpy") { @@ -757,15 +761,16 @@ nnfusion::LanguageUnit_p CudaCodegenPass::func_call_codegen(nnfusion::ir::Instru } else { - if (ins->getKernel()->is_eliminative()) - { - lu << "// eliminated: " << func_call; - } - // todo: this hack is to eliminate d2d copy caused by extern result memory - else if (FLAGS_fextern_result_memory && gnode && gnode->get_op_ptr()->is_output()) + if (ins->getKernel()->is_eliminative() || + (*(ins->getGNode()))["is_eliminative"].is_valid_as()) { lu << "// eliminated: " << func_call; } + // // todo: this hack is to eliminate d2d copy caused by extern result memory + // else if (FLAGS_fextern_result_memory && gnode && gnode->get_op_ptr()->is_output()) + // { + // lu << "// eliminated: " << func_call; + // } else { diff --git a/src/nnfusion/engine/pass/codegen/hlsl_cpp_codegen_pass.cpp b/src/nnfusion/engine/pass/codegen/hlsl_cpp_codegen_pass.cpp index 1cfb91ee9..26f17d374 100644 --- a/src/nnfusion/engine/pass/codegen/hlsl_cpp_codegen_pass.cpp +++ b/src/nnfusion/engine/pass/codegen/hlsl_cpp_codegen_pass.cpp @@ -287,7 +287,7 @@ bool HLSLCPPCodegenPass::collect_funcs(std::shared_ptr ctx, if (FLAGS_fcustomized_mem_imp) lup_func_calls->unit_vec.push_back(get_customized_mem_imp(ins).first); - auto mem_ref = codegen_mem_ref(kernel); + auto mem_ref = codegen_mem_ref(ins); if (mem_ref != nullptr) lup_func_calls->unit_vec.push_back(mem_ref); lup_func_calls->unit_vec.push_back(kernel_func_call); diff --git a/src/nnfusion/engine/pass/codegen/hlsl_cs_codegen_pass.cpp b/src/nnfusion/engine/pass/codegen/hlsl_cs_codegen_pass.cpp index 6ba1914cd..b2e60cdd4 100644 --- a/src/nnfusion/engine/pass/codegen/hlsl_cs_codegen_pass.cpp +++ b/src/nnfusion/engine/pass/codegen/hlsl_cs_codegen_pass.cpp @@ -240,7 +240,7 @@ bool HLSLCSCodegenPass::collect_funcs(std::shared_ptr ctx, std::make_shared(fu->call_unit->get_symbol(), call_str); if (FLAGS_fcustomized_mem_imp) lup_func_calls->unit_vec.push_back(get_customized_mem_imp(ins).first); - auto mem_ref = codegen_mem_ref(kernel); + auto mem_ref = codegen_mem_ref(ins); if (mem_ref != nullptr) lup_func_calls->unit_vec.push_back(mem_ref); lup_func_calls->unit_vec.push_back(kernel_func_call);