Skip to content

Commit

Permalink
FoldScaleAxis became non-recursive (apache#8325)
Browse files Browse the repository at this point in the history
* FoldScaleAxis became non-recursive

FoldScaleAxis moved from ExprVisitor and ExprMutator
to non-recursive MixedModeVisitor and MixedModeMutator.
The specific transforming part itself is still recursive,
however the underlying traversal machinery is non-recursive.

Change-Id: I8bf40bd1f3f039ef0705c665a34a4624067048a1

* Added extra empty lines as requested

Change-Id: I242ec95f92b3dfc7fa3dd89385f56ab07c6e72a8
  • Loading branch information
d-smirnov authored and ylc committed Sep 29, 2021
1 parent a69fba1 commit a2aa033
Showing 1 changed file with 78 additions and 59 deletions.
137 changes: 78 additions & 59 deletions src/relay/transforms/fold_scale_axis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ using FForwardRewrite = TypedPackedFunc<Expr(const Call& ref_call, const Array<E
//----------------------------------------------
// Generic Visitors for FScaleAxisForward
//----------------------------------------------
class ForwardPrep : private ExprVisitor {
class ForwardPrep : private MixedModeVisitor {
public:
std::unordered_map<const Object*, Message> Prepare(const Expr& body) {
this->Update(body, NullValue<Message>());
Expand Down Expand Up @@ -585,15 +585,22 @@ RELAY_REGISTER_OP("nn.conv2d")

Expr ForwardFoldScaleAxis(const Expr& data) {
auto message = ForwardPrep().Prepare(data);
auto fcontext = [&](const Call& call) -> ObjectRef {
auto it = message.find(call.get());
if (it != message.end()) {
return it->second;
} else {
return ObjectRef(nullptr);
for (const auto& m : message) {
if (m.second.defined()) {
// run optimization
auto fcontext = [&](const Call& call) -> ObjectRef {
auto it = message.find(call.get());
if (it != message.end()) {
return it->second;
} else {
return ObjectRef(nullptr);
}
};
return ForwardRewrite(data, "FScaleAxisForwardRewrite", fcontext);
}
};
return ForwardRewrite(data, "FScaleAxisForwardRewrite", fcontext);
}
// no messages - no optimization
return data;
}

//----------------------------------------
Expand All @@ -618,7 +625,7 @@ using FBackwardTransform =
// Generic Visitors for FScaleAxisBackward
//----------------------------------------------

class BackwardPrep : private ExprVisitor {
class BackwardPrep : private MixedModeVisitor {
public:
// The message on each node.
std::unordered_map<const Object*, Message> Prepare(const Expr& body) {
Expand All @@ -643,6 +650,14 @@ class BackwardPrep : private ExprVisitor {
// We only allow propagation of scale backward
// if the expression is only referred by a single parent.
if (rit->second != 1) return;
Array<Message> in_messages = GetInMessages(call);
Message out_message = f(GetRef<Call>(call), in_messages);
if (out_message.defined()) {
message_[call] = out_message;
}
}

Array<Message> GetInMessages(const CallNode* call) {
Array<Message> in_messages;
for (Expr arg : call->args) {
auto it = message_.find(arg.get());
Expand All @@ -652,52 +667,34 @@ class BackwardPrep : private ExprVisitor {
in_messages.push_back(NullValue<Message>());
}
}
Message out_message = f(GetRef<Call>(call), in_messages);
if (out_message.defined()) {
message_[call] = out_message;
}
return in_messages;
}
};

class BackwardTransformerNode : public Object, private ExprMutator {
/*
* Hybrid apporach is used with the transformation
* itself is recursive but the traversal is non-recursive
*/
class BackwardTransformerNode : public Object, private MixedModeMutator {
public:
using MixedModeMutator::Mutate;
// Run forward transform.
Expr Fold(Expr expr) {
message_ = BackwardPrep().Prepare(expr);
return this->Mutate(expr);
}
/*!
* \brief Transform the expr to consider the scaling.
*
* \param expr The input expression.
* \param axes The axes to scale.
* \param scale The scale applied to the axes.
* \return The result of transformation.
*/
Expr Transform(const Expr& expr, Message message, Expr scale) {
// NOTE: the result of Transform is memoized.
if (const CallNode* call_node = expr.as<CallNode>()) {
return Transform(call_node, message, scale);
} else {
ICHECK(!message.defined()) << "outstanding scale";
return ExprMutator::VisitExpr(expr);
for (const auto& m : message_) {
if (m.second.defined()) {
// run optimization
return this->Mutate(expr);
}
}
// no messages - no optimization
return expr;
}

/*!
* \brief Normal way of mutating call node.
* \param call_node The call node to be mutated.
* \return the result of the call Mutation.
* \brief Transform the expr to consider the scaling.
*/
Expr NormalCallTransform(const CallNode* call_node) {
const Call call = GetRef<Call>(call_node);
const auto it = memo_.find(call);
if (it != memo_.end()) {
return it->second;
}
Expr new_expr = ExprMutator::VisitExpr_(call_node);
memo_[call] = new_expr;
return new_expr;
}
Expr Transform(const Expr& expr, Message message, Expr scale);
/*!
* \brief Get the message propogated to the expr.
* \param expr The expresison.
Expand All @@ -719,11 +716,12 @@ class BackwardTransformerNode : public Object, private ExprMutator {
// Valid axes on each node.
std::unordered_map<const Object*, Message> message_;
// Override mutation of call.
Expr VisitExpr_(const CallNode* call_node) final {
return Transform(call_node, NullValue<Message>(), NullValue<Expr>());
Expr Rewrite_(const CallNode* call_node, const Expr& post) final {
return Transform(GetRef<Call>(call_node), NullValue<Message>(), NullValue<Expr>());
}
// Transform of CallNode.
Expr Transform(const CallNode* call_node, Message message, Expr scale);

public:
Expr NormalCallTransform(const CallNode* call_node) { return ExprMutator::VisitExpr_(call_node); }
};

class BackwardTransformer : public ObjectRef {
Expand All @@ -736,21 +734,39 @@ class BackwardTransformer : public ObjectRef {
using ContainerType = BackwardTransformerNode;
};

Expr BackwardTransformerNode::Transform(const CallNode* call_node, Message message, Expr scale) {
static const auto& ftransform = Op::GetAttrMap<FBackwardTransform>("FScaleAxisBackwardTransform");
auto f = ftransform.get(call_node->op, nullptr);
if (f != nullptr) {
/*!
* \brief Transform the expr to consider the scaling.
*
* \param expr The input expression.
* \param message The axes to scale.
* \param scale The scale applied to the axes.
* \return The result of transformation.
*/
Expr BackwardTransformerNode::Transform(const Expr& expr, Message message, Expr scale) {
if (const CallNode* call_node = expr.as<CallNode>()) {
static const auto& ftransform =
Op::GetAttrMap<FBackwardTransform>("FScaleAxisBackwardTransform");
auto f = ftransform.get(call_node->op, nullptr);
const Call call = GetRef<Call>(call_node);
const auto it = memo_.find(call);
if (it != memo_.end()) {
return it->second;
// ignore if there is a message
if (!message.defined()) {
const auto it = memo_.find(call);
if (it != memo_.end()) {
return it->second;
}
}
Expr new_expr = NullValue<Expr>();
if (f != nullptr) {
new_expr = f(call, message, scale, GetRef<BackwardTransformer>(this));
} else {
ICHECK(!message.defined()) << "outstanding scale";
new_expr = NormalCallTransform(call.operator->());
}
Expr new_expr = f(GetRef<Call>(call_node), message, scale, GetRef<BackwardTransformer>(this));
memo_[call] = new_expr;
return new_expr;
} else {
ICHECK(!message.defined()) << "outstanding scale";
return NormalCallTransform(call_node);
return this->Mutate(expr);
}
}

Expand Down Expand Up @@ -813,6 +829,7 @@ Expr AddSubBackwardTransform(const Call& call, const Message& message, const Exp
if (!message.defined()) {
return transformer->NormalCallTransform(call.operator->());
}

Message lhs_message = transformer->GetMessage(call->args[0]);
Message rhs_message = transformer->GetMessage(call->args[1]);
StructuralEqual equal;
Expand Down Expand Up @@ -959,7 +976,9 @@ Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Exp
} else {
wscale = ReshapeToMatchAxis(scale, weight->type_as<TensorTypeNode>()->shape,
{big_ko_axis, small_ko_axis});
if (!wscale.defined()) return transformer->NormalCallTransform(call.operator->());
if (!wscale.defined()) {
return transformer->NormalCallTransform(call.operator->());
}
}
weight = Multiply(weight, wscale);
return Call(call->op, {data, weight}, call->attrs, call->type_args);
Expand Down

0 comments on commit a2aa033

Please sign in to comment.