From 3edf15c36fdac2b501073b704235b309f37ab507 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Mon, 27 Sep 2021 16:45:51 -0700 Subject: [PATCH] [Parser][Printer] update parser and printer for match_shape (#13) --- python/tvm/relax/parser.py | 101 +++++++++++++++------- src/relay/printer/relax_script_printer.cc | 51 ++++++----- tests/python/relax/test_parser.py | 28 ++++-- tests/python/relax/test_printer.py | 13 +-- 4 files changed, 129 insertions(+), 64 deletions(-) diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py index cf4c5c261ad5..d39ae8ba6357 100644 --- a/python/tvm/relax/parser.py +++ b/python/tvm/relax/parser.py @@ -479,6 +479,26 @@ def transform_function(self, func: ast.Function, is_global: bool = False) -> rx. params, new_body, ret_type, name=func_name, span=self.to_tvm_span(func.span) ) + def is_match_shape(self, stmt: ast.Stmt) -> bool: + """Returns whether or not the given statement is a MatchShape binding. + + Parameters + ---------- + stmt : ast.Stmt + The statement to be parsed. + + Returns + ------- + bool + Whether or not the statement is a MatchShape binding. + """ + call_op = None + if isinstance(stmt, ast.UnassignedCall): + call_op = self.transform_expr(stmt.call.func_name) + elif isinstance(stmt, ast.Assign) and isinstance(stmt.rhs, ast.Call): + call_op = self.transform_expr(stmt.rhs.func_name) + return call_op == SpecialOp.MATCH_SHAPE + def parse_binding(self, stmt: ast.Stmt, is_dataflow: bool = False) -> rx.Binding: """Parses the input synr statement to the corresponding Relax binding. @@ -495,42 +515,62 @@ def parse_binding(self, stmt: ast.Stmt, is_dataflow: bool = False) -> rx.Binding The parsed Relax binding """ assert isinstance(stmt, (ast.Assign, ast.UnassignedCall)) - if isinstance(stmt, ast.Assign): - return self.parse_var_binding(stmt, is_dataflow=is_dataflow) + if self.is_match_shape(stmt): + return self.parse_shape_binding(stmt, is_dataflow=is_dataflow) else: - return self.parse_shape_binding(stmt) + assert isinstance(stmt, ast.Assign) + return self.parse_var_binding(stmt, is_dataflow=is_dataflow) - def parse_shape_binding(self, stmt: ast.UnassignedCall) -> rx.MatchShape: + def parse_shape_binding(self, stmt: ast.Stmt, is_dataflow: bool = False) -> rx.MatchShape: """Parses the input synr statement to a Relax shape binding. Parameters ---------- - stmt : ast.UnassignedCall + stmt : ast.Stmt The input synr statement + is_dataflow : bool, optional + Whether or not the bound variable (if any) is a dataflow variable, by default False Returns ------- rx.MatchShape The parsed Relax shape binding """ - call: synr.ast.Call = stmt.call + var: ast.Var = None + call: ast.Call = None + + if isinstance(stmt, ast.UnassignedCall): + # case where only dimension variables are bound, e.g. `match_shape(x.shape, (n, m))` + call = stmt.call + else: + # case where the statement also binds a Relax variable to the value being matched + assert isinstance(stmt, ast.Assign) + if not isinstance(stmt.lhs, ast.Var): + self.report_error( + "the left hand side of a binding must be a variable", stmt.lhs.span + ) + var = stmt.lhs + call = stmt.rhs + op = self.transform_expr(call.func_name) - if op != SpecialOp.MATCH_SHAPE: - self.report_error("the results of calls must be bound or used", stmt.span) - if len(stmt.call.params) != 2: - self.report_error(op.value + " takes exactly two arguments", stmt.span) - lhs = stmt.call.params[0] - rhs = stmt.call.params[1] + assert op == SpecialOp.MATCH_SHAPE + if len(call.params) != 2: + self.report_error(op.value + " takes exactly two arguments", call.span) - rhs_expr = self.transform_expr(rhs) - if not isinstance(lhs, ast.Tuple): - self.report_error( - "the pattern (lhs) of " + op.value + " must be a tuple", - lhs.span, - ) - lhs_expr = self.parse_shape(lhs, bind_free_vars=True) - return rx.MatchShape(lhs_expr, rhs_expr, self.to_tvm_span(stmt.span)) + value, pattern = call.params + + value = self.transform_expr(value) + if not isinstance(pattern, ast.Tuple): + self.report_error(f"the pattern of a {op.value} call must be a tuple", pattern.span) + pattern = self.parse_shape(pattern, bind_free_vars=True) + + if var is not None: + # TODO(@altanh): keep or discard annotation? + ty, shape = self.transform_type(stmt.ty, bind_free_vars=False) + var = self.decl_var(var.id.name, ty, shape, var.span, is_dataflow=is_dataflow) + + return rx.MatchShape(value, pattern, var, self.to_tvm_span(stmt.span)) def parse_var_binding(self, stmt: ast.Assign, is_dataflow=False) -> rx.VarBinding: """Parses the input synr assignment to a Relax variable binding. @@ -540,12 +580,12 @@ def parse_var_binding(self, stmt: ast.Assign, is_dataflow=False) -> rx.VarBindin stmt : ast.Assign The input synr assignment is_dataflow : bool, optional - Whether or not the binding is in a dataflow block, by default False + Whether or not the bound variable is a dataflow variable, by default False Returns ------- rx.VarBinding - The prased Relax variable binding + The parsed Relax variable binding """ if not isinstance(stmt.lhs, ast.Var): self.report_error( @@ -644,8 +684,10 @@ def transform_stmt(self, stmt: ast.Stmt) -> Union[rx.Expr, rx.Binding, rx.Datafl return self.transform_expr(stmt.value) elif isinstance(stmt, ast.UnassignedCall): - # FIXME: when we add ref support, ref_write can be unassigned - return self.parse_shape_binding(stmt) + if self.transform_expr(stmt.call.func_name) == SpecialOp.MATCH_SHAPE: + return self.parse_shape_binding(stmt) + else: + self.report_error("the results of normal function calls must be bound", stmt.span) elif isinstance(stmt, ast.With): if not isinstance(stmt.rhs, ast.Call): @@ -727,9 +769,10 @@ def parse_dataflow(self, block: ast.Block) -> rx.DataflowBlock: "only bindings are supported in dataflow blocks", binding_stmt.span, ) - is_match_shape = isinstance(binding_stmt, ast.UnassignedCall) - is_dataflow = not is_match_shape and ( - binding_stmt.lhs.id.name not in output_var_names + is_match_shape = self.is_match_shape(binding_stmt) + is_dataflow = ( + isinstance(binding_stmt, ast.Assign) + and binding_stmt.lhs.id.name not in output_var_names ) binding = self.parse_binding(binding_stmt, is_dataflow=is_dataflow) bindings.append(binding) @@ -737,9 +780,9 @@ def parse_dataflow(self, block: ast.Block) -> rx.DataflowBlock: if is_match_shape: for var in binding.pattern: output_vars.append(var) - else: + if binding.var is not None: output_vars.append(binding.var) - unbound_output_vars.pop(binding_stmt.lhs.id.name) + unbound_output_vars.pop(binding.var.name_hint) # check that the output variables are all bound locally for unbound_var in unbound_output_vars.values(): diff --git a/src/relay/printer/relax_script_printer.cc b/src/relay/printer/relax_script_printer.cc index e50588ad1f18..9a7b4a126b2e 100644 --- a/src/relay/printer/relax_script_printer.cc +++ b/src/relay/printer/relax_script_printer.cc @@ -79,6 +79,7 @@ class RelaxScriptPrinter : public relax::IRFunctor, Doc PrintIfStmt(const relax::Var& var, const relay::If& ite); Doc PrintFunctionDef(const Doc& name, const relax::Function& func); + Doc PrintVarAnnotation(const relax::Var& var); Doc PrintTensorAnnotation(const relax::DynTensorType& ty, const Optional& shape); Doc VisitType_(const relax::ShapeTypeNode* node) override; @@ -238,9 +239,12 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::ShapeExprNode* op) { Doc RelaxScriptPrinter::VisitNode_(const relax::MatchShapeNode* op) { Doc doc; + if (op->var.defined()) { + doc << Print(op->var) << PrintVarAnnotation(op->var) << " = "; + } doc << "relax.match_shape("; // TODO(@altanh): maybe op->pattern should just be a ShapeExpr? - doc << Print(relax::ShapeExpr(op->pattern)) << ", " << Print(op->value); + doc << Print(op->value) << ", " << Print(relax::ShapeExpr(op->pattern)); doc << ")"; return doc; } @@ -260,16 +264,7 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::VarBindingNode* op) { return tir::AsTVMScriptDoc(mod, false, prim_func_ref); } else { Doc doc; - doc << Print(op->var); - if (op->var->type_annotation.defined()) { - doc << ": "; - if (const relax::DynTensorTypeNode* tty = - op->var->type_annotation.as()) { - doc << PrintTensorAnnotation(GetRef(tty), op->var->shape_); - } else { - doc << Print(op->var->type_annotation); - } - } + doc << Print(op->var) << PrintVarAnnotation(op->var); doc << " = " << Print(op->value); return doc; } @@ -289,10 +284,14 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::DataflowBlockNode* op) { std::vector return_vars; for (const relax::Binding& binding : op->bindings) { body << Print(binding) << Doc::NewLine(); + Var var; if (const relax::VarBindingNode* var_binding = binding.as()) { - if (!var_binding->var.as()) { - return_vars.push_back(Print(var_binding->var)); - } + var = var_binding->var; + } else if (const relax::MatchShapeNode* shape_binding = binding.as()) { + var = shape_binding->var; + } + if (var.defined() && !var.as()) { + return_vars.push_back(Print(var)); } } ICHECK(!return_vars.empty()) << "dataflow blocks should have at least one output variable"; @@ -444,16 +443,7 @@ Doc RelaxScriptPrinter::PrintFunctionDef(const Doc& name, const relax::Function& for (size_t i = 0; i < func->params.size(); ++i) { relax::Var var = func->params[i]; Doc param; - param << Print(var); - if (var->type_annotation.defined()) { - param << ": "; - if (const relax::DynTensorTypeNode* tty = - var->type_annotation.as()) { - param << PrintTensorAnnotation(GetRef(tty), var->shape_); - } else { - param << Print(var->type_annotation); - } - } + param << Print(var) << PrintVarAnnotation(var); params.push_back(param); } @@ -471,6 +461,19 @@ Doc RelaxScriptPrinter::PrintFunctionDef(const Doc& name, const relax::Function& return doc; } +Doc RelaxScriptPrinter::PrintVarAnnotation(const relax::Var& var) { + Doc doc; + if (var->type_annotation.defined()) { + doc << ": "; + if (const relax::DynTensorTypeNode* tty = var->type_annotation.as()) { + doc << PrintTensorAnnotation(GetRef(tty), var->shape_); + } else { + doc << Print(var->type_annotation); + } + } + return doc; +} + Doc RelaxScriptPrinter::PrintTensorAnnotation(const relax::DynTensorType& ty, const Optional& shape) { Doc doc; diff --git a/tests/python/relax/test_parser.py b/tests/python/relax/test_parser.py index 0575129ea12a..8e58c84ca613 100644 --- a/tests/python/relax/test_parser.py +++ b/tests/python/relax/test_parser.py @@ -90,7 +90,7 @@ def foo(x: Tensor[(32, m), "float32"], y: Tensor[(m, k), "float32"]) -> Tensor: def test_match_shape(): @rx.script def foo(x: Tensor[_, "float32"]): - relax.match_shape((n, m), x.shape) + relax.match_shape(x.shape, (n, m)) y: Tensor[(n, m), "float32"] = add(x, x) return x @@ -289,13 +289,31 @@ def test_dataflow_match_shape(): @rx.script def foo(x: Tensor[_, _]): with relax.dataflow(): - y = add(x, x) + x2: Tensor[(n, m), _] = relax.match_shape(x, (n, m)) + y = add(x2, x2) z = multiply(y, x) - relax.match_shape((n, m), z.shape) + relax.match_shape(z.shape, (n, m)) w: Tensor[(n, m), _] = subtract(z, x) - relax.output(y, w) + relax.output(y, w, x2) t: Tensor[(n, m), _] = divide(y, w) - return t + q: Tensor[(n, m), _] = add(t, x2) + return q + + f = rx_func(foo) + x = f.params[0] + df_block = f.body.blocks[0] + x2_bind = df_block.bindings[0] + z_shape_bind = df_block.bindings[3] + q_bind = f.body.blocks[1].bindings[1] + + assert x2_bind.var.name_hint == "x2" + check_tensor_var(x2_bind.var, ("n", "m"), "") + check_shape(x2_bind.pattern, ("n", "m")) + assert x2_bind.value == x + + check_shape(z_shape_bind.pattern, ("n", "m")) + + assert q_bind.value.args[1] == x2_bind.var @pytest.mark.xfail diff --git a/tests/python/relax/test_printer.py b/tests/python/relax/test_printer.py index 4623ad01b52e..bb4363cd20fc 100644 --- a/tests/python/relax/test_printer.py +++ b/tests/python/relax/test_printer.py @@ -31,14 +31,13 @@ def foo(x: Tensor[(32, m), "float32"], y: Tensor[(m, k), "float32"]) -> Tensor: def test_match_shape(): @rx.script def foo(x: Tensor[_, "float32"]): - relax.match_shape((n, m), x.shape) + relax.match_shape(x.shape, (n, m)) y: Tensor[(n, m), "float32"] = add(x, x) return x check_roundtrip(foo) - def test_if(): @rx.script def foo(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): @@ -94,13 +93,15 @@ def test_dataflow_match_shape(): @rx.script def foo(x: Tensor[_, _]): with relax.dataflow(): - y = add(x, x) + x2: Tensor[(n, m), _] = relax.match_shape(x, (n, m)) + y = add(x2, x2) z = multiply(y, x) - relax.match_shape((n, m), z.shape) + relax.match_shape(z.shape, (n, m)) w: Tensor[(n, m), _] = subtract(z, x) - relax.output(y, w) + relax.output(y, w, x2) t: Tensor[(n, m), _] = divide(y, w) - return t + q: Tensor[(n, m), _] = add(t, x2) + return q check_roundtrip(foo)