diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 72dc8a5c9bf95..1493544e73242 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -190,16 +190,6 @@ constexpr const char* kTarget = "target"; */ constexpr const char* kGlobalSymbol = "global_symbol"; -/*! - * \brief The \p VirtualDevice which will hold each of the functions parameters. - * - * Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but - * may be included as an annotation on user programs. - * - * Type: Array - */ -constexpr const char* kParamVirtualDevice = "param_virtual_devices"; - } // namespace attr } // namespace tvm #endif // TVM_IR_FUNCTION_H_ diff --git a/src/relay/op/memory/on_device.cc b/src/relay/op/memory/on_device.cc index 4536725b20733..7d1d7f39bc079 100644 --- a/src/relay/op/memory/on_device.cc +++ b/src/relay/op/memory/on_device.cc @@ -144,9 +144,14 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr) { Function FunctionOnDevice(Function function, Array param_virtual_devices, VirtualDevice result_virtual_device) { - auto func = WithAttrs( - WithFields(std::move(function), {}, {}, {}, {}, {}, std::move(result_virtual_device)), - {{tvm::attr::kParamVirtualDevice, std::move(param_virtual_devices)}}); + ICHECK_EQ(param_virtual_devices.size(), function->params.size()) + << "There should be one virtual device per function parameter."; + Array annotated_params; + for (size_t i = 0; i < function->params.size(); i++) { + annotated_params.push_back(WithFields(function->params[i], {}, {}, param_virtual_devices[i])); + } + auto func = WithFields(std::move(function), std::move(annotated_params), {}, {}, {}, {}, + std::move(result_virtual_device)); VLOG(1) << "Annotated func: " << PrettyPrint(func); return func; }