From ef123c80197bcaf5edf28fb9055b19818bf48cf7 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Wed, 10 Feb 2021 12:54:43 -0700 Subject: [PATCH] Simplify full broadcast (#7423) * convert argwhere(full(const)) to reshape(arange()) * Add IsWildcard syntatic sugar * add a simplify expression to fold full into broadcast ops * Allow constant folding of full-like ops after SimplifyExpr * fix a bug with the Attr Pattern matching * remove skip_list --- include/tvm/relay/dataflow_pattern.h | 2 + src/relay/ir/dataflow_matcher.cc | 8 +- src/relay/ir/dataflow_pattern.cc | 1 + src/relay/op/make_op.h | 6 + src/relay/op/tensor/unary.cc | 6 +- src/relay/transforms/fold_constant.cc | 5 - src/relay/transforms/simplify_expr.cc | 111 ++++++++++++++++-- tests/python/relay/test_dataflow_pattern.py | 2 + tests/python/relay/test_pass_fold_constant.py | 16 --- tests/python/relay/test_pass_simplify_expr.py | 65 ++++++++++ 10 files changed, 185 insertions(+), 37 deletions(-) diff --git a/include/tvm/relay/dataflow_pattern.h b/include/tvm/relay/dataflow_pattern.h index 1e6cecfd041b..99ef9a237de2 100644 --- a/include/tvm/relay/dataflow_pattern.h +++ b/include/tvm/relay/dataflow_pattern.h @@ -524,6 +524,8 @@ class DominatorPattern : public DFPattern { DFPattern IsVar(const String& name); /*! \brief Syntatic Sugar for creating a ConstantPattern */ DFPattern IsConstant(); +/*! \brief Syntatic Sugar for creating a WildcardPattern */ +DFPattern IsWildcard(); /*! \brief Syntatic Sugar for creating a ExprPattern */ DFPattern IsExpr(const Expr& expr); /*! \brief Syntatic Sugar for creating a ExprPattern base on an Op*/ diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index a43f50f600df..ac716579f2ab 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -162,8 +162,12 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons if (Op::HasAttrMap(attr_name)) { auto op_map = Op::GetAttrMap(attr_name); if (op_map.count(op)) { - matches = MatchRetValue(attr_value, op_map[op]); + matches &= MatchRetValue(attr_value, op_map[op]); + } else { + matches = false; } + } else { + matches = false; } } } else if (auto* op = expr.as()) { @@ -196,6 +200,8 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons break; } } + } else { + matches = false; } return matches; } diff --git a/src/relay/ir/dataflow_pattern.cc b/src/relay/ir/dataflow_pattern.cc index 4c3b82cc19d4..9c65c490d855 100644 --- a/src/relay/ir/dataflow_pattern.cc +++ b/src/relay/ir/dataflow_pattern.cc @@ -357,6 +357,7 @@ DFPattern DFPattern::HasShape(const Array shape) { } DFPattern IsVar(const String& name) { return VarPattern(name); } DFPattern IsConstant() { return ConstantPattern(make_object()); } +DFPattern IsWildcard() { return WildcardPattern(make_object()); } DFPattern IsExpr(const Expr& expr) { return ExprPattern(expr); } DFPattern IsOp(const String& op_name) { return IsExpr(Op::Get(op_name)); } DFPattern IsTuple(const Array& fields) { return TuplePattern(fields); } diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index 2b05290b270c..79f7e135e29d 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -100,6 +100,12 @@ Expr MakeResize(Expr data, Array size, String layout, String method, Expr MakeSparseToDense(Expr indices, Array output_shape, Expr values, Expr default_value); +Expr MakeArange(Expr start, Expr stop, Expr step, DataType dtype); + +Expr MakeShapeOf(Expr data, DataType dtype); + +Expr MakeTake(Expr data, Expr indices, Integer axis, String mode); + } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_MAKE_OP_H_ diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index e17bdc0e0906..3e82b92a5f03 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -430,12 +430,14 @@ Array ShapeOfCompute(const Attrs& attrs, const Array& in return {topi::shape(inputs[0], param->dtype)}; } -TVM_REGISTER_GLOBAL("relay.op._make.shape_of").set_body_typed([](Expr data, DataType dtype) { +Expr MakeShapeOf(Expr data, DataType dtype) { auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("shape_of"); return Call(op, {data}, Attrs(attrs), {}); -}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.shape_of").set_body_typed(MakeShapeOf); RELAY_REGISTER_OP("shape_of") .describe(R"code(Returns a tensor representing the shape of a tensor. diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 657d4db993b0..4454c9c0459a 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -148,8 +148,6 @@ class ConstantFolder : public MixedModeMutator { } static auto op_stateful = Op::GetAttrMap("TOpIsStateful"); - std::unordered_set skip_list{"zeros_like", "ones_like", "full_like", "full"}; - auto origin_args = call->args; call = post.as(); // We don't constant fold function with zero arguments. @@ -158,9 +156,6 @@ class ConstantFolder : public MixedModeMutator { if (call->args.size() == 0) return post; const OpNode* op = call->op.as(); if (op == nullptr) return post; - if (skip_list.count(op->name)) { - return post; - } // skip stateful ops. if (op_stateful.get(GetRef(op), false)) return post; // Try to evaluate shape_of op diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index 0f78c260378c..74e48dc4bc54 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -29,24 +29,38 @@ #include #include "../op/tensor/transform.h" +#include "pattern_utils.h" namespace tvm { namespace relay { +class SimplifyPattern { + public: + virtual Expr callback(const Expr& pre, const Expr& post, + const Map>& node_map) const = 0; + + DFPattern pattern() const { return pattern_; } + + protected: + /*! \brief Pattern for rewriting */ + DFPattern pattern_; +}; + /*! * \brief SimplifyReshape matches the pattern of consecutive reshape or reverse_reshape ops, * and merges into one reshape op. */ -class SimplifyReshape { +class SimplifyReshape : public SimplifyPattern { public: SimplifyReshape() { - x_ = WildcardPattern(make_object()); + x_ = IsWildcard(); auto reshape1 = IsOp("reshape") || IsOp("contrib_reverse_reshape"); auto reshape2 = IsOp("reshape") || IsOp("contrib_reverse_reshape"); pattern_ = reshape1({reshape2({x_})}); } - Expr callback(const Expr& pre, const Expr& post, const Map>& node_map) { + Expr callback(const Expr& pre, const Expr& post, + const Map>& node_map) const override { auto x = node_map[x_][0]; bool const_shape = true; Array newshape; @@ -63,13 +77,82 @@ class SimplifyReshape { return post; } - DFPattern pattern() const { return pattern_; } - private: /*! \brief Pattern input */ DFPattern x_; - /*! \brief Pattern for consecutive reshape or reverse_reshape ops */ - DFPattern pattern_; +}; + +/*! + * \brief FullArgwhere finds full followed by argwhere and turns it into an Arange op + */ +class FullElementwise : public SimplifyPattern { + public: + FullElementwise() { + x_ = IsWildcard(); + data_ = IsWildcard(); + value_ = IsConstant(); + + full_ = IsOp("full")({value_}) || IsOp("full_like")({data_, value_}); + ones_ = IsOp("ones")({}) || IsOp("ones_like")({data_}); + zeros_ = IsOp("zeros")({}) || IsOp("zeros_like")({data_}); + + Map attrs; + attrs.Set("TOpPattern", Integer(static_cast(kBroadcast))); + DFPattern op = IsWildcard().HasAttr(attrs); + DFPattern full = full_ || ones_ || zeros_; + pattern_ = op({full, x_}) || op({x_, full}); + } + + Expr callback(const Expr& pre, const Expr& post, + const Map>& node_map) const override { + const CallNode* call = pre.as(); + ICHECK(call); + Type pre_type = pre->checked_type_; + ICHECK(pre_type.as()); + auto dtype = pre_type.as()->dtype; + auto x = node_map[x_][0]; + bool is_left = post.as()->args[1] == x; + Type x_type; + if (is_left) { + x_type = call->args[1]->checked_type_; + } else { + x_type = call->args[0]->checked_type_; + } + + if (StructuralEqual()(x_type, pre_type)) { + Expr value; + if (node_map.count(full_)) { + value = node_map[value_][0]; + ICHECK(IsConstScalar(value)); + } else if (node_map.count(ones_)) { + value = MakeConstantScalar(dtype, 1); + } else if (node_map.count(zeros_)) { + value = MakeConstantScalar(dtype, 0); + } else { + ICHECK(false) << "Didn't find a full op while matching full + elementwise"; + } + if (is_left) { + return Call(call->op, {value, x}, call->attrs, call->type_args, call->span); + } else { + return Call(call->op, {x, value}, call->attrs, call->type_args, call->span); + } + } + return post; + } + + private: + /*! \brief binary argument */ + DFPattern x_; + /*! \brief data ops get shape from */ + DFPattern data_; + /*! \brief constant input */ + DFPattern value_; + /*! \brief full op */ + DFPattern full_; + /*! \brief ones op */ + DFPattern ones_; + /*! \brief zeros op */ + DFPattern zeros_; }; /*! @@ -78,22 +161,24 @@ class SimplifyReshape { class ExprSimplifier { public: explicit ExprSimplifier(IRModule mod) : mod_(mod) { - auto reshape_func = [this](TVMArgs args, TVMRetValue* rv) { + CreateCallback(SimplifyReshape()); + CreateCallback(FullElementwise()); + } + template + void CreateCallback(const T& pattern) { + auto func = [pattern](TVMArgs args, TVMRetValue* rv) { Expr pre = args[0]; Expr post = args[1]; Map> node_map = args[2]; - *rv = simplify_reshape_.callback(pre, post, node_map); + *rv = pattern.callback(pre, post, node_map); }; - callbacks_.push_back( - DFPatternCallback(simplify_reshape_.pattern(), PackedFunc(reshape_func), true)); + callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true)); } Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); } private: IRModule mod_; - /*! \brief Simplify reshape pattern */ - SimplifyReshape simplify_reshape_; /*! \brief Callbacks for expr simplification */ Array callbacks_; }; diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index b39c03a6160e..a8e4b65f1bc6 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -437,6 +437,8 @@ def test_no_match_op_attr(): x = relay.var("x") y = relay.var("y") assert not op_pat.match(x - y) + z = relay.var("z") + assert not op_pat.match(relay.Let(z, x + y, z)) def test_match_func_attr(): diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py index 76182d2c3e08..14ad419e80c6 100644 --- a/tests/python/relay/test_pass_fold_constant.py +++ b/tests/python/relay/test_pass_fold_constant.py @@ -231,22 +231,6 @@ def expected(dtype): assert tvm.ir.structural_equal(zz, zexpected) -def test_fold_full(): - c_shape = (8, 9, 10) - - def before(): - dtype = "float32" - return relay.full(relay.const(1.0, dtype), c_shape, dtype=dtype) - - def expected(): - # expect no changes - return before() - - zz = run_opt_pass(before(), transform.FoldConstant()) - zexpected = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(zz, zexpected) - - def test_fold_batch_norm(): def expected(): data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32")) diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index b57abc6942d7..3d925bcfc759 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -58,5 +58,70 @@ def symbolic(): assert tvm.ir.structural_equal(zz, after) +def test_simplify_full_elementwise(): + def validate(shape, value, dtype): + def before_left(x, elem_op, full): + return elem_op(full, x) + + def after_left(x, elem_op, value): + return elem_op(relay.const(value, dtype), x) + + def before_right(x, elem_op, full): + return elem_op(x, full) + + def after_right(x, elem_op, value): + return elem_op(x, relay.const(value, dtype)) + + x = relay.var("x", shape=shape, dtype=dtype) + elem_ops = [relay.add, relay.multiply, relay.subtract, relay.divide] + full_ops = [] + if value == 0: + full_ops.append(relay.zeros(shape, dtype)) + full_ops.append(relay.zeros_like(x)) + if value == 1: + full_ops.append(relay.ones(shape, dtype)) + full_ops.append(relay.ones_like(x)) + else: + full_ops.append(relay.full(relay.const(value, dtype), shape)) + full_ops.append(relay.full_like(x, relay.const(value, dtype))) + for op in elem_ops: + for full in full_ops: + z = before_left(x, op, full) + zz = run_opt_pass(z, transform.SimplifyExpr()) + after = run_opt_pass(after_left(x, op, value), transform.InferType()) + assert tvm.ir.structural_equal(zz, after) + + z = before_right(x, op, full) + zz = run_opt_pass(z, transform.SimplifyExpr()) + after = run_opt_pass(after_right(x, op, value), transform.InferType()) + assert tvm.ir.structural_equal(zz, after) + + # Test the case in which x is broadcast to full's shape + full_ops = [] + if value == 0: + full_ops.append(relay.zeros(shape * 2, dtype)) + if value == 1: + full_ops.append(relay.ones(shape * 2, dtype)) + else: + full_ops.append(relay.full(relay.const(value, dtype), shape * 2)) + for op in elem_ops: + for full in full_ops: + z = before_left(x, op, full) + zz = run_opt_pass(z, transform.SimplifyExpr()) + after = run_opt_pass(before_left(x, op, full), transform.InferType()) + assert tvm.ir.structural_equal(zz, after) + + z = before_right(x, op, full) + zz = run_opt_pass(z, transform.SimplifyExpr()) + after = run_opt_pass(before_right(x, op, full), transform.InferType()) + assert tvm.ir.structural_equal(zz, after) + + for shape in [[10], [10, 10], [10, 10, 10]]: + for dtype in ["float32", "int32"]: + for value in [0, 1, 2]: + validate(shape, value, dtype) + + if __name__ == "__main__": test_simplify_reshape() + test_simplify_full_elementwise()