Skip to content

Commit

Permalink
store function result's virtual device in virtual_device_ field
Browse files Browse the repository at this point in the history
  • Loading branch information
electriclilies committed Jan 14, 2022
1 parent b899af5 commit 1205fd0
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 16 deletions.
10 changes: 0 additions & 10 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,16 +200,6 @@ constexpr const char* kGlobalSymbol = "global_symbol";
*/
constexpr const char* kParamVirtualDevice = "param_virtual_devices";

/*!
* \brief The \p VirtualDevice which will hold the function result.
*
* Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but
* may be included as an annotation on user programs.
*
* Type: VirtualDevice
*/
constexpr const char* kResultVirtualDevice = "result_virtual_device";

} // namespace attr
} // namespace tvm
#endif // TVM_IR_FUNCTION_H_
34 changes: 34 additions & 0 deletions python/tvm/ir/json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,39 @@ def _updater(data):
return _updater


def create_updater_08_to_09():
"""
Create an update to upgrade json from v0.8 to v0.9
Returns
-------
fupdater : function
The updater function
"""

def _initialize_virtual_device(item, _):
if "virtual_device_" not in item["attrs"].keys():
item["attrs"]["virtual_device_"] = "0"
return item

node_map = {
# Base IR
"GlobalVar": _initialize_virtual_device,
"relay.Var": _initialize_virtual_device,
"relay.Function": _initialize_virtual_device,
"relay.Tuple": _initialize_virtual_device,
"relay.Call": _initialize_virtual_device,
"relay.Let": _initialize_virtual_device,
"relay.If": _initialize_virtual_device,
"relay.TupleGetItem": _initialize_virtual_device,
"relay.RefCreate": _initialize_virtual_device,
"relay.RefRead": _initialize_virtual_device,
"relay.RefWrite": _initialize_virtual_device,
"relay.Match": _initialize_virtual_device,
}

return create_updater(node_map, "0.8", "0.9")

def create_updater_08_to_09():
"""
Create an update to upgrade json from v0.8 to v0.9
Expand Down Expand Up @@ -86,6 +119,7 @@ def _initialize_virtual_device(item, _):
"relay.RefRead": _initialize_virtual_device,
"relay.RefWrite": _initialize_virtual_device,
"relay.Match": _initialize_virtual_device,
"relay.Constant": _initialize_virtual_device,
}

return create_updater(node_map, "0.8", "0.9")
Expand Down
1 change: 1 addition & 0 deletions src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) {
if (fn->ret_type.defined()) {
doc << "-> " << Print(fn->ret_type) << " ";
}
doc << "Virtual Device: " << Print(fn->virtual_device()) << " \n";
doc << PrintBody(fn->body);
return doc;
}
Expand Down
2 changes: 2 additions & 0 deletions src/relay/backend/graph_plan_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
smap.Set(GetRef<Expr>(kv.first), storage_info);
}
// Either all or none of the nodes should be annotated.
VLOG(1) << "num annotated nodes / num_nodes: " << num_annotated_nodes << " / " << num_nodes
<< std::endl;
if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) {
LOG(FATAL) << num_annotated_nodes << " out of " << num_nodes
<< "expressions are assigned with virtual device types. Either all "
Expand Down
1 change: 1 addition & 0 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ using namespace tvm::runtime;
Constant::Constant(runtime::NDArray data, Span span) {
ObjectPtr<ConstantNode> n = make_object<ConstantNode>();
n->data = std::move(data);
n->virtual_device_ = VirtualDevice::FullyUnconstrained();
n->span = std::move(span);
data_ = std::move(n);
}
Expand Down
11 changes: 6 additions & 5 deletions src/relay/op/memory/on_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,11 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr) {

Function FunctionOnDevice(Function function, Array<VirtualDevice> param_virtual_devices,
VirtualDevice result_virtual_device) {
return WithAttrs(std::move(function),
{{tvm::attr::kParamVirtualDevice, std::move(param_virtual_devices)},
{tvm::attr::kResultVirtualDevice, std::move(result_virtual_device)}});
auto func = WithAttrs(
WithFields(std::move(function), {}, {}, {}, {}, {}, std::move(result_virtual_device)),
{{tvm::attr::kParamVirtualDevice, std::move(param_virtual_devices)}});
VLOG(1) << "Annotated func: " << PrettyPrint(func);
return func;
}

TVM_REGISTER_GLOBAL("relay.op.annotation._make.FunctionOnDevice").set_body_typed(FunctionOnDevice);
Expand All @@ -166,8 +168,7 @@ Function MaybeFunctionOnDevice(Function function, Array<VirtualDevice> param_vir
}

VirtualDevice GetFunctionResultVirtualDevice(const FunctionNode* function_node) {
auto opt_virtual_device = function_node->GetAttr<VirtualDevice>(tvm::attr::kResultVirtualDevice);
return opt_virtual_device.value_or(VirtualDevice::FullyUnconstrained());
return function_node->virtual_device();
}

VirtualDevice GetFunctionParamVirtualDevice(const FunctionNode* function_node, size_t i) {
Expand Down
4 changes: 3 additions & 1 deletion src/relay/transforms/to_cps.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,9 @@ Function UnCPS(const Function& f) {
ICHECK_GT(f->params.size(), 0);
Array<Var> new_params;
for (const auto& p : f->params) {
new_params.push_back(Var(p->name_hint(), p->checked_type()));
// TODO(@electriclilies): Not sure if this is correct, it was copying before,
// but seems like we just need to make a copy to pop so should be fine?
new_params.push_back(WithFields(std::move(p)));
}
auto cont_type = Downcast<FuncType>(new_params.back()->type_annotation);
new_params.pop_back();
Expand Down

0 comments on commit 1205fd0

Please sign in to comment.