Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Pass] Update SimplifyTranspose to correctly simplify rank changing layout transforms #7807

Merged
merged 5 commits into from
May 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 144 additions & 31 deletions src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
#include <tvm/runtime/logging.h>

#include <limits>
#include <memory>
#include <string>
#include <utility>

#include "../op/tensor/transform.h"
Expand Down Expand Up @@ -117,36 +119,20 @@ class SimplifyTranspose : public DFPatternRewrite {

Expr Callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
// Helper function to get the axes from call node attribute
auto get_axes_from_call = [](const Call trans_call, int ndim) {
std::vector<int> attr_axes;
if (auto attr = trans_call->attrs.as<TransposeAttrs>()) {
if (attr->axes.defined()) {
for (int i = 0; i < ndim; ++i) {
int64_t axis = attr->axes[i];
axis += (axis < 0) ? ndim : 0;
attr_axes.push_back(axis);
}
} else {
// Empty axes means reverse
for (int i = ndim - 1; i >= 0; --i) {
attr_axes.push_back(i);
}
}
} else if (auto attr = trans_call->attrs.as<LayoutTransformAttrs>()) {
Layout src_layout(attr->src_layout);
Layout dst_layout(attr->dst_layout);
for (int i = 0; i < ndim; ++i) {
attr_axes.push_back(src_layout.IndexOf(dst_layout[i]));
auto x = node_map[x_][0];

Call trans_call = Downcast<Call>(post);

// Try to fuse any rank changing layout transformations
if (auto layout_trans = FoldRankChangingLayoutTrans(x, trans_call)) {
csullivan marked this conversation as resolved.
Show resolved Hide resolved
if (auto attr = layout_trans.value()->attrs.as<LayoutTransformAttrs>()) {
// Prune any trivial layout transformation
if (attr->src_layout == attr->dst_layout) {
return x;
}
} else {
CHECK(false) << "Expected transpose or layout_transform, but got "
<< Downcast<Op>(trans_call->op)->name;
}
return std::move(attr_axes);
};

auto x = node_map[x_][0];
return layout_trans.value();
}

// Initialize axes
int ndim = Downcast<TensorType>(pre->checked_type())->shape.size();
Expand All @@ -157,10 +143,9 @@ class SimplifyTranspose : public DFPatternRewrite {

// Collect axes changes from the matched pattern, including two consecutive transposes.
std::vector<std::vector<int>> interm_axes;
Call trans_call = Downcast<Call>(post);
interm_axes.push_back(get_axes_from_call(trans_call, ndim));
interm_axes.push_back(GetTransposeAxisOrder(trans_call, ndim));
trans_call = Downcast<Call>(trans_call->args[0]);
interm_axes.push_back(get_axes_from_call(trans_call, ndim));
interm_axes.push_back(GetTransposeAxisOrder(trans_call, ndim));

// Calculate the final axes in reverse order (from root to output)
auto it = interm_axes.rbegin();
Expand Down Expand Up @@ -190,6 +175,134 @@ class SimplifyTranspose : public DFPatternRewrite {
return x;
}

String PermuteLayout(const String& layout, std::vector<int> axes_order) const {
std::string new_layout{};
csullivan marked this conversation as resolved.
Show resolved Hide resolved
std::string old_layout{layout};
ICHECK_EQ(axes_order.size(), layout.size())
<< "Number of axes must match the number of named axes in the layout to permute: length("
<< old_layout << ") != " << axes_order.size();
std::stringstream order;
for (auto axis : axes_order) {
new_layout += old_layout[axis];
order << axis << ", ";
}
DLOG(INFO) << "Using transpose axes order {" << order.str()
<< "} to permute layout: " << old_layout << " to " << new_layout;
return new_layout;
}

struct RankChangingLayoutDescriptor {
Layout src_layout;
Layout dst_layout;
// Either a rank changing layout transform or a transpose
Call other_transform;
};

std::unique_ptr<RankChangingLayoutDescriptor> GetRankChangeDescriptor(const Call& call) const {
std::unique_ptr<RankChangingLayoutDescriptor> desc{nullptr};
if (auto attr = call->attrs.as<LayoutTransformAttrs>()) {
if (attr->src_layout.length() != attr->dst_layout.length()) {
desc = std::make_unique<RankChangingLayoutDescriptor>();
desc->src_layout = Layout(attr->src_layout);
desc->dst_layout = Layout(attr->dst_layout);
desc->other_transform = Downcast<Call>(call->args[0]);
}
}
if (auto attr = Downcast<Call>(call->args[0])->attrs.as<LayoutTransformAttrs>()) {
if (attr->src_layout.length() != attr->dst_layout.length()) {
if (!desc) {
desc = std::make_unique<RankChangingLayoutDescriptor>();
desc->src_layout = Layout(attr->src_layout);
desc->dst_layout = Layout(attr->dst_layout);
desc->other_transform = call;
} else {
ICHECK(desc->src_layout->name == attr->dst_layout)
<< "Back-to-back layout transforms must have the same intermediate layout: "
<< desc->src_layout->name << " != " << attr->dst_layout;
desc->src_layout = Layout(attr->src_layout);
}
}
}
return desc;
}

/*
* \brief Fuse call and it's argument into a single layout_transform operator
* when either call or it's argument is a rang changing layout_transform, e.g.,
*
* Simplify
*
* [N, H, W, C] -> Transpose -> [N, C, H, W] -> LayoutTrans -> [N, C, H, W, 4c]
*
* to,
*
* [N, H, W, C] -> LayoutTrans -> [N, C, H, W, 4c].
*
* \param The input expression to the matched pattern
* \param The pattern root; the second of two consecutive Transpose/LayoutTransform ops
*/
Optional<Call> FoldRankChangingLayoutTrans(const Expr& data, const Call& call) const {
// Check to see if either the first or second call in matched pattern
// is a rank changing layout transform. If so, return a descriptor containing
// the layouts and any additional transpose or layout transform op.
auto desc = GetRankChangeDescriptor(call);
if (desc == nullptr) {
// No rank changing layout transform
return Optional<Call>{nullptr};
}

Optional<Expr> output_layout_trans;
// Fuse a rank increasing layout transform and a preceeding transpose
if (desc->src_layout->axes.size() < desc->dst_layout->axes.size()) {
auto axes = GetTransposeAxisOrder(desc->other_transform, desc->src_layout->axes.size());
// Calculate the reverse axis order and apply to the source layout
std::vector<int> inverse(axes.size());
for (size_t i = 0; i < axes.size(); i++) {
inverse[axes[i]] = i;
}
String new_layout = PermuteLayout(desc->src_layout->name, inverse);
output_layout_trans = MakeLayoutTransform(data, new_layout, desc->dst_layout->name);
// Fuse a rank descreasing layout transform followed by a transpose
} else if (desc->src_layout->axes.size() > desc->dst_layout->axes.size()) {
auto axes = GetTransposeAxisOrder(desc->other_transform, desc->dst_layout->axes.size());
String new_layout = PermuteLayout(desc->dst_layout->name, axes);
output_layout_trans = MakeLayoutTransform(data, desc->src_layout->name, new_layout);
// Fuse two back-to-back layout transformations which change rank
} else if (desc->other_transform->attrs.as<LayoutTransformAttrs>()) {
output_layout_trans =
MakeLayoutTransform(data, desc->src_layout->name, desc->dst_layout->name);
}
return Downcast<Call>(output_layout_trans);
}

std::vector<int> GetTransposeAxisOrder(const Call& call, int ndim) const {
std::vector<int> attr_axes;
if (auto attr = call->attrs.as<TransposeAttrs>()) {
if (attr->axes.defined()) {
for (int i = 0; i < ndim; ++i) {
int64_t axis = attr->axes[i];
axis += (axis < 0) ? ndim : 0;
attr_axes.push_back(axis);
}
} else {
// Empty axes means reverse
for (int i = ndim - 1; i >= 0; --i) {
attr_axes.push_back(i);
}
}
} else if (auto attr = call->attrs.as<LayoutTransformAttrs>()) {
Layout src_layout(attr->src_layout);
Layout dst_layout(attr->dst_layout);
for (int i = 0; i < ndim; ++i) {
attr_axes.push_back(src_layout.IndexOf(dst_layout[i]));
}
} else {
CHECK(false) << "Expected transpose or layout_transform, but got "
<< Downcast<Op>(call->op)->name;
}
return std::move(attr_axes);
}

private:
/*! \brief Pattern input */
DFPattern x_;
Expand Down
166 changes: 166 additions & 0 deletions tests/python/relay/test_pass_simplify_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,176 @@ def expected3():
y = relay.transpose(y, axes=[0, 2, 3, 1])
return relay.Function([x], y)

# Test a series of transpose and rank changing layout_transform
csullivan marked this conversation as resolved.
Show resolved Hide resolved
def before4():
"""
Simplify transpose->layout_transform and its inverse.

Input:
NHWC -> NCHW -> NCHW4c -> op -> NCHW4c -> NCHW -> NHWC

Simplified:
NHWC -> NCHW4c -> op -> NCHW4c -> NHWC
"""
x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32")
y = relay.transpose(x, axes=[0, 3, 1, 2])
y = relay.layout_transform(y, "NCHW", "NCHW4c")
y = relay.nn.relu(y)
y = relay.layout_transform(y, "NCHW4c", "NCHW")
y = relay.transpose(y, axes=[0, 2, 3, 1])
return relay.Function([x], y)

def expected4():
x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") # NHWC
y = relay.layout_transform(x, "NHWC", "NCHW4c") # To NCHW4c
y = relay.nn.relu(y)
y = relay.layout_transform(y, "NCHW4c", "NHWC") # To NHWC
return relay.Function([x], y)

def before5():
"""
Simplify layout_transform->layout_transform and its inverse.

Input:
NHWC -> NCHW -> NCHW4c -> op -> NCHW4c -> NCHW -> NHWC

Simplified:
NHWC -> NCHW4c -> op -> NCHW4c -> NHWC
"""
x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") # NHWC
y = relay.layout_transform(x, "NHWC", "NCHW") # To NCHW
y = relay.layout_transform(y, "NCHW", "NCHW4c") # To NCHW4c
y = relay.nn.relu(y)
y = relay.layout_transform(y, "NCHW4c", "NCHW") # To NCHW
y = relay.layout_transform(y, "NCHW", "NHWC") # To NHWC
return relay.Function([x], y)

def expected5():
x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") # NHWC
y = relay.layout_transform(x, "NHWC", "NCHW4c") # To NCHW4c
y = relay.nn.relu(y)
y = relay.layout_transform(y, "NCHW4c", "NHWC") # To NHWC
return relay.Function([x], y)

def before6():
"""
Remove trivial layout_transform->layout_transform.

Input:
NCHW -> NHWC -> NCHW -> op

Simplified:
NHWC -> op
"""

x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
y = relay.layout_transform(x, "NCHW", "NHWC")
y = relay.layout_transform(y, "NHWC", "NCHW")
y = relay.nn.relu(y)
return relay.Function([x], y)

def expected6():
x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
y = relay.nn.relu(x)
return relay.Function([x], y)

def before7():
"""
Remove trivial layout_transform->layout_transform.

Input:
NCHW4c -> NCHW8c -> NCHW4c -> op

Simplified:
NCHW4c -> op
"""
x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32")
y = relay.layout_transform(x, "NCHW4c", "NCHW8c")
y = relay.layout_transform(y, "NCHW8c", "NCHW4c")
y = relay.nn.relu(y)
return relay.Function([x], y)

def expected7():
x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32")
y = relay.nn.relu(x)
return relay.Function([x], y)

def before8():
"""
Simplify layout_transform->layout_transform with rank contraction and expansion

Input:
NCHW4c -> NCHW -> NCHW8c -> op

Simplified:
NCHW4c -> NCHW8c -> op
"""
x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32")
y = relay.layout_transform(x, "NCHW4c", "NCHW")
y = relay.layout_transform(y, "NCHW", "NCHW8c")
y = relay.nn.relu(y)
return relay.Function([x], y)

def expected8():
x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32")
y = relay.layout_transform(x, "NCHW4c", "NCHW8c")
y = relay.nn.relu(y)
return relay.Function([x], y)

def before9():
"""
Remove trivial layout_transform->layout_transform.

Input:
NCHW -> NCHW4c -> NCHW -> op

Simplified:
NCHW -> op
"""
x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
y = relay.layout_transform(x, "NCHW", "NCHW4c")
y = relay.layout_transform(y, "NCHW4c", "NCHW")
y = relay.nn.relu(y)
return relay.Function([x], y)

def expected9():
x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
y = relay.nn.relu(x)
return relay.Function([x], y)

def before10():
"""
Simplify layout_transform->layout_transform without rank change to transpose.

Input:
NCHW -> NHWC -> CHWN -> op

Simplified:
NCHW -> CHWN -> op
"""
x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
y = relay.layout_transform(x, "NCHW", "NHWC")
y = relay.layout_transform(y, "NHWC", "CHWN")
y = relay.nn.relu(y)
return relay.Function([x], y)

def expected10():
x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
y = relay.transpose(x, axes=[1, 2, 3, 0])
y = relay.nn.relu(y)
return relay.Function([x], y)

for before, expected in [
[before1(), expected1()],
[before2(), expected2()],
[before3(), expected3()],
[before4(), expected4()],
[before5(), expected5()],
[before6(), expected6()],
[before7(), expected7()],
[before8(), expected8()],
[before9(), expected9()],
[before10(), expected10()],
]:
after = run_opt_pass(before, transform.SimplifyExpr())
expected = run_opt_pass(expected, transform.InferType())
Expand Down