Skip to content

Commit

Permalink
Change function constructors to WithFields (apache#9690)
Browse files Browse the repository at this point in the history
* Change function constructors to WithFields

Get rid of std::moves, they were causing problems

* Fix bad rebase

* flaky

* try to trigger ci

* try again
  • Loading branch information
electriclilies authored and ylc committed Feb 16, 2022
1 parent 27e7004 commit 905aa88
Show file tree
Hide file tree
Showing 25 changed files with 94 additions and 94 deletions.
1 change: 1 addition & 0 deletions python/tvm/ir/json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def _initialize_virtual_device(item, _):
"relay.RefRead": _initialize_virtual_device,
"relay.RefWrite": _initialize_virtual_device,
"relay.Match": _initialize_virtual_device,
"relay.Constant": _initialize_virtual_device,
}

return create_updater(node_map, "0.8", "0.9")
Expand Down
7 changes: 3 additions & 4 deletions src/relay/backend/contrib/cmsisnn/extract_constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ class ExtractConstantsMutator : public MixedModeMutator {
auto new_body = VisitExpr(func->body);
functions_.pop_back();
if (function_to_constants_[func].size()) {
func = Function(FreeVars(new_body), new_body, func->ret_type, FreeTypeVars(new_body, mod_),
func->attrs);
func = WithFields(func, FreeVars(new_body), new_body, func->ret_type,
FreeTypeVars(new_body, mod_), func->attrs);
}
return std::move(func);
}
Expand Down Expand Up @@ -159,8 +159,7 @@ IRModule ExtractConstants(const IRModule& mod) {
auto new_main_body = extract_constants.VisitExpr(main_func->body);
if (!new_main_body.same_as(main_func->body)) {
auto main_var = mod->GetGlobalVar("main");
auto new_main_func = Function(main_func->params, new_main_body, main_func->ret_type,
main_func->type_params, main_func->attrs);
Function new_main_func = WithFields(main_func, main_func->params, new_main_body);
mod->Update(main_var, new_main_func);
}
return mod;
Expand Down
9 changes: 2 additions & 7 deletions src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,8 @@ class RelayToTIRVisitor : public MixedModeMutator {

IRModule Mutate() {
GlobalVar main_global_var = ir_module_->GetGlobalVar("main");
BaseFunc main = ir_module_->Lookup(main_global_var);
Function main_func = GetRef<Function>(main.as<FunctionNode>());

// Copy everything across and mutate the body
Function mutated_main =
Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type,
main_func->type_params, main_func->attrs, main_func->span);
Function main = Downcast<Function>(ir_module_->Lookup(main_global_var));
Function mutated_main = WithFields(main, main->params, VisitExpr(main->body));

ir_module_->Update(main_global_var, mutated_main);

Expand Down
8 changes: 2 additions & 6 deletions src/relay/backend/contrib/ethosu/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,8 @@ class RelayToTIRMutator : public MixedModeMutator {

IRModule operator()() {
GlobalVar main_global_var = ir_module_->GetGlobalVar("main");
Function main_func = Downcast<Function>(ir_module_->Lookup(main_global_var));

// Copy everything across and mutate the body
Function mutated_main =
Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type,
main_func->type_params, main_func->attrs, main_func->span);
Function main = Downcast<Function>(ir_module_->Lookup(main_global_var));
Function mutated_main = WithFields(main, main->params, VisitExpr(main->body));

ir_module_->Update(main_global_var, mutated_main);
ir_module_ = WithAttr(ir_module_, "device_contexts", device_contexts_);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,8 @@ class ConvertAddToSubtract : public MixedModeMutator {

IRModule Mutate() {
GlobalVar main_global_var = ir_module_->GetGlobalVar("main");
BaseFunc main = ir_module_->Lookup(main_global_var);
Function main_func = GetRef<Function>(main.as<FunctionNode>());

// Copy everything across and mutate the body
Function mutated_main =
Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type,
main_func->type_params, main_func->attrs, main_func->span);
Function main = GetRef<Function>(ir_module_->Lookup(main_global_var).as<FunctionNode>());
Function mutated_main = WithFields(main, main->params, VisitExpr(main->body));

ir_module_->Update(main_global_var, mutated_main);

Expand Down
12 changes: 10 additions & 2 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class TECompilerImpl : public TECompilerNode {
}

IRModule GetLoweredFunctions() {
VLOG(1) << "GetLoweredFunctions";
IRModule mod;
// Extract lowered functions from the cache
for (const auto& it : cache_) {
Expand Down Expand Up @@ -164,8 +165,15 @@ class TECompilerImpl : public TECompilerNode {
for (const auto& kv2 : kv1.second->cached_func->funcs->functions) {
if (const auto* function_node = kv2.second.as<FunctionNode>()) {
// Abandon the existing function annotations.
Function function(function_node->params, function_node->body, function_node->ret_type,
function_node->type_params, /*attrs=*/{}, function_node->span);

// Unfortuantely, Optional<DictAttrs>() is indistinguishable from
// NullValue<DictAttrs>(), and DictAttrs() is nullptr, so to erase the attributes, we
// need pass in DictAttrs<Map<String, ObjectRef>()), which is a DictAttrs containing no
// attributes.
Function function =
WithFields(GetRef<Function>(function_node), function_node->params,
function_node->body, function_node->ret_type, function_node->type_params,
/* erase attributes */ DictAttrs(Map<String, ObjectRef>()));
// Mark function as 'extern' using the "ExternalSymbol" attribute.
function = WithAttr(std::move(function), attr::kExternalSymbol, kv2.first->name_hint);
module->Add(kv2.first, function);
Expand Down
6 changes: 2 additions & 4 deletions src/relay/backend/vm/lambda_lift.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator {

if (function_nesting() == 1) {
// We don't need to lift global functions.
return Function(func_node->params, VisitExpr(func_node->body), func_node->ret_type,
func_node->type_params, func_node->attrs, func_node->span);
return WithFields(GetRef<Function>(func_node), func_node->params, VisitExpr(func_node->body));
}

auto name = GenerateName(func);
Expand Down Expand Up @@ -188,8 +187,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator {
// construct the "closure" function with fully annotated arguments, no longer relying
// on type inference.
size_t before_arity = body->params.size();
auto rebound_body = Function(func->params, Bind(body->body, rebinding_map), func->ret_type,
func->type_params, func->attrs, func->span);
auto rebound_body = WithFields(func, func->params, Bind(body->body, rebinding_map));
size_t after_arity = rebound_body->params.size();
CHECK_EQ(before_arity, after_arity);
lifted_func =
Expand Down
1 change: 1 addition & 0 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ using namespace tvm::runtime;
Constant::Constant(runtime::NDArray data, Span span) {
ObjectPtr<ConstantNode> n = make_object<ConstantNode>();
n->data = std::move(data);
n->virtual_device_ = VirtualDevice::FullyUnconstrained();
n->span = std::move(span);
data_ = std::move(n);
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/quantize/annotate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ Pass QuantizeAnnotate() {
for (const auto& x : FreeVars(func)) {
new_params.push_back(x);
}
return Function(new_params, func->body, func->ret_type, func->type_params, func->attrs);
return WithFields(func, new_params);
};
return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {});
}
Expand Down
9 changes: 7 additions & 2 deletions src/relay/quantize/calibrate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,13 @@ class StatsCollector : private ExprMutator {
const FunctionNode* func = new_e.as<FunctionNode>();
ICHECK(func) << "Input shoule be Function";
Expr new_body = Tuple(std::move(profile_data_));
return Function(FreeVars(new_body), new_body, NullValue<Type>(), func->type_params,
func->attrs);
Function ret_func = WithFields(GetRef<Function>(func), FreeVars(new_body), new_body);

// We are changing the function's ret_type to an empty type. Unfortunately, Optional<Type>() is
// indistinguishable from NullValue<Type>(), so we can't express "update to nullptr" in
// WithFields.
ret_func.CopyOnWrite()->ret_type = NullValue<Type>();
return ret_func;
}

private:
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
func = Downcast<Function>(post);
new_body = InsertCompilerEndAndPropogateTarget(func->body);
}
return Function(func->params, new_body, func->ret_type, func->type_params, func->attrs);
return WithFields(func, func->params, new_body);
}

Expr Rewrite_(const LetNode* op, const Expr& post) override {
Expand Down
4 changes: 2 additions & 2 deletions src/relay/transforms/convert_sparse_conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,12 +292,12 @@ Pass Conv2dToSparse(const Array<ObjectRef>& weight_name, const Array<Array<PrimE
auto f0 =
Downcast<Function>(Conv2dToSparse(f, weight_name, weight_shape, layout, kernel_size));
Array<Var> sparse_params = FreeVars(f0);
auto f1 = Function(sparse_params, f0->body, f0->ret_type, f0->type_params, f0->attrs);
auto f1 = WithFields(f0, sparse_params);
Array<Var> params = FreeVars(f1);
for (const auto& var : sparse_params) {
params.push_back(var);
}
return Function(params, f1->body, f1->ret_type, f1->type_params, f1->attrs);
return WithFields(f1, params);
};
return CreateFunctionPass(pass_func, 4, "Conv2dToSparse", {"DeadCodeElimination"});
}
Expand Down
4 changes: 2 additions & 2 deletions src/relay/transforms/convert_sparse_dense.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,12 @@ Pass DenseToSparse(const Array<ObjectRef>& weight_name,
// Remove FreeVar warnings
auto f0 = Downcast<Function>(DenseToSparse(f, weight_name, weight_shape));
Array<Var> sparse_params = FreeVars(f0);
auto f1 = Function(sparse_params, f0->body, f0->ret_type, f0->type_params, f0->attrs);
auto f1 = WithFields(f0, sparse_params);
Array<Var> params = FreeVars(f1);
for (const auto& var : sparse_params) {
params.push_back(var);
}
return Function(params, f1->body, f1->ret_type, f1->type_params, f1->attrs);
return WithFields(f1, params);
};
return CreateFunctionPass(pass_func, 4, "DenseToSparse", {"DeadCodeElimination"});
}
Expand Down
9 changes: 5 additions & 4 deletions src/relay/transforms/de_duplicate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,17 @@ Expr DeDup(const Expr& e) {

Type VisitType(const Type& t) final { return t.defined() ? TypeMutator::VisitType(t) : t; }

Expr VisitExpr_(const FunctionNode* op) final {
Expr VisitExpr_(const FunctionNode* func_node) final {
tvm::Array<TypeVar> type_params;
for (const TypeVar& type_param : op->type_params) {
for (const TypeVar& type_param : func_node->type_params) {
type_params.push_back(Fresh(type_param));
}
tvm::Array<Var> params;
for (const Var& param : op->params) {
for (const Var& param : func_node->params) {
params.push_back(Fresh(param));
}
return Function(params, VisitExpr(op->body), VisitType(op->ret_type), type_params, op->attrs);
return WithFields(GetRef<Function>(func_node), params, VisitExpr(func_node->body),
VisitType(func_node->ret_type), type_params);
}

Pattern VisitPattern(const Pattern& p) final { return PatternFunctor::VisitPattern(p); }
Expand Down
7 changes: 4 additions & 3 deletions src/relay/transforms/defunctionalization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ class DefuncMutator : public ExprMutator {

auto apply_gv = GetApplyFunction(ft);
auto body = this->VisitExpr(Bind(fn->body, free_var_bind_map));
AddApplyCase(apply_gv, ft, c, Function(fn->params, body, fn->ret_type, fn->type_params),
AddApplyCase(apply_gv, ft, c, WithFields(GetRef<Function>(fn), fn->params, body),
pattern_vars);

return Call(c, call_args);
Expand Down Expand Up @@ -380,7 +380,7 @@ class DefuncMutator : public ExprMutator {
map.Set(f->type_params[i], type_args[i]);
}
// copy with typevars removed
auto copy = TypeSubst(Function(f->params, f->body, f->ret_type, {}), map);
auto copy = TypeSubst(WithFields(f, {}, {}, {}, /* erase type params */ Array<TypeVar>()), map);
return Downcast<Function>(copy);
}

Expand Down Expand Up @@ -410,7 +410,8 @@ class DefuncMutator : public ExprMutator {
}

auto bind = Downcast<Function>(Bind(f, var_bind_map));
return Function(params, this->VisitExpr(bind->body), bind->ret_type, {});
return WithFields(bind, params, this->VisitExpr(bind->body), bind->ret_type,
/* erase type params */ Array<TypeVar>());
}
};

Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/eta_expand.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ class EtaExpander : public ExprMutator {
params.push_back(var);
args.push_back(var);
}

return Function(args, Call(gvar, params), func->ret_type, func->type_params);
return WithFields(func, args, Call(gvar, params));
} else {
return std::move(gvar);
}
Expand Down
5 changes: 3 additions & 2 deletions src/relay/transforms/first_order_gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,9 @@ Pass FirstOrderGradient() {
});
return Pair(res.forward, grad_tuple);
});
ad_mod->Update(pr.first,
Function(func->params, body, GradRetType(GetRef<Function>(func)), {}));
ad_mod->Update(pr.first, WithFields(GetRef<Function>(func), func->params, body,
GradRetType(GetRef<Function>(func)),
/* erase type params */ Array<TypeVar>()));
}

return ad_mod;
Expand Down
19 changes: 10 additions & 9 deletions src/relay/transforms/higher_order_gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -341,28 +341,28 @@ struct ReverseAD : ExprMutator {
GlobalVar gv(op->name_hint + "_grad");
(*ad_gvars)[orig_gv] = gv;
Function orig_f = Downcast<Function>(DeDup(mod.value()->Lookup(orig_gv)));
std::vector<Var> params;
Array<Var> params;
for (const auto& p : orig_f->params) {
params.push_back(Downcast<Var>(VisitExpr(p)));
}
params.push_back(bp);
Expr body = VisitExpr(orig_f->body);
Function f(params, body, VisitType(orig_f->ret_type), orig_f->type_params, orig_f->attrs);
Function f = WithFields(orig_f, params, VisitExpr(orig_f->body), VisitType(orig_f->ret_type));
std::cout << "gv " << op->name_hint << ": " << AsText(f, false) << std::endl;
mod.value()->Add(gv, f);
}
return ad_gvars->at(orig_gv);
}

Expr VisitExpr_(const FunctionNode* op) final {
std::vector<Var> params;
for (const auto& var : op->params) {
Expr VisitExpr_(const FunctionNode* func_node) final {
Array<Var> params;
for (const auto& var : func_node->params) {
params.push_back(Downcast<Var>(VisitExpr(var)));
}
auto new_bp = Var("bp", bpt);
params.push_back(new_bp);
return Function(params, ReverseAD(mod, new_bp, ad_vars, ad_gvars)(op->body),
VisitType(op->ret_type), op->type_params, op->attrs);
return WithFields(GetRef<Function>(func_node), params,
ReverseAD(mod, new_bp, ad_vars, ad_gvars)(func_node->body),
VisitType(func_node->ret_type));
}

Type VisitType(const Type& t) final { return t.defined() ? ReverseType(t) : t; }
Expand Down Expand Up @@ -456,7 +456,8 @@ Expr Gradient(const Expr& re, const Optional<IRModule>& mod) {
};
return Pair(get_final_result(c, f->body->checked_type()), Tuple(ret));
});
auto ret = Function(f->params, body, GradRetType(GetRef<Function>(f)), {});
Function ret = WithFields(GetRef<Function>(f), f->params, body, GradRetType(GetRef<Function>(f)),
/* erase type params */ Array<TypeVar>());
CheckFeature(ret, FeatureSet::All() - fGraph);
return std::move(ret);
}
Expand Down
5 changes: 3 additions & 2 deletions src/relay/transforms/inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ class Inliner : ExprMutator {
}

Function Inline(const Function& func) {
return Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params,
func->attrs);
return WithFields(func, func->params, VisitExpr(func->body));
}

private:
Expand Down Expand Up @@ -131,6 +130,8 @@ class Inliner : ExprMutator {
const auto* fn = base_func.as<FunctionNode>();
ICHECK(fn) << "Expected to work on a Relay function.";

// There is an inconsistency here, the function itself gets shallow-copied but the body is not
// shallow-copied.
auto func = Function(fn->params, fn->body, fn->ret_type, fn->type_params, fn->attrs);
// Inline the function body to the caller if this function uses default
// compiler, i.e. no external codegen is needed.
Expand Down
24 changes: 12 additions & 12 deletions src/relay/transforms/partial_eval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -827,18 +827,18 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
Expr VisitFuncDynamic(const Function& func, const Func& f, const Expr& self) {
return store_.Extend<Expr>([&]() {
store_.Invalidate();
return Function(func->params, LetList::With([&](LetList* ll) {
std::vector<PStatic> pv;
for (const auto& v : func->params) {
pv.push_back(NoStatic(v));
}
tvm::Array<Type> type_args;
for (const auto& tp : func->type_params) {
type_args.push_back(tp);
}
return f(HasStatic(MkSFunc(f), self), pv, Attrs(), type_args, ll)->dynamic;
}),
func->ret_type, func->type_params, func->attrs);
return WithFields(
func, func->params, LetList::With([&](LetList* ll) {
std::vector<PStatic> pv;
for (const auto& v : func->params) {
pv.push_back(NoStatic(v));
}
tvm::Array<Type> type_args;
for (const auto& tp : func->type_params) {
type_args.push_back(tp);
}
return f(HasStatic(MkSFunc(f), self), pv, Attrs(), type_args, ll)->dynamic;
}));
});
}

Expand Down
Loading

0 comments on commit 905aa88

Please sign in to comment.