Skip to content

Commit

Permalink
[Core] Move constant-folded variable from exec scope to root scope fo…
Browse files Browse the repository at this point in the history
…r fix predictor clone error (#9062)
  • Loading branch information
shentanyue committed Jun 8, 2022
1 parent 3554f73 commit ef8d53c
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 7 deletions.
8 changes: 4 additions & 4 deletions lite/api/cxx_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ lite::Tensor *Predictor::GetInput(size_t offset) {
<< "The network has " << input_names_.size() << " inputs"
<< ", the offset should be less than this.";
auto *in_var = exec_scope_->FindVar(input_names_[offset]);
CHECK(in_var) << "no fatch variable " << input_names_[offset]
CHECK(in_var) << "no feed variable " << input_names_[offset]
<< " in exec_scope";
return in_var->GetMutable<lite::Tensor>();
}
Expand Down Expand Up @@ -249,7 +249,7 @@ const lite::Tensor *Predictor::GetOutput(size_t offset) const {
<< ", the offset should be less than this.";
const std::string name = output_names_.at(offset);
auto *out_var = exec_scope_->FindVar(name);
CHECK(out_var) << "no fatch variable " << name << " in exec_scope";
CHECK(out_var) << "no fetch variable " << name << " in exec_scope";
return out_var->GetMutable<lite::Tensor>();
}

Expand All @@ -265,15 +265,15 @@ std::vector<const lite::Tensor *> Predictor::GetOutputs() const {
#else
const lite::Tensor *Predictor::GetOutput(size_t offset) const {
auto *_fetch_list = exec_scope_->FindVar("fetch");
CHECK(_fetch_list) << "no fatch variable in exec_scope";
CHECK(_fetch_list) << "no fetch variable in exec_scope";
auto &fetch_list = *_fetch_list->GetMutable<std::vector<lite::Tensor>>();
CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow";
return &fetch_list.at(offset);
}

std::vector<const lite::Tensor *> Predictor::GetOutputs() const {
auto *_fetch_list = exec_scope_->FindVar("fetch");
CHECK(_fetch_list) << "no fatch variable in exec_scope";
CHECK(_fetch_list) << "no fetch variable in exec_scope";
auto &fetch_list = *_fetch_list->GetMutable<std::vector<lite::Tensor>>();

std::vector<const lite::Tensor *> outputs;
Expand Down
2 changes: 1 addition & 1 deletion lite/api/cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class LITE_API Predictor {
if (!program_generated_) {
GenRuntimeProgram();
}
// step 2. Create a predictor friom current program_desc_ and
// step 2. Create a predictor from current program_desc_ and
// runtime_program.
auto predictor =
std::make_shared<Predictor>(program_desc_, scope_, valid_places_);
Expand Down
16 changes: 14 additions & 2 deletions lite/core/program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@ void UpdateVarDescFromTensorInfo(cpp::VarDesc* var,
var->SetType(cpp::VarDesc::Type::LOD_TENSOR);
auto tensor = scope->FindVar(var_name)->GetMutable<Tensor>();
var->SetPersistable(tensor->persistable());
// Move the persistable var from exec scope to the root scope
auto root_scope = scope->MutableParent();
if (tensor->persistable() && root_scope != scope &&
!root_scope->FindLocalVar(var_name)) {
// Find or create new var in root scope
auto root_tensor = root_scope->LocalVar(var_name)->GetMutable<Tensor>();
if (root_tensor != tensor) {
root_tensor->CopyDataFrom(*tensor);
scope->DeleteLocalVar(var_name);
}
}

if (var_name != "feed" && var_name != "fetch") {
var->SetShape(tensor->dims().data());
auto precision = tensor->precision();
Expand Down Expand Up @@ -579,8 +591,7 @@ void Program::PrepareWorkspace(
#if defined(LITE_WITH_XPU) || defined(LITE_WITH_CUDA)
}
#endif

// Create tensors or wights from variable description.
// Create tensors or weights from variable description.
if (!var_desc->Persistable()) {
vars_.push_back(var_name);
auto* var = exec_scope_->Var(var_name);
Expand All @@ -603,6 +614,7 @@ void Program::PrepareWorkspace(
VLOG(4) << " - dims " << tensor->dims().repr();
}
tensor->set_precision(var_data_type);
tensor->set_persistable(var_desc->Persistable());
} else if (var_type == lite::VarDescAPI::Type::LOD_TENSOR_ARRAY) {
var_type_map_[var_name] = LiteType::GetTensorListTy(
TARGET(kUnk), PRECISION(kUnk), DATALAYOUT(kUnk));
Expand Down
13 changes: 13 additions & 0 deletions lite/core/scope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,19 @@ Variable *Scope::FindLocalVar(const std::string &name) const {
return nullptr;
}

void Scope::DeleteLocalVar(const std::string &name) {
rwlock_->RDLock();
if (FindLocalVar(name)) {
auto *p = vars_[name].release();
if (!p) {
delete p;
p = nullptr;
}
vars_.erase(name);
}
rwlock_->UNLock();
}

// AttributeVarNames will get persistive attribute names stored in parent scope
std::vector<std::string> Scope::AttributeVarNames() const {
std::vector<std::string> resulted_keys;
Expand Down
2 changes: 2 additions & 0 deletions lite/core/scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class Scope final {

Variable* FindLocalVar(const std::string& name) const;

void DeleteLocalVar(const std::string& name);

const Scope* parent() const { return parent_; }
Scope* MutableParent() { return const_cast<Scope*>(parent_); }

Expand Down

0 comments on commit ef8d53c

Please sign in to comment.