Skip to content

Commit

Permalink
Revert "[Frontend] Add Span filling for frontends to Relay (#9723)" (#…
Browse files Browse the repository at this point in the history
…10072)

Because of the failure of LSTM conversion from Pytorch
  • Loading branch information
chunit-quic committed Jan 26, 2022
1 parent ffbe491 commit 095b639
Show file tree
Hide file tree
Showing 13 changed files with 48 additions and 289 deletions.
7 changes: 2 additions & 5 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,13 +316,10 @@ class TupleGetItem(ExprWithOp):
index: int
The index.
span: Optional[tvm.relay.Span]
Span that points to original source code
"""

def __init__(self, tuple_value, index, span=None):
self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, tuple_value, index, span)
def __init__(self, tuple_value, index):
self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, tuple_value, index)


@tvm._ffi.register_object("relay.RefCreate")
Expand Down
53 changes: 0 additions & 53 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from tvm.topi.utils import get_const_tuple

from .. import expr as _expr
from ..expr_functor import ExprMutator
from .. import function as _function
from .. import transform as _transform
from .. import op as _op
Expand Down Expand Up @@ -955,55 +954,3 @@ def try_resolve_var_to_const(x, graph_params):
return _op.const(value, dtype)

return x


def set_span(sym, node_name):
"""Set up the span of relay expression(s) while converting OP"""

class SpanFiller(ExprMutator):
"""SpanFiller"""

def __init__(self, node_name, suffix_str="_PART_"):
ExprMutator.__init__(self)
self.node_name = node_name
self.suffix_str = suffix_str
self.counter = 0
self.distance_from_leaf = -1

def _create_span(self):
if self.distance_from_leaf == 0:
return tvm.relay.Span(tvm.relay.SourceName(self.node_name), 0, 0, 0, 0)
self.distance_from_leaf -= 1
span_str = "{}{}{}".format(self.node_name, self.suffix_str, str(self.counter))
self.counter += 1
return tvm.relay.Span(tvm.relay.SourceName(span_str), 0, 0, 0, 0)

def visit_call(self, call):
if call.span is None:
self.distance_from_leaf += 1
new_args = [self.visit(arg) for arg in call.args]
return _expr.Call(
call.op, new_args, call.attrs, call.type_args, self._create_span()
)
return call

def visit_tuple(self, tup):
if tup.span is None:
self.distance_from_leaf += 1
return _expr.Tuple([self.visit(field) for field in tup.fields], self._create_span())
return tup

def visit_tuple_getitem(self, op):
if op.span is None:
self.distance_from_leaf += 1
return _expr.TupleGetItem(self.visit(op.tuple_value), op.index, self._create_span())
return op

def fill(self, sym):
if isinstance(sym, _expr.TupleWrapper):
return _expr.TupleWrapper(self.visit(sym.tuple_value), sym.size)
if isinstance(sym, _expr.RelayExpr):
return self.visit(sym)
return sym

return SpanFiller(node_name).fill(sym)
19 changes: 0 additions & 19 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from .common import infer_value as _infer_value
from .common import infer_value_simulated as _infer_value_simulated
from .common import lstm_cell, try_infer_value, unbind
from .common import set_span
from .pytorch_utils import is_version_greater_than

__all__ = ["from_pytorch"]
Expand Down Expand Up @@ -3276,9 +3275,6 @@ def body(*current_vals):

def convert_operators(self, operators, outputs, ret_names):
"""Convert each Torch IR operators to Relay equivalent"""
# an op node might not belong to any of scope in trace info natively
# use a cunter to prevent from messing up its scope in span
empty_counter = 0
for node_name, op_node in operators:
operator = op_node.kind()
inputs = _get_op_inputs(op_node, outputs)
Expand Down Expand Up @@ -3339,9 +3335,6 @@ def _handel_nested_input(inputs):
relay_out = relay_op(
inputs, _get_input_types(op_node, outputs, default_dtype=self.default_dtype)
)
span_str, empty_counter = self._get_torch_span(op_node, empty_counter)
relay_out = set_span(relay_out, span_str)

self.record_output_type(relay_out)

if isinstance(relay_out, tuple):
Expand All @@ -3355,18 +3348,6 @@ def _handel_nested_input(inputs):

return [_wrap_const(outputs[ret_name]) for ret_name in ret_names]

def _get_torch_span(self, node, empty_counter):
# torch span looks like
# %input.5 : Float(...) = aten::relu_(%input.3), scope: __module.relu # ${torch}/nn file
# the scope part might not exist
if node.scopeName():
scope_name_str = "jit._trace.TopLevelTracedModule: " + node.scopeName()
else:
scope_name_str = "warning: no trace info " + str(empty_counter)
empty_counter += 1
span_str = "C.graph: {}, {}".format(node.kind(), scope_name_str)
return span_str, empty_counter


def _pytorch_result_type(dtypes, non_tensor_inputs):
"""This promotes TVM dtypes like PyTorch would"""
Expand Down
17 changes: 15 additions & 2 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from .common import infer_type as _infer_type
from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value
from .common import set_span

from .tensorflow_ops import _convert_map
from .tensorflow_ops import _need_prelude_for_shape_inference
Expand Down Expand Up @@ -1029,10 +1028,24 @@ def _convert_operator(
else:
raise NotImplementedError("Operator {} not implemented.".format(op_name))

sym = set_span(sym, node_name)
sym = self._set_span(sym, node_name)

return sym

@staticmethod
def _set_span(sym, node_name):
span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0)
if isinstance(sym, _expr.Call) and sym.span is None:
sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span)
elif isinstance(sym, _expr.TupleWrapper):
tuple_value = sym.tuple_value
if isinstance(tuple_value, _expr.Call) and tuple_value.span is None:
tuple_value = _expr.Call(
tuple_value.op, tuple_value.args, tuple_value.attrs, tuple_value.type_args, span
)
sym = _expr.TupleWrapper(tuple_value, sym.size)
return sym

def _licm_construct(self, loop_name, node_name):
"""Construct a node by considering whether it is
loop invariant with the given while loop. If yes, we
Expand Down
17 changes: 16 additions & 1 deletion python/tvm/relay/frontend/tensorflow2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from .. import function as _function
from ..loops import while_loop as _while_loop
from .common import infer_type as _infer_type
from .common import set_span

from .tensorflow_ops import _convert_map as _convert_map_common
from .tensorflow_ops import _get_more_static_shape_rank
Expand All @@ -59,6 +58,22 @@ def _infer_type_with_prelude(val, prelude):
return body.checked_type


def set_span(sym, node_name):
"""set span of symbol"""

span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0)
if isinstance(sym, _expr.Call):
sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span)
elif isinstance(sym, _expr.TupleWrapper):
tuple_value = sym.tuple_value
if isinstance(tuple_value, _expr.Call):
tuple_value = _expr.Call(
tuple_value.op, tuple_value.args, tuple_value.attrs, tuple_value.type_args, span
)
sym = _expr.TupleWrapper(tuple_value, sym.size)
return sym


def is_tensor_list_constuctor(tf_node):
"""Check whether is tensor list constructor node."""
return tf_node.op == "TensorListReserve"
Expand Down
16 changes: 5 additions & 11 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from .. import qnn as _qnn
from .common import ExprTable
from .common import infer_shape as _infer_shape
from .common import set_span
from .common import to_int_list
from .tflite_flexbuffer import FlexBufferDecoder

Expand Down Expand Up @@ -240,17 +239,12 @@ def convert_op_to_relay(self):

if len(output_tensors) == 1:
tensor_idx = output_tensors[0].tensor_idx
curr_output = get_tensor_name(self.subgraph, tensor_idx)
ret = set_span(ret, "location: {}, output_name: {}".format(op_idx, curr_output))
self.exp_tab.set_expr(curr_output, ret)
self.exp_tab.set_expr(get_tensor_name(self.subgraph, tensor_idx), ret)
else:
out_names = []
for output_tensor in output_tensors:
out_names.append(get_tensor_name(self.subgraph, output_tensor.tensor_idx))
curr_output = ", ".join(out_names)
ret = set_span(ret, "location: {}, output_name: {}".format(op_idx, curr_output))
for idx, out_name in enumerate(out_names):
self.exp_tab.set_expr(out_name, ret[idx])
for idx, output_tensor in enumerate(output_tensors):
self.exp_tab.set_expr(
get_tensor_name(self.subgraph, output_tensor.tensor_idx), ret[idx]
)

def get_op_code_str(self, op):
"""Get TFLite ops string representation"""
Expand Down
23 changes: 6 additions & 17 deletions src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -389,21 +389,12 @@ Doc RelayTextPrinter::VisitExpr_(const TupleNode* op) {
if (op->fields.size() == 1) {
doc << ",";
}
doc << ")";
if (op->span.defined()) {
doc << " /* " << PrintSpan(op->span) << " */";
}
return doc;
return doc << ")";
}

Doc RelayTextPrinter::VisitExpr_(const TupleGetItemNode* op) {
Doc doc;
doc << Print(op->tuple) << "." << op->index;

if (op->span.defined()) {
doc << " /* " << PrintSpan(op->span) << " */";
}
return doc;
return doc << Print(op->tuple) << "." << op->index;
}

Doc RelayTextPrinter::VisitExpr_(const IfNode* op) {
Expand Down Expand Up @@ -977,13 +968,11 @@ Doc RelayTextPrinter::PrintMapAsAttributeValue(const Map<ObjectRef, ObjectRef>&
return doc;
}

Doc RelayTextPrinter::PrintSpan(const Span& span, bool include_spans) {
Doc RelayTextPrinter::PrintSpan(const Span& span) {
Doc doc;
if (include_spans) {
const auto* span_node = span.as<SpanNode>();
ICHECK(span_node);
doc << span_node->source_name->name;
}
const auto* span_node = span.as<SpanNode>();
ICHECK(span_node);
doc << span_node->source_name->name;
return doc;
}

Expand Down
2 changes: 1 addition & 1 deletion src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
*/
Doc PrintMapAsAttributeValue(const Map<ObjectRef, ObjectRef>& map);

Doc PrintSpan(const Span& span, bool include_spans = true);
Doc PrintSpan(const Span& span);

Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false);

Expand Down
4 changes: 2 additions & 2 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,8 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional<Expr> opt_tuple,

TVM_REGISTER_NODE_TYPE(TupleGetItemNode);

TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int index, Span span) {
return TupleGetItem(tuple, index, span);
TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int index) {
return TupleGetItem(tuple, index);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
Expand Down
47 changes: 0 additions & 47 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,53 +247,6 @@ def visit(op):
torch.cuda.empty_cache()


def verify_span(model_name, input_data=[], custom_convert_map={}):
if isinstance(model_name, str):
baseline_model, baseline_input = load_model(model_name)
elif isinstance(input_data, list):
baseline_model = model_name
baseline_input = input_data
elif isinstance(input_data, torch.Tensor) or len(input_data.shape) == 0:
baseline_model = model_name
baseline_input = [input_data]
else:
assert False, "Unexpected input format"

trace = torch.jit.trace(baseline_model, [input.clone() for input in baseline_input])
if isinstance(baseline_model, torch.nn.Module):
trace = trace.float().eval()

if torch.cuda.is_available():
trace = trace.cuda()
else:
trace = trace.cpu()

input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)]
input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input]))
mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map)

# collect fail cases for the convenience of further improvement
fail_cases = []
mod_main_start = False
for line in str(mod.__str__).split("\n"):
if "@main" in line:
mod_main_start = True
continue

if mod_main_start == True:
if "}" == line:
break
elif not ("/*" in line and "*/" in line):
fail_cases.append(line)

print(fail_cases)
assert len(fail_cases) == 0


def test_span():
verify_span("resnet18")


# Single operator tests
@tvm.testing.uses_gpu
def test_forward_pixel_shuffle():
Expand Down
54 changes: 0 additions & 54 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,60 +298,6 @@ def is_gpu_available():
return False


def verify_span(mod):
# collect fail cases for the convenience of further improvement
fail_cases = []
mod_main_start = False
for line in str(mod.__str__).split("\n"):
if "@main" in line:
mod_main_start = True
continue

if mod_main_start == True:
if "}" == line:
break
elif not ("/*" in line and "*/" in line):
fail_cases.append(line)

print(fail_cases)
assert len(fail_cases) == 0


def simple_model():
input_node = tf.placeholder(shape=[None, None, 3, 1], dtype=np.float32, name="input")

shape = tf.shape(input_node)
stack = tf.stack([shape[0], 3, 3], axis=0)
output_node = tf.reshape(input_node, stack, name="output")
return output_node


#######################################################################
# Span fill up
# -------
def test_span_complement_simple_model():
with tf.Graph().as_default() as graph:
model_graph = simple_model()
graph_def = graph.as_graph_def()

graph_def = tf_testing.ProcessGraphDefParam(graph_def)

mod, params = relay.frontend.from_tensorflow(graph_def, shape={"input:0", (1, 3, 3, 1)})
verify_span(mod)


def test_span_complement_big_model():
with tf.Graph().as_default() as graph:
graph_def = tf_testing.get_workload("ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb")
# Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def)

mod, params = relay.frontend.from_tensorflow(
graph_def, shape={"input_tensor:0", (128, 224, 224, 3)}
)
verify_span(mod)


#######################################################################
# Pooling
# -------
Expand Down
Loading

0 comments on commit 095b639

Please sign in to comment.