From 095b63935efdd42effc0298998d67838086b7b26 Mon Sep 17 00:00:00 2001 From: Chun-I Tsai Date: Wed, 26 Jan 2022 22:35:38 +0800 Subject: [PATCH] Revert "[Frontend] Add Span filling for frontends to Relay (#9723)" (#10072) Because of the failure of LSTM conversion from Pytorch --- python/tvm/relay/expr.py | 7 +-- python/tvm/relay/frontend/common.py | 53 ------------------ python/tvm/relay/frontend/pytorch.py | 19 ------- python/tvm/relay/frontend/tensorflow.py | 17 +++++- python/tvm/relay/frontend/tensorflow2.py | 17 +++++- python/tvm/relay/frontend/tflite.py | 16 ++---- src/printer/relay_text_printer.cc | 23 +++----- src/printer/text_printer.h | 2 +- src/relay/ir/expr.cc | 4 +- tests/python/frontend/pytorch/test_forward.py | 47 ---------------- .../frontend/tensorflow/test_forward.py | 54 ------------------- .../tensorflow2/test_sequential_models.py | 24 +-------- tests/python/frontend/tflite/test_forward.py | 54 ------------------- 13 files changed, 48 insertions(+), 289 deletions(-) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 598354e1b514..811e205fb2b3 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -316,13 +316,10 @@ class TupleGetItem(ExprWithOp): index: int The index. - - span: Optional[tvm.relay.Span] - Span that points to original source code """ - def __init__(self, tuple_value, index, span=None): - self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, tuple_value, index, span) + def __init__(self, tuple_value, index): + self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, tuple_value, index) @tvm._ffi.register_object("relay.RefCreate") diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index f8c12ff334db..eeede181f6f9 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -25,7 +25,6 @@ from tvm.topi.utils import get_const_tuple from .. import expr as _expr -from ..expr_functor import ExprMutator from .. import function as _function from .. import transform as _transform from .. import op as _op @@ -955,55 +954,3 @@ def try_resolve_var_to_const(x, graph_params): return _op.const(value, dtype) return x - - -def set_span(sym, node_name): - """Set up the span of relay expression(s) while converting OP""" - - class SpanFiller(ExprMutator): - """SpanFiller""" - - def __init__(self, node_name, suffix_str="_PART_"): - ExprMutator.__init__(self) - self.node_name = node_name - self.suffix_str = suffix_str - self.counter = 0 - self.distance_from_leaf = -1 - - def _create_span(self): - if self.distance_from_leaf == 0: - return tvm.relay.Span(tvm.relay.SourceName(self.node_name), 0, 0, 0, 0) - self.distance_from_leaf -= 1 - span_str = "{}{}{}".format(self.node_name, self.suffix_str, str(self.counter)) - self.counter += 1 - return tvm.relay.Span(tvm.relay.SourceName(span_str), 0, 0, 0, 0) - - def visit_call(self, call): - if call.span is None: - self.distance_from_leaf += 1 - new_args = [self.visit(arg) for arg in call.args] - return _expr.Call( - call.op, new_args, call.attrs, call.type_args, self._create_span() - ) - return call - - def visit_tuple(self, tup): - if tup.span is None: - self.distance_from_leaf += 1 - return _expr.Tuple([self.visit(field) for field in tup.fields], self._create_span()) - return tup - - def visit_tuple_getitem(self, op): - if op.span is None: - self.distance_from_leaf += 1 - return _expr.TupleGetItem(self.visit(op.tuple_value), op.index, self._create_span()) - return op - - def fill(self, sym): - if isinstance(sym, _expr.TupleWrapper): - return _expr.TupleWrapper(self.visit(sym.tuple_value), sym.size) - if isinstance(sym, _expr.RelayExpr): - return self.visit(sym) - return sym - - return SpanFiller(node_name).fill(sym) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index b7188370d86e..f7538f0837c6 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -45,7 +45,6 @@ from .common import infer_value as _infer_value from .common import infer_value_simulated as _infer_value_simulated from .common import lstm_cell, try_infer_value, unbind -from .common import set_span from .pytorch_utils import is_version_greater_than __all__ = ["from_pytorch"] @@ -3276,9 +3275,6 @@ def body(*current_vals): def convert_operators(self, operators, outputs, ret_names): """Convert each Torch IR operators to Relay equivalent""" - # an op node might not belong to any of scope in trace info natively - # use a cunter to prevent from messing up its scope in span - empty_counter = 0 for node_name, op_node in operators: operator = op_node.kind() inputs = _get_op_inputs(op_node, outputs) @@ -3339,9 +3335,6 @@ def _handel_nested_input(inputs): relay_out = relay_op( inputs, _get_input_types(op_node, outputs, default_dtype=self.default_dtype) ) - span_str, empty_counter = self._get_torch_span(op_node, empty_counter) - relay_out = set_span(relay_out, span_str) - self.record_output_type(relay_out) if isinstance(relay_out, tuple): @@ -3355,18 +3348,6 @@ def _handel_nested_input(inputs): return [_wrap_const(outputs[ret_name]) for ret_name in ret_names] - def _get_torch_span(self, node, empty_counter): - # torch span looks like - # %input.5 : Float(...) = aten::relu_(%input.3), scope: __module.relu # ${torch}/nn file - # the scope part might not exist - if node.scopeName(): - scope_name_str = "jit._trace.TopLevelTracedModule: " + node.scopeName() - else: - scope_name_str = "warning: no trace info " + str(empty_counter) - empty_counter += 1 - span_str = "C.graph: {}, {}".format(node.kind(), scope_name_str) - return span_str, empty_counter - def _pytorch_result_type(dtypes, non_tensor_inputs): """This promotes TVM dtypes like PyTorch would""" diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index c2aa5a165b3c..d35e0e1c203d 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -37,7 +37,6 @@ from .common import infer_type as _infer_type from .common import infer_shape as _infer_shape from .common import infer_value as _infer_value -from .common import set_span from .tensorflow_ops import _convert_map from .tensorflow_ops import _need_prelude_for_shape_inference @@ -1029,10 +1028,24 @@ def _convert_operator( else: raise NotImplementedError("Operator {} not implemented.".format(op_name)) - sym = set_span(sym, node_name) + sym = self._set_span(sym, node_name) return sym + @staticmethod + def _set_span(sym, node_name): + span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0) + if isinstance(sym, _expr.Call) and sym.span is None: + sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span) + elif isinstance(sym, _expr.TupleWrapper): + tuple_value = sym.tuple_value + if isinstance(tuple_value, _expr.Call) and tuple_value.span is None: + tuple_value = _expr.Call( + tuple_value.op, tuple_value.args, tuple_value.attrs, tuple_value.type_args, span + ) + sym = _expr.TupleWrapper(tuple_value, sym.size) + return sym + def _licm_construct(self, loop_name, node_name): """Construct a node by considering whether it is loop invariant with the given while loop. If yes, we diff --git a/python/tvm/relay/frontend/tensorflow2.py b/python/tvm/relay/frontend/tensorflow2.py index 2c8b7d4e777b..465f530624b9 100644 --- a/python/tvm/relay/frontend/tensorflow2.py +++ b/python/tvm/relay/frontend/tensorflow2.py @@ -36,7 +36,6 @@ from .. import function as _function from ..loops import while_loop as _while_loop from .common import infer_type as _infer_type -from .common import set_span from .tensorflow_ops import _convert_map as _convert_map_common from .tensorflow_ops import _get_more_static_shape_rank @@ -59,6 +58,22 @@ def _infer_type_with_prelude(val, prelude): return body.checked_type +def set_span(sym, node_name): + """set span of symbol""" + + span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0) + if isinstance(sym, _expr.Call): + sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span) + elif isinstance(sym, _expr.TupleWrapper): + tuple_value = sym.tuple_value + if isinstance(tuple_value, _expr.Call): + tuple_value = _expr.Call( + tuple_value.op, tuple_value.args, tuple_value.attrs, tuple_value.type_args, span + ) + sym = _expr.TupleWrapper(tuple_value, sym.size) + return sym + + def is_tensor_list_constuctor(tf_node): """Check whether is tensor list constructor node.""" return tf_node.op == "TensorListReserve" diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 12296bd50542..b675dd56a7bb 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -32,7 +32,6 @@ from .. import qnn as _qnn from .common import ExprTable from .common import infer_shape as _infer_shape -from .common import set_span from .common import to_int_list from .tflite_flexbuffer import FlexBufferDecoder @@ -240,17 +239,12 @@ def convert_op_to_relay(self): if len(output_tensors) == 1: tensor_idx = output_tensors[0].tensor_idx - curr_output = get_tensor_name(self.subgraph, tensor_idx) - ret = set_span(ret, "location: {}, output_name: {}".format(op_idx, curr_output)) - self.exp_tab.set_expr(curr_output, ret) + self.exp_tab.set_expr(get_tensor_name(self.subgraph, tensor_idx), ret) else: - out_names = [] - for output_tensor in output_tensors: - out_names.append(get_tensor_name(self.subgraph, output_tensor.tensor_idx)) - curr_output = ", ".join(out_names) - ret = set_span(ret, "location: {}, output_name: {}".format(op_idx, curr_output)) - for idx, out_name in enumerate(out_names): - self.exp_tab.set_expr(out_name, ret[idx]) + for idx, output_tensor in enumerate(output_tensors): + self.exp_tab.set_expr( + get_tensor_name(self.subgraph, output_tensor.tensor_idx), ret[idx] + ) def get_op_code_str(self, op): """Get TFLite ops string representation""" diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 7654ef17b753..fdc6c37e527a 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -389,21 +389,12 @@ Doc RelayTextPrinter::VisitExpr_(const TupleNode* op) { if (op->fields.size() == 1) { doc << ","; } - doc << ")"; - if (op->span.defined()) { - doc << " /* " << PrintSpan(op->span) << " */"; - } - return doc; + return doc << ")"; } Doc RelayTextPrinter::VisitExpr_(const TupleGetItemNode* op) { Doc doc; - doc << Print(op->tuple) << "." << op->index; - - if (op->span.defined()) { - doc << " /* " << PrintSpan(op->span) << " */"; - } - return doc; + return doc << Print(op->tuple) << "." << op->index; } Doc RelayTextPrinter::VisitExpr_(const IfNode* op) { @@ -977,13 +968,11 @@ Doc RelayTextPrinter::PrintMapAsAttributeValue(const Map& return doc; } -Doc RelayTextPrinter::PrintSpan(const Span& span, bool include_spans) { +Doc RelayTextPrinter::PrintSpan(const Span& span) { Doc doc; - if (include_spans) { - const auto* span_node = span.as(); - ICHECK(span_node); - doc << span_node->source_name->name; - } + const auto* span_node = span.as(); + ICHECK(span_node); + doc << span_node->source_name->name; return doc; } diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index ca46700d9cf5..a4d0ff30fa62 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -113,7 +113,7 @@ class RelayTextPrinter : public ExprFunctor, */ Doc PrintMapAsAttributeValue(const Map& map); - Doc PrintSpan(const Span& span, bool include_spans = true); + Doc PrintSpan(const Span& span); Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false); diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 64d921efe6a6..73ae3faf7078 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -375,8 +375,8 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, TVM_REGISTER_NODE_TYPE(TupleGetItemNode); -TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int index, Span span) { - return TupleGetItem(tuple, index, span); +TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int index) { + return TupleGetItem(tuple, index); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 2c07094c1e9f..3fbef494f16d 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -247,53 +247,6 @@ def visit(op): torch.cuda.empty_cache() -def verify_span(model_name, input_data=[], custom_convert_map={}): - if isinstance(model_name, str): - baseline_model, baseline_input = load_model(model_name) - elif isinstance(input_data, list): - baseline_model = model_name - baseline_input = input_data - elif isinstance(input_data, torch.Tensor) or len(input_data.shape) == 0: - baseline_model = model_name - baseline_input = [input_data] - else: - assert False, "Unexpected input format" - - trace = torch.jit.trace(baseline_model, [input.clone() for input in baseline_input]) - if isinstance(baseline_model, torch.nn.Module): - trace = trace.float().eval() - - if torch.cuda.is_available(): - trace = trace.cuda() - else: - trace = trace.cpu() - - input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)] - input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input])) - mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map) - - # collect fail cases for the convenience of further improvement - fail_cases = [] - mod_main_start = False - for line in str(mod.__str__).split("\n"): - if "@main" in line: - mod_main_start = True - continue - - if mod_main_start == True: - if "}" == line: - break - elif not ("/*" in line and "*/" in line): - fail_cases.append(line) - - print(fail_cases) - assert len(fail_cases) == 0 - - -def test_span(): - verify_span("resnet18") - - # Single operator tests @tvm.testing.uses_gpu def test_forward_pixel_shuffle(): diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index c76803b8fb3c..a5a67e149986 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -298,60 +298,6 @@ def is_gpu_available(): return False -def verify_span(mod): - # collect fail cases for the convenience of further improvement - fail_cases = [] - mod_main_start = False - for line in str(mod.__str__).split("\n"): - if "@main" in line: - mod_main_start = True - continue - - if mod_main_start == True: - if "}" == line: - break - elif not ("/*" in line and "*/" in line): - fail_cases.append(line) - - print(fail_cases) - assert len(fail_cases) == 0 - - -def simple_model(): - input_node = tf.placeholder(shape=[None, None, 3, 1], dtype=np.float32, name="input") - - shape = tf.shape(input_node) - stack = tf.stack([shape[0], 3, 3], axis=0) - output_node = tf.reshape(input_node, stack, name="output") - return output_node - - -####################################################################### -# Span fill up -# ------- -def test_span_complement_simple_model(): - with tf.Graph().as_default() as graph: - model_graph = simple_model() - graph_def = graph.as_graph_def() - - graph_def = tf_testing.ProcessGraphDefParam(graph_def) - - mod, params = relay.frontend.from_tensorflow(graph_def, shape={"input:0", (1, 3, 3, 1)}) - verify_span(mod) - - -def test_span_complement_big_model(): - with tf.Graph().as_default() as graph: - graph_def = tf_testing.get_workload("ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb") - # Call the utility to import the graph definition into default graph. - graph_def = tf_testing.ProcessGraphDefParam(graph_def) - - mod, params = relay.frontend.from_tensorflow( - graph_def, shape={"input_tensor:0", (128, 224, 224, 3)} - ) - verify_span(mod) - - ####################################################################### # Pooling # ------- diff --git a/tests/python/frontend/tensorflow2/test_sequential_models.py b/tests/python/frontend/tensorflow2/test_sequential_models.py index b76b4a714938..1b5a6342f07d 100644 --- a/tests/python/frontend/tensorflow2/test_sequential_models.py +++ b/tests/python/frontend/tensorflow2/test_sequential_models.py @@ -26,25 +26,6 @@ from common import compare_tf_tvm from common import run_tf_code -from tvm.relay.frontend.tensorflow2 import from_tensorflow - - -def verify_span(mod): - fail_cases = [] - mod_main_start = False - for line in str(mod.__str__).split("\n"): - if "@main" in line: - mod_main_start = True - continue - - if mod_main_start == True: - if "}" == line: - break - elif not ("/*" in line and "*/" in line): - fail_cases.append(line) - - print(fail_cases) - assert len(fail_cases) == 0 def run_sequential_model(model_fn, input_shape): @@ -67,10 +48,7 @@ def model_graph(model, input_shape): gdef = f.graph.as_graph_def(add_shapes=True) return gdef, _input, _output - gdef, _input, _output = model_graph(model_fn, input_shape) - mod, _ = from_tensorflow(gdef) - compare_tf_tvm(gdef, _input, _output, runtime="vm") - verify_span(mod) + compare_tf_tvm(*model_graph(model_fn, input_shape), runtime="vm") def test_dense_model(): diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 77acce459fc9..60af94b53a51 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -259,59 +259,6 @@ def run_tflite_graph(tflite_model_buf, input_data): return tflite_output -def run_span_verification( - tflite_model_buf, - input_data, - input_node, - num_output=1, - target="llvm", - out_names=None, - mode="graph_executor", -): - """Generic function to compile on relay and execute on tvm""" - # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 - try: - import tflite.Model - - tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) - except AttributeError: - import tflite - - tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) - except ImportError: - raise ImportError("The tflite package must be installed") - - input_data = convert_to_list(input_data) - input_node = convert_to_list(input_node) - - shape_dict = {} - dtype_dict = {} - for i, e in enumerate(input_node): - shape_dict[e] = input_data[i].shape - dtype_dict[e] = input_data[i].dtype.name - - mod, _ = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict) - verify_span(mod) - - -def verify_span(mod): - fail_cases = [] - mod_main_start = False - for line in str(mod.__str__).split("\n"): - if "@main" in line: - mod_main_start = True - continue - - if mod_main_start == True: - if "}" == line: - break - elif not ("/*" in line and "*/" in line): - fail_cases.append(line) - - print(fail_cases) - assert len(fail_cases) == 0 - - def compare_tflite_with_tvm( in_data, in_name, @@ -4620,7 +4567,6 @@ def test_forward_tflite2_qnn_resnet50(): tflite_output = run_tflite_graph(tflite_model_buf, data) tflite_predictions = np.squeeze(tflite_output) tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] - run_span_verification(tflite_model_buf, np.array(data), "input_1") tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), "input_1") tvm_predictions = np.squeeze(tvm_output) tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]