Skip to content

Commit

Permalink
parent b77d3e0
Browse files Browse the repository at this point in the history
author Lily Orth-Smith <lilyorthsmith@gmail.com> 1641339403 -0800
committer Lily Orth-Smith <lilyorthsmith@gmail.com> 1641949451 -0800

Make function result virtual_device_ stored in virtual_device_ field
  • Loading branch information
electriclilies committed Jan 12, 2022
1 parent b77d3e0 commit 00ca9ad
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 15 deletions.
2 changes: 2 additions & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ class RelayExprNode : public BaseExprNode {
*
* Set to VirtualDevice::FullyUnconstrained by default.
*
* Set to VirtualDevice::FullyUnconstrained by default.
*
* \note Unfortunately, the type of virtual_device_ needs to be ObjectRef to avoid a circular
* import.
*/
Expand Down
10 changes: 0 additions & 10 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,16 +200,6 @@ constexpr const char* kGlobalSymbol = "global_symbol";
*/
constexpr const char* kParamVirtualDevice = "param_virtual_devices";

/*!
* \brief The \p VirtualDevice which will hold the function result.
*
* Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but
* may be included as an annotation on user programs.
*
* Type: VirtualDevice
*/
constexpr const char* kResultVirtualDevice = "result_virtual_device";

} // namespace attr
} // namespace tvm
#endif // TVM_IR_FUNCTION_H_
1 change: 1 addition & 0 deletions src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) {
if (fn->ret_type.defined()) {
doc << "-> " << Print(fn->ret_type) << " ";
}
doc << "Virtual Device: " << Print(fn->virtual_device()) << " \n";
doc << PrintBody(fn->body);
return doc;
}
Expand Down
2 changes: 2 additions & 0 deletions src/relay/backend/graph_plan_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
smap.Set(GetRef<Expr>(kv.first), storage_info);
}
// Either all or none of the nodes should be annotated.
VLOG(1) << "num annotated nodes / num_nodes: " << num_annotated_nodes << " / " << num_nodes
<< std::endl;
if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) {
LOG(FATAL) << num_annotated_nodes << " out of " << num_nodes
<< "expressions are assigned with virtual device types. Either all "
Expand Down
11 changes: 6 additions & 5 deletions src/relay/op/memory/on_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,11 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr) {

Function FunctionOnDevice(Function function, Array<VirtualDevice> param_virtual_devices,
VirtualDevice result_virtual_device) {
return WithAttrs(std::move(function),
{{tvm::attr::kParamVirtualDevice, std::move(param_virtual_devices)},
{tvm::attr::kResultVirtualDevice, std::move(result_virtual_device)}});
auto func = WithAttrs(
WithFields(std::move(function), {}, {}, {}, {}, {}, std::move(result_virtual_device)),
{{tvm::attr::kParamVirtualDevice, std::move(param_virtual_devices)}});
VLOG(1) << "Annotated func: " << PrettyPrint(func);
return func;
}

TVM_REGISTER_GLOBAL("relay.op.annotation._make.FunctionOnDevice").set_body_typed(FunctionOnDevice);
Expand All @@ -166,8 +168,7 @@ Function MaybeFunctionOnDevice(Function function, Array<VirtualDevice> param_vir
}

VirtualDevice GetFunctionResultVirtualDevice(const FunctionNode* function_node) {
auto opt_virtual_device = function_node->GetAttr<VirtualDevice>(tvm::attr::kResultVirtualDevice);
return opt_virtual_device.value_or(VirtualDevice::FullyUnconstrained());
return function_node->virtual_device();
}

VirtualDevice GetFunctionParamVirtualDevice(const FunctionNode* function_node, size_t i) {
Expand Down

0 comments on commit 00ca9ad

Please sign in to comment.