Skip to content

Commit

Permalink
Progress on parsing & printing
Browse files Browse the repository at this point in the history
  • Loading branch information
electriclilies committed Feb 16, 2022
1 parent 3faecd6 commit 3a462b2
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
8 changes: 8 additions & 0 deletions include/tvm/target/virtual_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ using MemoryScope = String;
*
* These operations are needed during device planning.
*/

class VirtualDeviceNode : public AttrsNode<VirtualDeviceNode> {
private:
/*!
Expand Down Expand Up @@ -361,6 +362,13 @@ class VirtualDeviceCache {
std::unordered_set<VirtualDevice, StructuralHash, StructuralEqual> cache_;
};

/*! brief The attribute key for the virtual device. This key will be promoted to first class on
* functions. For use in the parser and printer only.
*
* Type: VirtualDevice
*/
constexpr const char* kVirtualDevice = "result_virtual_device";

} // namespace tvm

#endif // TVM_TARGET_VIRTUAL_DEVICE_H_
11 changes: 2 additions & 9 deletions src/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <tvm/runtime/logging.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/virtual_device.h>

#include <fstream>

Expand Down Expand Up @@ -230,13 +231,6 @@ GlobalTypeVar AddOrGet(InternTable<GlobalTypeVar>* table, const std::string& nam
}
}

/*! brief The attribute key for the virtual device. This key will be promoted to first class on
* functions.
*
* Type: VirtualDevice
*/
constexpr const char* kVirtualDevice = "result_virtual_device";

/*! \brief The parser class is the main interface to the parser.
* the parser is not currently exposed beyond this .cc file.
*
Expand Down Expand Up @@ -1145,8 +1139,7 @@ class Parser {
// TODO(@jroesch): attributes should never be null, they should always be empty.
if (raw_attrs.size()) {
// Promote kVirtualDevice to first-class
String vid_key = kVirtualDevice;
if (raw_attrs.count(vid_key)) {
if (raw_attrs.count(kVirtualDevice)) {
ObjectRef vid = raw_attrs.at(kVirtualDevice);
ICHECK(vid.as<VirtualDeviceNode>())
<< "Expected the " << kVirtualDevice << " to have type VirtualDeviceNode, but got "
Expand Down
7 changes: 6 additions & 1 deletion src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -445,11 +445,16 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) {
for (const Doc& d : PrintDictAttrs(fn->attrs)) {
params.push_back(d);
}
if (fn->virtual_device() != VirtualDevice::FullyUnconstrained()) {
Doc vid_doc;
vid_doc << kVirtualDevice << "=" << Print(fn->virtual_device());
params.push_back(vid_doc);
}
doc << Doc::Concat(params) << ") ";
if (fn->ret_type.defined()) {
doc << "-> " << Print(fn->ret_type) << " ";
}
doc << "Virtual Device: " << Print(fn->virtual_device()) << " \n";

doc << PrintBody(fn->body);
return doc;
}
Expand Down

0 comments on commit 3a462b2

Please sign in to comment.