Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Virtual Device] Store function parameter virtual devices in virtual_device_ field #9907

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,11 @@ class RelayExprNode : public BaseExprNode {
*
* For expressions that have the function type, the virtual device describes where the result of
* the call to the function or closure is stored (instead of where the function itself is stored).
* For example, the virtual device of f = fn(x) { body } is the virtual device of f(y), not where
* the function itself is stored. Note that f(y)'s virtual device will be the same as the virtual
* device of body. For more details, see the documentation in
* src/relay/transforms/device_planner.cc.
*
* The VirtualDevice's Target field describes how the body of the function should be compiled.
*
* Set to VirtualDevice::FullyUnconstrained by default.
Expand All @@ -190,6 +195,13 @@ class RelayExprNode : public BaseExprNode {
/*!
* \return The virtual device (VirtualDevice).
* If the virtual device is not defined, returns VirtualDevice::FullyUnconstrained().
* Note that for function types, the virtual device is the device where the result of a
* call to the function is stored, not where the function itself lives.
* For example, the virtual device of f = fn(x) { body } is the virtual device of f(y), not where
* the function itself is stored. Note that f(y)'s virtual device will be the same as the virtual
* device of body.
*
* See the documentation of the virtual_device_ field (above) for more details.
*/
VirtualDevice virtual_device() const;

Expand Down
20 changes: 0 additions & 20 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,26 +190,6 @@ constexpr const char* kTarget = "target";
*/
constexpr const char* kGlobalSymbol = "global_symbol";

/*!
* \brief The \p VirtualDevice which will hold each of the functions parameters.
*
* Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but
* may be included as an annotation on user programs.
*
* Type: Array<VirtualDevice>
*/
constexpr const char* kParamVirtualDevice = "param_virtual_devices";

/*!
* \brief The \p VirtualDevice which will hold the function result.
*
* Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but
* may be included as an annotation on user programs.
*
* Type: VirtualDevice
*/
constexpr const char* kResultVirtualDevice = "result_virtual_device";

} // namespace attr
} // namespace tvm
#endif // TVM_IR_FUNCTION_H_
2 changes: 2 additions & 0 deletions src/relay/backend/graph_plan_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
smap.Set(GetRef<Expr>(kv.first), storage_info);
}
// Either all or none of the nodes should be annotated.
VLOG(1) << "num annotated nodes / num_nodes: " << num_annotated_nodes << " / " << num_nodes
<< std::endl;
if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) {
LOG(FATAL) << num_annotated_nodes << " out of " << num_nodes
<< "expressions are assigned with virtual device types. Either all "
Expand Down
24 changes: 11 additions & 13 deletions src/relay/op/memory/on_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,15 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr) {

Function FunctionOnDevice(Function function, Array<VirtualDevice> param_virtual_devices,
VirtualDevice result_virtual_device) {
return WithAttrs(std::move(function),
{{tvm::attr::kParamVirtualDevice, std::move(param_virtual_devices)},
{tvm::attr::kResultVirtualDevice, std::move(result_virtual_device)}});
ICHECK_EQ(param_virtual_devices.size(), function->params.size())
<< "There should be one virtual device per function parameter.";
Array<Var> annotated_params;
for (size_t i = 0; i < function->params.size(); i++) {
annotated_params.push_back(WithFields(function->params[i], {}, {}, param_virtual_devices[i]));
}
auto func = WithFields(function, annotated_params, {}, {}, {}, {}, result_virtual_device);
VLOG(1) << "Annotated func: " << PrettyPrint(func);
return func;
}

TVM_REGISTER_GLOBAL("relay.op.annotation._make.FunctionOnDevice").set_body_typed(FunctionOnDevice);
Expand All @@ -166,22 +172,14 @@ Function MaybeFunctionOnDevice(Function function, Array<VirtualDevice> param_vir
}

VirtualDevice GetFunctionResultVirtualDevice(const FunctionNode* function_node) {
auto opt_virtual_device = function_node->GetAttr<VirtualDevice>(tvm::attr::kResultVirtualDevice);
return opt_virtual_device.value_or(VirtualDevice::FullyUnconstrained());
return function_node->virtual_device();
}

VirtualDevice GetFunctionParamVirtualDevice(const FunctionNode* function_node, size_t i) {
ICHECK_LT(i, function_node->params.size())
<< "param index " << i << " out of range for function of arity "
<< function_node->params.size();
auto opt_array = function_node->GetAttr<Array<VirtualDevice>>(tvm::attr::kParamVirtualDevice);
if (!opt_array) {
// No annotation.
return VirtualDevice::FullyUnconstrained();
}
ICHECK_EQ(opt_array.value().size(), function_node->params.size())
<< "annotation parameters do not match function arity";
return opt_array.value()[i];
return function_node->params[i]->virtual_device();
}

} // namespace relay
Expand Down