diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc index 59ce01ce4227a..80925117b6cf0 100644 --- a/src/relay/analysis/util.cc +++ b/src/relay/analysis/util.cc @@ -71,7 +71,7 @@ class TypeVarTVisitor : public TypeVisitor { InsertionSet* bound_type_vars_; }; -class TypeVarEVisitor : private ExprVisitor { +class TypeVarEVisitor : private MixedModeVisitor { public: explicit TypeVarEVisitor(const IRModule& mod) : mod_(mod) {} @@ -159,7 +159,7 @@ class TypeVarEVisitor : private ExprVisitor { const IRModule& mod_; }; -class VarVisitor : protected ExprVisitor, protected PatternVisitor { +class VarVisitor : protected MixedModeVisitor, protected PatternVisitor { public: Array Free(const Expr& expr) { this->VisitExpr(expr); diff --git a/src/relay/analysis/well_formed.cc b/src/relay/analysis/well_formed.cc index 3e409d10b8855..be1ebc661ceca 100644 --- a/src/relay/analysis/well_formed.cc +++ b/src/relay/analysis/well_formed.cc @@ -32,7 +32,7 @@ namespace tvm { namespace relay { //! brief make sure each Var is bound at most once in a scope. -class WellFormedChecker : private ExprVisitor, PatternVisitor { +class WellFormedChecker : private MixedModeVisitor, PatternVisitor { public: Optional diag_ctx; Span occurs_in; @@ -126,7 +126,7 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor { // CHECK(call->attrs.defined()); CHECK(call->type_args.defined()); - ExprVisitor::VisitExpr_(call); + MixedModeVisitor::VisitExpr_(call); } void VisitClause(const Clause& c) final { @@ -139,18 +139,14 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor { void VisitVar(const Var& v) final { Bound(v); } - void VisitExpr(const Expr& e) final { + public: + bool CheckWellFormed(const Expr& e) { if (auto v = e.as()) { VisitExpr_(v); } else { // this->occurs_in = e->span; - ExprVisitor::VisitExpr(e); + VisitExpr(e); } - } - - public: - bool CheckWellFormed(const Expr& e) { - this->VisitExpr(e); return well_formed; } }; diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index cbc41d225d4b5..9125e07d98ab2 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -517,7 +517,7 @@ TVM_REGISTER_GLOBAL("relay.analysis.post_order_visit").set_body_typed([](Expr ex }); // Implement bind. -class ExprBinder : public ExprMutator, PatternMutator { +class ExprBinder : public MixedModeMutator, PatternMutator { public: explicit ExprBinder(const tvm::Map& args_map) : args_map_(args_map) {} diff --git a/src/relay/transforms/de_duplicate.cc b/src/relay/transforms/de_duplicate.cc index d90e5c584df36..0cb27840acebe 100644 --- a/src/relay/transforms/de_duplicate.cc +++ b/src/relay/transforms/de_duplicate.cc @@ -31,7 +31,7 @@ namespace tvm { namespace relay { Expr DeDup(const Expr& e) { - class DeDupMutator : public TypeMutator, public ExprMutator, public PatternMutator { + class DeDupMutator : public TypeMutator, public MixedModeMutator, public PatternMutator { public: TypeVar Fresh(const TypeVar& tv) { TypeVar ret = TypeVar(tv->name_hint, tv->kind); @@ -47,7 +47,7 @@ Expr DeDup(const Expr& e) { return ret; } - Expr VisitExpr(const Expr& e) final { + Expr DispatchVisitExpr(const Expr& e) final { auto ret = ExprMutator::VisitExpr(e); ret->checked_type_ = e->checked_type_; return ret; diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 660aff2eed9a6..61c3a41f5cdfd 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -75,7 +75,7 @@ TVM_REGISTER_GLOBAL("relay.analysis.check_constant").set_body_typed(ConstantChec // TODO(tvm-team) consider combine dead-code with constant folder. // or make a more powerful partial evaluator. -class ConstantFolder : public ExprMutator { +class ConstantFolder : public MixedModeMutator { public: explicit ConstantFolder(IRModule module) : module_(module), @@ -118,7 +118,7 @@ class ConstantFolder : public ExprMutator { } } - Expr VisitExpr_(const CallNode* call) final { + Expr Rewrite_(const CallNode* call, const Expr& post) final { if (inside_primitive) { return GetRef(call); } @@ -127,26 +127,25 @@ class ConstantFolder : public ExprMutator { std::unordered_set skip_list{"zeros_like", "ones_like", "full_like", "full"}; auto origin_args = call->args; - Expr res = ExprMutator::VisitExpr_(call); - call = res.as(); + call = post.as(); // We don't constant fold function with zero arguments. // This is a heuristic that is useful. // For example it is harmful to fold ones(shape=(4, 5)). - if (call->args.size() == 0) return res; + if (call->args.size() == 0) return post; const OpNode* op = call->op.as(); - if (op == nullptr) return res; + if (op == nullptr) return post; if (skip_list.count(op->name)) { - return res; + return post; } // skip stateful ops. - if (op_stateful.get(GetRef(op), false)) return res; + if (op_stateful.get(GetRef(op), false)) return post; // Try to evaluate shape_of op if (call->op == shape_of_op_ || call->op == vm_shape_of_op_) { - return EvaluateShapeOf(res, origin_args, call->attrs); + return EvaluateShapeOf(post, origin_args, call->attrs); } if (call->op == ndarray_size_op_) { - return EvaluateNdarraySize(res, origin_args, call->attrs); + return EvaluateNdarraySize(post, origin_args, call->attrs); } // We should think about potentially constant evaluation over these ops too. @@ -162,19 +161,18 @@ class ConstantFolder : public ExprMutator { } } if (all_const_args) { - return ConstEvaluate(res); + return ConstEvaluate(post); } else { - return res; + return post; } } - Expr VisitExpr_(const TupleGetItemNode* op) final { - Expr res = ExprMutator::VisitExpr_(op); - op = res.as(); + Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final { + op = post.as(); if (const auto* tuple = op->tuple.as()) { return tuple->fields[op->index]; } else { - return res; + return post; } }