diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index a93532895b5a..7056dfe79fee 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -202,7 +202,7 @@ using FForwardRewrite = TypedPackedFunc Prepare(const Expr& body) { this->Update(body, NullValue()); @@ -585,15 +585,22 @@ RELAY_REGISTER_OP("nn.conv2d") Expr ForwardFoldScaleAxis(const Expr& data) { auto message = ForwardPrep().Prepare(data); - auto fcontext = [&](const Call& call) -> ObjectRef { - auto it = message.find(call.get()); - if (it != message.end()) { - return it->second; - } else { - return ObjectRef(nullptr); + for (const auto& m : message) { + if (m.second.defined()) { + // run optimization + auto fcontext = [&](const Call& call) -> ObjectRef { + auto it = message.find(call.get()); + if (it != message.end()) { + return it->second; + } else { + return ObjectRef(nullptr); + } + }; + return ForwardRewrite(data, "FScaleAxisForwardRewrite", fcontext); } - }; - return ForwardRewrite(data, "FScaleAxisForwardRewrite", fcontext); + } + // no messages - no optimization + return data; } //---------------------------------------- @@ -618,7 +625,7 @@ using FBackwardTransform = // Generic Visitors for FScaleAxisBackward //---------------------------------------------- -class BackwardPrep : private ExprVisitor { +class BackwardPrep : private MixedModeVisitor { public: // The message on each node. std::unordered_map Prepare(const Expr& body) { @@ -643,6 +650,14 @@ class BackwardPrep : private ExprVisitor { // We only allow propagation of scale backward // if the expression is only referred by a single parent. if (rit->second != 1) return; + Array in_messages = GetInMessages(call); + Message out_message = f(GetRef(call), in_messages); + if (out_message.defined()) { + message_[call] = out_message; + } + } + + Array GetInMessages(const CallNode* call) { Array in_messages; for (Expr arg : call->args) { auto it = message_.find(arg.get()); @@ -652,52 +667,34 @@ class BackwardPrep : private ExprVisitor { in_messages.push_back(NullValue()); } } - Message out_message = f(GetRef(call), in_messages); - if (out_message.defined()) { - message_[call] = out_message; - } + return in_messages; } }; -class BackwardTransformerNode : public Object, private ExprMutator { +/* + * Hybrid apporach is used with the transformation + * itself is recursive but the traversal is non-recursive + */ +class BackwardTransformerNode : public Object, private MixedModeMutator { public: + using MixedModeMutator::Mutate; // Run forward transform. Expr Fold(Expr expr) { message_ = BackwardPrep().Prepare(expr); - return this->Mutate(expr); - } - /*! - * \brief Transform the expr to consider the scaling. - * - * \param expr The input expression. - * \param axes The axes to scale. - * \param scale The scale applied to the axes. - * \return The result of transformation. - */ - Expr Transform(const Expr& expr, Message message, Expr scale) { - // NOTE: the result of Transform is memoized. - if (const CallNode* call_node = expr.as()) { - return Transform(call_node, message, scale); - } else { - ICHECK(!message.defined()) << "outstanding scale"; - return ExprMutator::VisitExpr(expr); + for (const auto& m : message_) { + if (m.second.defined()) { + // run optimization + return this->Mutate(expr); + } } + // no messages - no optimization + return expr; } + /*! - * \brief Normal way of mutating call node. - * \param call_node The call node to be mutated. - * \return the result of the call Mutation. + * \brief Transform the expr to consider the scaling. */ - Expr NormalCallTransform(const CallNode* call_node) { - const Call call = GetRef(call_node); - const auto it = memo_.find(call); - if (it != memo_.end()) { - return it->second; - } - Expr new_expr = ExprMutator::VisitExpr_(call_node); - memo_[call] = new_expr; - return new_expr; - } + Expr Transform(const Expr& expr, Message message, Expr scale); /*! * \brief Get the message propogated to the expr. * \param expr The expresison. @@ -719,11 +716,12 @@ class BackwardTransformerNode : public Object, private ExprMutator { // Valid axes on each node. std::unordered_map message_; // Override mutation of call. - Expr VisitExpr_(const CallNode* call_node) final { - return Transform(call_node, NullValue(), NullValue()); + Expr Rewrite_(const CallNode* call_node, const Expr& post) final { + return Transform(GetRef(call_node), NullValue(), NullValue()); } - // Transform of CallNode. - Expr Transform(const CallNode* call_node, Message message, Expr scale); + + public: + Expr NormalCallTransform(const CallNode* call_node) { return ExprMutator::VisitExpr_(call_node); } }; class BackwardTransformer : public ObjectRef { @@ -736,21 +734,39 @@ class BackwardTransformer : public ObjectRef { using ContainerType = BackwardTransformerNode; }; -Expr BackwardTransformerNode::Transform(const CallNode* call_node, Message message, Expr scale) { - static const auto& ftransform = Op::GetAttrMap("FScaleAxisBackwardTransform"); - auto f = ftransform.get(call_node->op, nullptr); - if (f != nullptr) { +/*! + * \brief Transform the expr to consider the scaling. + * + * \param expr The input expression. + * \param message The axes to scale. + * \param scale The scale applied to the axes. + * \return The result of transformation. + */ +Expr BackwardTransformerNode::Transform(const Expr& expr, Message message, Expr scale) { + if (const CallNode* call_node = expr.as()) { + static const auto& ftransform = + Op::GetAttrMap("FScaleAxisBackwardTransform"); + auto f = ftransform.get(call_node->op, nullptr); const Call call = GetRef(call_node); - const auto it = memo_.find(call); - if (it != memo_.end()) { - return it->second; + // ignore if there is a message + if (!message.defined()) { + const auto it = memo_.find(call); + if (it != memo_.end()) { + return it->second; + } + } + Expr new_expr = NullValue(); + if (f != nullptr) { + new_expr = f(call, message, scale, GetRef(this)); + } else { + ICHECK(!message.defined()) << "outstanding scale"; + new_expr = NormalCallTransform(call.operator->()); } - Expr new_expr = f(GetRef(call_node), message, scale, GetRef(this)); memo_[call] = new_expr; return new_expr; } else { ICHECK(!message.defined()) << "outstanding scale"; - return NormalCallTransform(call_node); + return this->Mutate(expr); } } @@ -813,6 +829,7 @@ Expr AddSubBackwardTransform(const Call& call, const Message& message, const Exp if (!message.defined()) { return transformer->NormalCallTransform(call.operator->()); } + Message lhs_message = transformer->GetMessage(call->args[0]); Message rhs_message = transformer->GetMessage(call->args[1]); StructuralEqual equal; @@ -959,7 +976,9 @@ Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Exp } else { wscale = ReshapeToMatchAxis(scale, weight->type_as()->shape, {big_ko_axis, small_ko_axis}); - if (!wscale.defined()) return transformer->NormalCallTransform(call.operator->()); + if (!wscale.defined()) { + return transformer->NormalCallTransform(call.operator->()); + } } weight = Multiply(weight, wscale); return Call(call->op, {data, weight}, call->attrs, call->type_args);