Skip to content

Commit

Permalink
Visit shape in Visitor/Mutator (apache#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
YuchenJin committed Nov 17, 2022
1 parent 0fc4547 commit df2c068
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 68 deletions.
1 change: 1 addition & 0 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ using Expr = RelayExpr;
using ExprNode = RelayExprNode;
using relay::Call;
using relay::CallNode;
using relay::Constant;
using relay::ConstantNode;
using relay::Id;
using relay::If;
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relax/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
}

/*!
* \brief Create a new var with specified shape and type if it's original shape or type does not
* \brief Create a new var with specified shape and type if the original var's shape or type does not
* match with the specified ones.
* \param var The var to be updated.
* \param shape The specified shape.
Expand Down
38 changes: 2 additions & 36 deletions src/relax/backend/vm/vm_shape_lower.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,38 +32,6 @@
namespace tvm {
namespace relax {

/*!
* \brief Visitor to apply a function to every Expr it visits. Also applies the function
* to the shape field of the var definition site if the var's shape is a ShapeExpr.
*/
class ExprApplyVisitWithShape : public ExprVisitor {
public:
explicit ExprApplyVisitWithShape(std::function<void(const Expr&)> f) : f_(f) {}

void VisitVarDef(const Var& var) {
if (var.as<DataflowVarNode>()) {
this->VisitExpr(Downcast<DataflowVar>(var));
} else {
this->VisitExpr(var);
}
if (var->shape_.operator bool() && var->shape_.value().as<ShapeExprNode>()) {
f_(Downcast<ShapeExpr>(var->shape_.value()));
}
}

void VisitExpr(const Expr& e) final {
ExprVisitor::VisitExpr(e);
f_(e);
}

private:
std::function<void(const Expr&)> f_;
};

void PostOrderVisitWithShape(const Expr& e, std::function<void(const Expr&)> fvisit) {
ExprApplyVisitWithShape(fvisit).VisitExpr(e);
}

class VMShapeLowerMutator : public ExprMutator {
public:
static DataType ShapeDType() { return DataType::Int(64); };
Expand Down Expand Up @@ -125,9 +93,7 @@ class VMShapeLowerMutator : public ExprMutator {
builder_->BeginBindingBlock();
builder_->Emit(VarBinding(
shape_heap_, Call(ExternFunc("vm.builtin.alloc_shape_heap"), {ShapeExpr({heap_size_})})));
Array<Var> params;
for (Var param : node->params) {
params.push_back(this->VisitVarDef(param));
if (param->shape_.operator bool() && param->shape_.value().as<ShapeExprNode>()) {
Var shape = builder_->Emit(Call(ExternFunc("vm.builtin.shape_of"), {param}), "sh");
StoreShape(shape, Downcast<ShapeExpr>(param->shape_.value())->values);
Expand All @@ -150,7 +116,7 @@ class VMShapeLowerMutator : public ExprMutator {
blocks.push_back(builder_->EndBlock());
new_body = SeqExpr(blocks, new_body);

return Function(node->name, params, new_body, ret_type);
return Function(node->name, node->params, new_body, ret_type);
}

tir::PrimFunc CalculateShape(ShapeExpr s) {
Expand Down Expand Up @@ -201,7 +167,7 @@ class VMShapeLowerMutator : public ExprMutator {
}
}
};
PostOrderVisitWithShape(expr, func);
PostOrderVisit(expr, func);
return ret;
}

Expand Down
151 changes: 123 additions & 28 deletions src/relax/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@
namespace tvm {
namespace relax {

void ExprVisitor::VisitExpr_(const ConstantNode* op) { this->VisitSpan(op->span); }
void ExprVisitor::VisitExpr_(const ConstantNode* op) {
this->VisitSpan(op->span);

if (op->shape_) {
this->VisitExpr(Downcast<Expr>(op->shape_.value()));
}
}

void ExprVisitor::VisitExpr_(const GlobalVarNode* op) { this->VisitSpan(op->span); }

Expand All @@ -42,20 +48,20 @@ void ExprVisitor::VisitExpr_(const TupleNode* op) {
for (Expr field : op->fields) {
this->VisitExpr(field);
}

if (op->shape_) {
this->VisitExpr(Downcast<Expr>(op->shape_.value()));
}
}

// Visit the use-site of a defined Var
void ExprVisitor::VisitExpr_(const VarNode* op) {
this->VisitSpan(op->span);
if (op->type_annotation.defined()) {
this->VisitType(op->type_annotation.value());
}
}

// Visit the use-site of a defined DataflowVar
void ExprVisitor::VisitExpr_(const DataflowVarNode* op) {
this->VisitSpan(op->span);
if (op->type_annotation.defined()) {
this->VisitType(op->type_annotation.value());
}
}

void ExprVisitor::VisitExpr_(const FunctionNode* op) {
Expand All @@ -78,6 +84,10 @@ void ExprVisitor::VisitExpr_(const CallNode* op) {
for (Expr arg : op->args) {
this->VisitExpr(arg);
}

if (op->shape_) {
this->VisitExpr(Downcast<Expr>(op->shape_.value()));
}
}

void ExprVisitor::VisitExpr_(const IfNode* op) {
Expand Down Expand Up @@ -142,19 +152,25 @@ void ExprVisitor::VisitVarDef_(const DataflowVarNode* var) {
if (var->type_annotation.defined()) {
this->VisitType(var->type_annotation.value());
}

if (var->shape_) {
this->VisitExpr(Downcast<Expr>(var->shape_.value()));
}
}

void ExprVisitor::VisitVarDef_(const VarNode* var) {
this->VisitSpan(var->span);
if (var->type_annotation.defined()) {
this->VisitType(var->type_annotation.value());
}
}

void ExprVisitor::VisitExpr(const Expr& expr) {
ExprFunctor::VisitExpr(expr);
if (var->shape_) {
this->VisitExpr(Downcast<Expr>(var->shape_.value()));
}
}

void ExprVisitor::VisitExpr(const Expr& expr) { ExprFunctor::VisitExpr(expr); }

void ExprVisitor::VisitBinding(const Binding& binding) {
if (const auto* node = binding.as<VarBindingNode>()) {
VisitBinding_(node);
Expand Down Expand Up @@ -209,23 +225,48 @@ TVM_REGISTER_GLOBAL("relax.analysis.post_order_visit").set_body_typed([](Expr ex
// ==================
// ExprMutator

Expr ExprMutator::VisitExpr_(const ConstantNode* op) { return GetRef<Expr>(op); }
Expr ExprMutator::VisitExpr_(const ConstantNode* op) {
Expr new_shape;
bool unchanged = true;
if (op->shape_) {
new_shape = this->VisitExpr(Downcast<Expr>(op->shape_.value()));
if (!new_shape.same_as(op->shape_)) {
unchanged = false;
}
}

if (unchanged) {
return GetRef<Expr>(op);
} else {
Expr new_constant = Constant(op->data, op->span);
new_constant->shape_ = new_shape;
return new_constant;
}
}

Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) { return GetRef<Expr>(op); }

Expr ExprMutator::VisitExpr_(const TupleNode* op) {
bool unchanged = true;
tvm::Array<Expr> fields;
bool all_fields_unchanged = true;
for (Expr field : op->fields) {
Expr new_field = this->VisitExpr(field);
fields.push_back(new_field);
all_fields_unchanged &= new_field.same_as(field);
unchanged &= new_field.same_as(field);
}

Expr new_shape;
if (op->shape_) {
new_shape = this->VisitExpr(Downcast<Expr>(op->shape_.value()));
unchanged &= new_shape.same_as(op->shape_);
}

if (all_fields_unchanged) {
if (unchanged) {
return GetRef<Expr>(op);
} else {
return Tuple(fields, op->span);
Expr new_tuple = Tuple(fields, op->span);
new_tuple->shape_ = new_shape;
return new_tuple;
}
}

Expand Down Expand Up @@ -288,10 +329,18 @@ Expr ExprMutator::VisitExpr_(const CallNode* call_node) {
unchanged &= new_arg.same_as(arg);
}

Expr new_shape;
if (call_node->shape_) {
new_shape = this->VisitExpr(Downcast<Expr>(call_node->shape_.value()));
unchanged &= new_shape.same_as(call_node->shape_);
}

if (unchanged) {
return GetRef<Expr>(call_node);
} else {
return Call(new_op, call_args, call_node->attrs, ty_args, call_node->span);
Expr new_call = Call(new_op, call_args, call_node->attrs, ty_args, call_node->span);
new_call->shape_ = new_shape;
return new_call;
}
}

Expand Down Expand Up @@ -424,29 +473,75 @@ BindingBlock ExprMutator::VisitBindingBlock_(const DataflowBlockNode* block) {
}

Var ExprMutator::VisitVarDef_(const DataflowVarNode* var) {
bool type_unchanged = true;
Type new_type;
if (var->type_annotation.defined()) {
Type type = this->VisitType(var->type_annotation.value());
if (!var->type_annotation.same_as(type)) {
Var new_var = DataflowVar(var->vid, NullOpt, type, var->span);
new_type = this->VisitType(var->type_annotation.value());
type_unchanged &= new_type.same_as(var->type_annotation);
}

bool shape_unchanged = true;
Expr new_shape;
if (var->shape_) {
new_shape = this->VisitExpr(Downcast<Expr>(var->shape_.value()));
shape_unchanged &= new_shape.same_as(var->shape_);
}

if (type_unchanged && shape_unchanged) {
return GetRef<Var>(var);
} else {
Var new_var;
if (type_unchanged) {
new_var = DataflowVar(var->vid, NullOpt, var->type_annotation, var->span);
} else {
new_var = DataflowVar(var->vid, NullOpt, new_type, var->span);
}

if (shape_unchanged) {
new_var->shape_ = var->shape_;
this->var_remap_[var->vid] = new_var;
return new_var;
} else {
new_var->shape_ = new_shape;
}

this->var_remap_[var->vid] = new_var;
return new_var;
}
return GetRef<Var>(var);
}

Var ExprMutator::VisitVarDef_(const VarNode* var) {
bool type_unchanged = true;
Type new_type;
if (var->type_annotation.defined()) {
Type type = this->VisitType(var->type_annotation.value());
if (!var->type_annotation.same_as(type)) {
Var new_var = Var(var->vid, NullOpt, type, var->span);
new_type = this->VisitType(var->type_annotation.value());
type_unchanged &= new_type.same_as(var->type_annotation);
}

bool shape_unchanged = true;
Expr new_shape;
if (var->shape_) {
new_shape = this->VisitExpr(Downcast<Expr>(var->shape_.value()));
shape_unchanged &= new_shape.same_as(var->shape_);
}

if (type_unchanged && shape_unchanged) {
return GetRef<Var>(var);
} else {
Var new_var;
if (type_unchanged) {
new_var = Var(var->vid, NullOpt, var->type_annotation, var->span);
} else {
new_var = Var(var->vid, NullOpt, new_type, var->span);
}

if (shape_unchanged) {
new_var->shape_ = var->shape_;
this->var_remap_[var->vid] = new_var;
return new_var;
} else {
new_var->shape_ = new_shape;
}

this->var_remap_[var->vid] = new_var;
return new_var;
}
return GetRef<Var>(var);
}

Expr ExprMutator::VisitExpr(const Expr& expr) {
Expand Down
30 changes: 27 additions & 3 deletions tests/python/relax/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ def test_fma_rewrite():
assert structural_equal(gv0.shape, relax.ShapeExpr([m, n]))

# after rewrite
passes = [relax.transform.FMARewrite()]
seq = tvm.transform.Sequential(passes)
new_mod = seq(mod)
new_mod = relax.transform.FMARewrite()(mod)
func = new_mod["main"]
v1 = func.body.blocks[0].bindings[1].var
s1 = func.body.blocks[0].bindings[1].value
Expand All @@ -69,6 +67,31 @@ def test_fma_rewrite():
assert gv0 == v0
assert type(func.body.blocks[0].bindings[1].var) == relax.Var

def test_visit_shape():
@tvm.script.ir_module
class TestVisitShape:
@R.function
def foo(x: Tensor[(m, n), "float32"]):
gv0 = R.add(x, x)
return gv0

mod = TestVisitShape

shape_expr = []
def fvisit(e):
if isinstance(e, relax.ShapeExpr):
nonlocal shape_expr
shape_expr.append(e)

relax.analysis.post_order_visit(mod["foo"], fvisit)

# should have visited ShapeExpr 3 times
# the first time being visited is x.shape
# the last two times are the call node's shape and gv0's shape
assert len(shape_expr) == 3
assert shape_expr[0] == mod["foo"].params[0].shape
assert shape_expr[1] == shape_expr[2]


def test_to_non_dataflow():
@tvm.script.ir_module
Expand Down Expand Up @@ -312,6 +335,7 @@ def foo(x: Tensor[(m, n), "float32"]):

if __name__ == "__main__":
test_fma_rewrite()
test_visit_shape()
test_to_non_dataflow()
test_call_dps_rewrite()
test_vm_memory_lower()
Expand Down

0 comments on commit df2c068

Please sign in to comment.