Skip to content

Commit

Permalink
[PatternLang] Add Syntatic Sugar to the C++ pattern API and support D…
Browse files Browse the repository at this point in the history
…ataType Attribute Matching (apache#7120)

* Add Syntatic Sugar for C++ Pattern API, Support DataType Attribute match

* add missing tests

* fix lint

* fix license edit

* fix bad rebase
  • Loading branch information
Matthew Brookhart authored and masahi committed Jan 14, 2021
1 parent d55aa4a commit 55702de
Show file tree
Hide file tree
Showing 7 changed files with 319 additions and 37 deletions.
56 changes: 43 additions & 13 deletions include/tvm/relay/dataflow_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>

#include <string>
#include <vector>

namespace tvm {
namespace relay {

Expand All @@ -46,6 +49,29 @@ class DFPatternNode : public Object {
*/
class DFPattern : public ObjectRef {
public:
/*! \brief Syntatic Sugar for creating a CallPattern */
DFPattern operator()(const std::vector<DFPattern>& args);
/*! \brief Syntatic Sugar for creating a CallPattern with an "add" op */
DFPattern operator+(const DFPattern& other);
/*! \brief Syntatic Sugar for creating a CallPattern with a "subtract" op */
DFPattern operator-(const DFPattern& other);
/*! \brief Syntatic Sugar for creating a CallPattern with a "multiply" op */
DFPattern operator*(const DFPattern& other);
/*! \brief Syntatic Sugar for creating a CallPattern with a "divide" op */
DFPattern operator/(const DFPattern& other);
/*! \brief Syntatic Sugar for creating an AltPattern */
DFPattern operator||(const DFPattern& other);
/*! \brief Syntatic Sugar for creating an AttrPattern */
DFPattern HasAttr(const Map<String, ObjectRef>& attrs);
/*! \brief Syntatic Sugar for creating a TypePattern */
DFPattern HasType(const Type& type);
/*! \brief Syntatic Sugar for creating a DataTypePattern with a DataType */
DFPattern HasDtype(const DataType& dtype);
/*! \brief Syntatic Sugar for creating a DataTypePattern with a data type's name */
DFPattern HasDtype(const std::string& dtype);
/*! \brief Syntatic Sugar for creating a ShapePattern */
DFPattern HasShape(const Array<PrimExpr> shape);

TVM_DEFINE_OBJECT_REF_METHODS(DFPattern, ObjectRef, DFPatternNode);
};

Expand Down Expand Up @@ -86,28 +112,19 @@ class VarPatternNode : public DFPatternNode {
* \brief The name of the Var (optional).
*/
String name;
/*!
* \brief type annotation of the variable.
* This field records user provided type annotation of the Var.
* This field is optional and can be None.
*/
Type type_annotation;

/*! \return The name hint of the variable */
const String& name_hint() const { return name; }

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("type_annotation", &type_annotation);
}
void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); }

static constexpr const char* _type_key = "relay.dataflow_pattern.VarPattern";
TVM_DECLARE_FINAL_OBJECT_INFO(VarPatternNode, DFPatternNode);
};

class VarPattern : public DFPattern {
public:
TVM_DLL VarPattern(String name_hint, Type type_annotation);
TVM_DLL VarPattern(String name_hint);
TVM_DEFINE_OBJECT_REF_METHODS(VarPattern, DFPattern, VarPatternNode);
};

Expand Down Expand Up @@ -393,7 +410,7 @@ class AttrPatternNode : public DFPatternNode {
/*! \brief The pattern. */
DFPattern pattern;
/*! \brief The attribute to match */
Attrs attrs;
DictAttrs attrs;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("pattern", &pattern);
Expand All @@ -409,7 +426,7 @@ class AttrPatternNode : public DFPatternNode {
*/
class AttrPattern : public DFPattern {
public:
TVM_DLL AttrPattern(DFPattern pattern, Attrs attrs);
TVM_DLL AttrPattern(DFPattern pattern, DictAttrs attrs);
TVM_DEFINE_OBJECT_REF_METHODS(AttrPattern, DFPattern, AttrPatternNode);
};

Expand Down Expand Up @@ -447,6 +464,19 @@ class DominatorPattern : public DFPattern {
TVM_DEFINE_OBJECT_REF_METHODS(DominatorPattern, DFPattern, DominatorPatternNode);
};

/*! \brief Syntatic Sugar for creating a VarPattern with a name */
DFPattern IsVar(const String& name);
/*! \brief Syntatic Sugar for creating a ConstantPattern */
DFPattern IsConstant();
/*! \brief Syntatic Sugar for creating a ExprPattern */
DFPattern IsExpr(const Expr& expr);
/*! \brief Syntatic Sugar for creating a ExprPattern base on an Op*/
DFPattern IsOp(const String& op_name);
/*! \brief Syntatic Sugar for creating a TuplePattern*/
DFPattern IsTuple(const Array<DFPattern>& fields);
/*! \brief Syntatic Sugar for creating a TupleGetItemPattern*/
DFPattern IsTupleGetItem(const DFPattern tuple, int index = -1);

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_DATAFLOW_PATTERN_H_
4 changes: 2 additions & 2 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,8 +480,8 @@ class VarPattern(DFPattern):
The type annotation on the variable.
"""

def __init__(self, name_hint: str = "", type_annotation: Optional[tvm.ir.type.Type] = None):
self.__init_handle_by_constructor__(ffi.VarPattern, name_hint, type_annotation)
def __init__(self, name_hint: str = ""):
self.__init_handle_by_constructor__(ffi.VarPattern, name_hint)


@register_df_node
Expand Down
14 changes: 12 additions & 2 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,13 @@ bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) {
return val->data == rhs.operator std::string();
}
break;
case kTVMDataType:
if (auto* val = lhs.as<tir::StringImmNode>()) {
return rhs.operator std::string() == val->value;
} else if (auto* val = lhs.as<StringObj>()) {
return rhs.operator std::string() == val->data;
}
break;
case kTVMObjectHandle:
if (rhs.IsObjectRef<String>()) {
if (auto* val = lhs.as<tir::StringImmNode>()) {
Expand All @@ -140,7 +147,10 @@ bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) {
}

bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
bool matches = false;
bool matches = VisitDFPattern(attr_pattern->pattern, expr);
if (!matches) {
return matches;
}
auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
if (const auto* op_node = expr.as<OpNode>()) {
Op op = GetRef<Op>(op_node);
Expand Down Expand Up @@ -179,7 +189,7 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons
}
}
}
return matches && VisitDFPattern(attr_pattern->pattern, expr);
return matches;
}

Array<DFPattern> reverse(const Array<DFPattern>& args) {
Expand Down
67 changes: 53 additions & 14 deletions src/relay/ir/dataflow_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* \brief The dataflow pattern language for Relay.
*/
#include <tvm/relay/dataflow_pattern.h>
#include <tvm/runtime/data_type.h>

namespace tvm {
namespace relay {
Expand All @@ -44,29 +45,22 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->Print(node->expr);
});

VarPattern::VarPattern(String name_hint, Type type_annotation) {
VarPattern::VarPattern(String name_hint) {
ObjectPtr<VarPatternNode> n = make_object<VarPatternNode>();
n->name = std::move(name_hint);
n->type_annotation = std::move(type_annotation);
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(VarPatternNode);

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.VarPattern")
.set_body_typed([](String name_hint, Type type_annotation) {
return VarPattern(name_hint, type_annotation);
});
TVM_REGISTER_GLOBAL("relay.dataflow_pattern.VarPattern").set_body_typed([](String name_hint) {
return VarPattern(name_hint);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<VarPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const VarPatternNode*>(ref.get());
p->stream << "VarPattern(" << node->name_hint();
if (node->type_annotation.defined()) {
p->stream << ", ty=";
p->Print(node->type_annotation);
}
p->stream << ")";
p->stream << "VarPattern(" << node->name_hint() << ")";
});

TVM_REGISTER_NODE_TYPE(ConstantPatternNode);
Expand Down Expand Up @@ -241,7 +235,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "TypePattern(" << node->pattern << " has dtype " << node->dtype << ")";
});

AttrPattern::AttrPattern(DFPattern pattern, Attrs attrs) {
AttrPattern::AttrPattern(DFPattern pattern, DictAttrs attrs) {
ObjectPtr<AttrPatternNode> n = make_object<AttrPatternNode>();
n->pattern = std::move(pattern);
n->attrs = std::move(attrs);
Expand All @@ -251,7 +245,7 @@ AttrPattern::AttrPattern(DFPattern pattern, Attrs attrs) {
TVM_REGISTER_NODE_TYPE(AttrPatternNode);

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.AttrPattern")
.set_body_typed([](DFPattern pattern, Attrs attrs) { return AttrPattern(pattern, attrs); });
.set_body_typed([](DFPattern pattern, DictAttrs attrs) { return AttrPattern(pattern, attrs); });

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<AttrPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
Expand All @@ -263,6 +257,7 @@ DominatorPattern::DominatorPattern(DFPattern parent, DFPattern path, DFPattern c
ObjectPtr<DominatorPatternNode> n = make_object<DominatorPatternNode>();
n->parent = std::move(parent);
n->path = std::move(path);

n->child = std::move(child);
data_ = std::move(n);
}
Expand All @@ -281,5 +276,49 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< ")";
});

// Syntatic Sugar
DFPattern DFPattern::operator()(const std::vector<DFPattern>& args) {
return CallPattern(GetRef<DFPattern>(this->get()), Array<DFPattern>(args));
}
DFPattern DFPattern::operator+(const DFPattern& other) {
return IsOp("add")({GetRef<DFPattern>(this->get()), other});
}
DFPattern DFPattern::operator-(const DFPattern& other) {
return IsOp("subtract")({GetRef<DFPattern>(this->get()), other});
}
DFPattern DFPattern::operator*(const DFPattern& other) {
return IsOp("multiply")({GetRef<DFPattern>(this->get()), other});
}
DFPattern DFPattern::operator/(const DFPattern& other) {
return IsOp("divide")({GetRef<DFPattern>(this->get()), other});
}
DFPattern DFPattern::operator||(const DFPattern& other) {
return AltPattern(GetRef<DFPattern>(this->get()), other);
}

DFPattern DFPattern::HasAttr(const Map<String, ObjectRef>& attrs) {
return AttrPattern(GetRef<DFPattern>(this->get()), DictAttrs(attrs));
}
DFPattern DFPattern::HasType(const Type& type) {
return TypePattern(GetRef<DFPattern>(this->get()), type);
}
DFPattern DFPattern::HasDtype(const DataType& dtype) {
return DataTypePattern(GetRef<DFPattern>(this->get()), dtype);
}
DFPattern DFPattern::HasDtype(const std::string& dtype) {
return HasDtype(DataType(runtime::String2DLDataType(dtype)));
}
DFPattern DFPattern::HasShape(const Array<PrimExpr> shape) {
return ShapePattern(GetRef<DFPattern>(this->get()), shape);
}
DFPattern IsVar(const String& name) { return VarPattern(name); }
DFPattern IsConstant() { return ConstantPattern(make_object<ConstantPatternNode>()); }
DFPattern IsExpr(const Expr& expr) { return ExprPattern(expr); }
DFPattern IsOp(const String& op_name) { return IsExpr(Op::Get(op_name)); }
DFPattern IsTuple(const Array<DFPattern>& fields) { return TuplePattern(fields); }
DFPattern IsTupleGetItem(const DFPattern tuple, int index) {
return TupleGetItemPattern(tuple, index);
}

} // namespace relay
} // namespace tvm
9 changes: 3 additions & 6 deletions src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@
namespace tvm {
namespace relay {

static Op reshape_op = Op::Get("reshape");
static Op reverse_reshape_op = Op::Get("contrib_reverse_reshape");

/*!
* \brief SimplifyReshape matches the pattern of consecutive reshape or reverse_reshape ops,
* and merges into one reshape op.
Expand All @@ -44,9 +41,9 @@ class SimplifyReshape {
public:
SimplifyReshape() {
x_ = WildcardPattern(make_object<WildcardPatternNode>());
auto reshape1 = AltPattern(ExprPattern(reshape_op), ExprPattern(reverse_reshape_op));
auto reshape2 = AltPattern(ExprPattern(reshape_op), ExprPattern(reverse_reshape_op));
pattern_ = CallPattern(reshape1, {CallPattern(reshape2, {x_})});
auto reshape1 = IsOp("reshape") || IsOp("contrib_reverse_reshape");
auto reshape2 = IsOp("reshape") || IsOp("contrib_reverse_reshape");
pattern_ = reshape1({reshape2({x_})});
}

Expr callback(const Expr& pre, const Expr& post, const Map<DFPattern, Array<Expr>>& node_map) {
Expand Down
Loading

0 comments on commit 55702de

Please sign in to comment.