Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify full broadcast #7423

Merged
merged 6 commits into from
Feb 10, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/tvm/relay/dataflow_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,8 @@ class DominatorPattern : public DFPattern {
DFPattern IsVar(const String& name);
/*! \brief Syntatic Sugar for creating a ConstantPattern */
DFPattern IsConstant();
/*! \brief Syntatic Sugar for creating a WildcardPattern */
DFPattern IsWildcard();
/*! \brief Syntatic Sugar for creating a ExprPattern */
DFPattern IsExpr(const Expr& expr);
/*! \brief Syntatic Sugar for creating a ExprPattern base on an Op*/
Expand Down
8 changes: 7 additions & 1 deletion src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,12 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons
if (Op::HasAttrMap(attr_name)) {
auto op_map = Op::GetAttrMap<TVMRetValue>(attr_name);
if (op_map.count(op)) {
matches = MatchRetValue(attr_value, op_map[op]);
matches &= MatchRetValue(attr_value, op_map[op]);
} else {
matches = false;
}
} else {
matches = false;
}
}
} else if (auto* op = expr.as<CallNode>()) {
Expand Down Expand Up @@ -196,6 +200,8 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons
break;
}
}
} else {
matches = false;
}
return matches;
}
Expand Down
1 change: 1 addition & 0 deletions src/relay/ir/dataflow_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ DFPattern DFPattern::HasShape(const Array<PrimExpr> shape) {
}
DFPattern IsVar(const String& name) { return VarPattern(name); }
DFPattern IsConstant() { return ConstantPattern(make_object<ConstantPatternNode>()); }
DFPattern IsWildcard() { return WildcardPattern(make_object<WildcardPatternNode>()); }
DFPattern IsExpr(const Expr& expr) { return ExprPattern(expr); }
DFPattern IsOp(const String& op_name) { return IsExpr(Op::Get(op_name)); }
DFPattern IsTuple(const Array<DFPattern>& fields) { return TuplePattern(fields); }
Expand Down
6 changes: 6 additions & 0 deletions src/relay/op/make_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ Expr MakeResize(Expr data, Array<IndexExpr> size, String layout, String method,

Expr MakeSparseToDense(Expr indices, Array<Integer> output_shape, Expr values, Expr default_value);

Expr MakeArange(Expr start, Expr stop, Expr step, DataType dtype);

Expr MakeShapeOf(Expr data, DataType dtype);

Expr MakeTake(Expr data, Expr indices, Integer axis, String mode);

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_MAKE_OP_H_
6 changes: 4 additions & 2 deletions src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -430,12 +430,14 @@ Array<te::Tensor> ShapeOfCompute(const Attrs& attrs, const Array<te::Tensor>& in
return {topi::shape(inputs[0], param->dtype)};
}

TVM_REGISTER_GLOBAL("relay.op._make.shape_of").set_body_typed([](Expr data, DataType dtype) {
Expr MakeShapeOf(Expr data, DataType dtype) {
auto attrs = make_object<ShapeOfAttrs>();
attrs->dtype = dtype;
static const Op& op = Op::Get("shape_of");
return Call(op, {data}, Attrs(attrs), {});
});
}

TVM_REGISTER_GLOBAL("relay.op._make.shape_of").set_body_typed(MakeShapeOf);

RELAY_REGISTER_OP("shape_of")
.describe(R"code(Returns a tensor representing the shape of a tensor.
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class ConstantFolder : public MixedModeMutator {
}
static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");

std::unordered_set<std::string> skip_list{"zeros_like", "ones_like", "full_like", "full"};
std::unordered_set<std::string> skip_list{};
mbrookhart marked this conversation as resolved.
Show resolved Hide resolved

auto origin_args = call->args;
call = post.as<CallNode>();
Expand Down
111 changes: 98 additions & 13 deletions src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,38 @@
#include <tvm/support/logging.h>

#include "../op/tensor/transform.h"
#include "pattern_utils.h"

namespace tvm {
namespace relay {

class SimplifyPattern {
public:
virtual Expr callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const = 0;

DFPattern pattern() const { return pattern_; }

protected:
/*! \brief Pattern for rewriting */
DFPattern pattern_;
};

/*!
* \brief SimplifyReshape matches the pattern of consecutive reshape or reverse_reshape ops,
* and merges into one reshape op.
*/
class SimplifyReshape {
class SimplifyReshape : public SimplifyPattern {
public:
SimplifyReshape() {
x_ = WildcardPattern(make_object<WildcardPatternNode>());
x_ = IsWildcard();
auto reshape1 = IsOp("reshape") || IsOp("contrib_reverse_reshape");
auto reshape2 = IsOp("reshape") || IsOp("contrib_reverse_reshape");
pattern_ = reshape1({reshape2({x_})});
}

Expr callback(const Expr& pre, const Expr& post, const Map<DFPattern, Array<Expr>>& node_map) {
Expr callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
auto x = node_map[x_][0];
bool const_shape = true;
Array<Integer> newshape;
Expand All @@ -63,13 +77,82 @@ class SimplifyReshape {
return post;
}

DFPattern pattern() const { return pattern_; }

private:
/*! \brief Pattern input */
DFPattern x_;
/*! \brief Pattern for consecutive reshape or reverse_reshape ops */
DFPattern pattern_;
};

/*!
* \brief FullArgwhere finds full followed by argwhere and turns it into an Arange op
*/
class FullElementwise : public SimplifyPattern {
public:
FullElementwise() {
x_ = IsWildcard();
data_ = IsWildcard();
value_ = IsConstant();

full_ = IsOp("full")({value_}) || IsOp("full_like")({data_, value_});
ones_ = IsOp("ones")({}) || IsOp("ones_like")({data_});
zeros_ = IsOp("zeros")({}) || IsOp("zeros_like")({data_});

Map<String, ObjectRef> attrs;
attrs.Set("TOpPattern", Integer(static_cast<int>(kBroadcast)));
DFPattern op = IsWildcard().HasAttr(attrs);
DFPattern full = full_ || ones_ || zeros_;
pattern_ = op({full, x_}) || op({x_, full});
}

Expr callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
const CallNode* call = pre.as<CallNode>();
ICHECK(call);
Type pre_type = pre->checked_type_;
ICHECK(pre_type.as<TensorTypeNode>());
auto dtype = pre_type.as<TensorTypeNode>()->dtype;
auto x = node_map[x_][0];
bool is_left = post.as<CallNode>()->args[1] == x;
Type x_type;
if (is_left) {
x_type = call->args[1]->checked_type_;
} else {
x_type = call->args[0]->checked_type_;
}

if (StructuralEqual()(x_type, pre_type)) {
Expr value;
if (node_map.count(full_)) {
value = node_map[value_][0];
ICHECK(IsConstScalar(value));
} else if (node_map.count(ones_)) {
value = MakeConstantScalar(dtype, 1);
} else if (node_map.count(zeros_)) {
value = MakeConstantScalar(dtype, 0);
} else {
ICHECK(false) << "Didn't find a full op while matching full + elementwise";
}
if (is_left) {
return Call(call->op, {value, x}, call->attrs, call->type_args, call->span);
} else {
return Call(call->op, {x, value}, call->attrs, call->type_args, call->span);
}
}
return post;
}

private:
/*! \brief binary argument */
DFPattern x_;
/*! \brief data ops get shape from */
DFPattern data_;
/*! \brief constant input */
DFPattern value_;
/*! \brief full op */
DFPattern full_;
/*! \brief ones op */
DFPattern ones_;
/*! \brief zeros op */
DFPattern zeros_;
};

/*!
Expand All @@ -78,22 +161,24 @@ class SimplifyReshape {
class ExprSimplifier {
public:
explicit ExprSimplifier(IRModule mod) : mod_(mod) {
auto reshape_func = [this](TVMArgs args, TVMRetValue* rv) {
CreateCallback(SimplifyReshape());
CreateCallback(FullElementwise());
}
template <typename T>
void CreateCallback(const T& pattern) {
auto func = [pattern](TVMArgs args, TVMRetValue* rv) {
Expr pre = args[0];
Expr post = args[1];
Map<DFPattern, Array<Expr>> node_map = args[2];
*rv = simplify_reshape_.callback(pre, post, node_map);
*rv = pattern.callback(pre, post, node_map);
};
callbacks_.push_back(
DFPatternCallback(simplify_reshape_.pattern(), PackedFunc(reshape_func), true));
callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true));
}

Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); }

private:
IRModule mod_;
/*! \brief Simplify reshape pattern */
SimplifyReshape simplify_reshape_;
/*! \brief Callbacks for expr simplification */
Array<DFPatternCallback> callbacks_;
};
Expand Down
2 changes: 2 additions & 0 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,8 @@ def test_no_match_op_attr():
x = relay.var("x")
y = relay.var("y")
assert not op_pat.match(x - y)
z = relay.var("z")
assert not op_pat.match(relay.Let(z, x + y, z))


def test_match_func_attr():
Expand Down
16 changes: 0 additions & 16 deletions tests/python/relay/test_pass_fold_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,22 +231,6 @@ def expected(dtype):
assert tvm.ir.structural_equal(zz, zexpected)


def test_fold_full():
c_shape = (8, 9, 10)

def before():
dtype = "float32"
return relay.full(relay.const(1.0, dtype), c_shape, dtype=dtype)

def expected():
# expect no changes
return before()

zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(zz, zexpected)


def test_fold_batch_norm():
def expected():
data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
Expand Down
65 changes: 65 additions & 0 deletions tests/python/relay/test_pass_simplify_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,70 @@ def symbolic():
assert tvm.ir.structural_equal(zz, after)


def test_simplify_full_elementwise():
def validate(shape, value, dtype):
def before_left(x, elem_op, full):
return elem_op(full, x)

def after_left(x, elem_op, value):
return elem_op(relay.const(value, dtype), x)

def before_right(x, elem_op, full):
return elem_op(x, full)

def after_right(x, elem_op, value):
return elem_op(x, relay.const(value, dtype))

x = relay.var("x", shape=shape, dtype=dtype)
elem_ops = [relay.add, relay.multiply, relay.subtract, relay.divide]
full_ops = []
if value == 0:
full_ops.append(relay.zeros(shape, dtype))
full_ops.append(relay.zeros_like(x))
if value == 1:
full_ops.append(relay.ones(shape, dtype))
full_ops.append(relay.ones_like(x))
else:
full_ops.append(relay.full(relay.const(value, dtype), shape))
full_ops.append(relay.full_like(x, relay.const(value, dtype)))
for op in elem_ops:
for full in full_ops:
z = before_left(x, op, full)
zz = run_opt_pass(z, transform.SimplifyExpr())
after = run_opt_pass(after_left(x, op, value), transform.InferType())
assert tvm.ir.structural_equal(zz, after)

z = before_right(x, op, full)
zz = run_opt_pass(z, transform.SimplifyExpr())
after = run_opt_pass(after_right(x, op, value), transform.InferType())
assert tvm.ir.structural_equal(zz, after)

# Test the case in which x is broadcast to full's shape
full_ops = []
if value == 0:
full_ops.append(relay.zeros(shape * 2, dtype))
if value == 1:
full_ops.append(relay.ones(shape * 2, dtype))
else:
full_ops.append(relay.full(relay.const(value, dtype), shape * 2))
for op in elem_ops:
for full in full_ops:
z = before_left(x, op, full)
zz = run_opt_pass(z, transform.SimplifyExpr())
after = run_opt_pass(before_left(x, op, full), transform.InferType())
assert tvm.ir.structural_equal(zz, after)

z = before_right(x, op, full)
zz = run_opt_pass(z, transform.SimplifyExpr())
after = run_opt_pass(before_right(x, op, full), transform.InferType())
assert tvm.ir.structural_equal(zz, after)

for shape in [[10], [10, 10], [10, 10, 10]]:
for dtype in ["float32", "int32"]:
for value in [0, 1, 2]:
validate(shape, value, dtype)


if __name__ == "__main__":
test_simplify_reshape()
test_simplify_full_elementwise()