From e90ab87fe64d804658124b04f2f962169dee648f Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Wed, 9 Dec 2020 16:37:11 +0100 Subject: [PATCH] Incremental type inference (#6900) --- python/tvm/relay/frontend/pytorch.py | 345 +++++++++++++++++---------- 1 file changed, 214 insertions(+), 131 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index d2c52fbc262a..ca188c78322a 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -28,63 +28,70 @@ import tvm from tvm.topi.utils import get_const_tuple +from tvm.ir import IRModule from .. import analysis as _analysis from .. import expr as _expr +from .. import function as _function from .. import op as _op from ..ty import TupleType, TensorType, Any from ..loops import while_loop from .. import transform from .common import AttrCvt, get_relay_op -from .common import infer_shape as _infer_shape from .common import infer_value as _infer_value from .common import try_infer_value from .common import infer_value_simulated as _infer_value_simulated -from .common import infer_type as _infer_type from ..prelude import Prelude, StaticTensorArrayOps +from ..expr_functor import ExprMutator from . import qnn_torch from .pytorch_utils import is_version_greater_than __all__ = ["from_pytorch"] - -# List ADT utilities -def _infer_type_with_prelude(val, prelude): - body = _infer_type(val, prelude.mod) - return body.checked_type - - -def _convert_to_list_adt(py_lst, prelude): - elem_tys = [_infer_type_with_prelude(elem, prelude) for elem in py_lst] - msg = "List elements should have identical types" - assert all(map(lambda ty: ty == elem_tys[0], elem_tys)), msg - - # get_type returns type_name, ctor1, ..., ctorN - # 1 is nil - _, cons, nil = prelude.mod.get_type("List") - adt_lst = nil() - for elem in reversed(py_lst): - adt_lst = cons(elem, adt_lst) - return adt_lst - - -def _map_tensor_array_constructor(adt_lst, prelude, shape): - static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) - static_tensor_array_ops.register() - tensor_create = prelude.get_tensor_ctor_static("tensor_constructor", "float32", shape) - return prelude.map(tensor_create, adt_lst) - - -def _convert_to_tensor_array(adt_lst, prelude): - _, cons, nil = prelude.mod.get_type("List") - if prelude.length(adt_lst) == 0: - return nil() - - checked_type = _infer_type_with_prelude(prelude.hd(adt_lst), prelude) - shape = checked_type.shape - tensor_array = _map_tensor_array_constructor(adt_lst, prelude, shape) - return tensor_array, tuple(shape) +# This returns a "subgraph" which puts variables whenever +# the type is known. It also records things to map the input +# nodes to the extracted graph's nodes. +# As Python objects are not round-trippable through C++, and +# our type annotations only live in Python, we need to map +# the we need to map the nodes we get in visiting to the nodes +# we used to construct the graph (they are the same in C++, +# match each other in dictionary lookups, but are not the same +# in Python) by using the hint dictionary filled as +# {node: node for node in nodes} to get the type annotations. +# https://discuss.tvm.apache.org/t/round-tripping-objects-through-the-ffi/8440 +class _TypeFinder(ExprMutator): + def __init__(self, types): + super().__init__() + self.counter = 0 + self.vars = {} + self.types = types + self.leave = set() # some variables are not inputs + + def visit_let(self, let): + self.leave.add(let.var) + return super().visit_let(let) + + def visit_function(self, fn): + self.leave.update(fn.params) + return super().visit_function(fn) + + def visit(self, expr): + if expr in self.leave: + return super().visit(expr) + if expr in self.vars: + return self.vars[expr] + if isinstance(expr, tvm.relay.Var): + self.vars[expr] = expr + return expr + if expr in self.types: + ty = self.types[expr] + v = tvm.relay.var(f"_{self.counter}", type_annotation=ty) + self.counter += 1 + self.vars[expr] = v + return v + v = super().visit(expr) + return v def _should_construct_dynamic_list(list_construct_node): @@ -125,17 +132,7 @@ def _is_int_seq(seq): return len(seq) > 0 and all([isinstance(i, int) for i in seq]) -def _is_quantized_tensor(data, prelude): - # If a quantized Torch module is saved and loaded back, dtype will be dropped - # Since dtypes from Torch tensors are not reliable in such cases, we use - # Relay's type inference result to decide if an input tensor is quantized - ty = _infer_type_with_prelude(data, prelude) - return ty.dtype == "uint8" - - # operator implementation - - class PyTorchOpConverter: """A helper class for holding PyTorch op converters.""" @@ -143,26 +140,137 @@ def __init__(self, prelude, default_dtype): self.prelude = prelude self.default_dtype = default_dtype self.create_convert_map() + self.types = {} # map from nodes to (Relay) type annotations + + # this incrementally infers the type, see the comments on the type visitor + # above. + def infer_type(self, node, mod=None): + """An incremental method to infer the type of a node in the relay graph.""" + + if node in self.types: + return self.types[node] + if isinstance(node, tvm.relay.Var): + return node.type_annotation + + tf = _TypeFinder(types=self.types) + new_node = tf.visit(node) + fn = _function.Function(list(tf.vars.values()), new_node) + new_mod = IRModule({"main": fn}) + if mod is not None: + new_mod.update(mod) + new_mod = transform.RemoveUnusedFunctions()(new_mod) + new_mod = transform.InferType()(new_mod) + entry = new_mod["main"] + ty = entry.body.checked_type + self.types[node] = ty + return self.types[node] + + def infer_type_with_prelude(self, val): + body = self.infer_type(val, self.prelude.mod) + return body + + # list ADT utilities + def convert_to_list_adt(self, py_lst): + elem_tys = [self.infer_type_with_prelude(elem) for elem in py_lst] + msg = "List elements should have identical types" + assert all(map(lambda ty: ty == elem_tys[0], elem_tys)), msg + + # get_type returns type_name, ctor1, ..., ctorN + # 1 is nil + _, cons, nil = self.prelude.mod.get_type("List") + adt_lst = nil() + for elem in reversed(py_lst): + adt_lst = cons(elem, adt_lst) + return adt_lst + + def map_tensor_array_constructor(self, adt_lst, shape): + static_tensor_array_ops = StaticTensorArrayOps(self.prelude, "float32", shape) + static_tensor_array_ops.register() + tensor_create = self.prelude.get_tensor_ctor_static("tensor_constructor", "float32", shape) + return self.prelude.map(tensor_create, adt_lst) + + def convert_to_tensor_array(self, adt_lst): + _, cons, nil = self.prelude.mod.get_type("List") + if self.prelude.length(adt_lst) == 0: + return nil() + + checked_type = self.infer_type_with_prelude(self.prelude.hd(adt_lst)) + shape = checked_type.shape + tensor_array = self.map_tensor_array_constructor(adt_lst, shape) + return tensor_array, tuple(shape) + + def infer_shape(self, inputs, mod=None): + """A method to get the output type of an intermediate node in the graph.""" + typ = self.infer_type(inputs, mod=mod) + if hasattr(typ, "shape"): + # Regular operator that outputs tensors + return get_const_tuple(typ.shape) + # The return type is not a tensor, for example List + return typ + + def infer_shape_with_prelude(self, inputs): + return self.infer_shape(inputs, mod=self.prelude.mod) + + def record_output_type(self, output): + if isinstance(output, tuple): + cleaned_output = [o for o in output if o is not None] + types = self.infer_type_with_prelude(_expr.Tuple(cleaned_output)) + for o, t in zip(cleaned_output, types.fields): + self.types[o] = t + elif isinstance(output, _expr.Expr): + self.infer_type_with_prelude(output) + # it can also happen that the type is int or so + + def pytorch_promote_types(self, inputs, dtypes): + """This promotes TVM inputs with TVM dtypes passed like PyTorch would""" + actual_dtypes = [] + for i, inp in enumerate(inputs): + if isinstance(inp, _expr.Expr): + idt = self.infer_type(inp).dtype + actual_dtypes.append(idt) + else: + actual_dtypes.append(dtypes[i]) + dtypes = actual_dtypes + tensor_dtypes = [dt for inp, dt in zip(inputs, dtypes) if not np.isscalar(inp)] + non_tensor_inputs = [inp for inp in inputs if np.isscalar(inp)] + result_type = _pytorch_result_type(tensor_dtypes, non_tensor_inputs) + results = [] + for inp, dt in zip(inputs, dtypes): + if np.isscalar(inp): + results.append(_expr.const(inp, dtype=result_type)) + elif dt == result_type: + results.append(inp) + else: + results.append(_op.cast(inp, result_type)) + return results + + def is_quantized_tensor(self, data): + # If a quantized Torch module is saved and loaded back, dtype will be dropped + # Since dtypes from Torch tensors are not reliable in such cases, we use + # Relay's type inference result to decide if an input tensor is quantized + ty = self.infer_type_with_prelude(data) + return ty.dtype == "uint8" + # Operator implementations def make_elemwise(self, name): def elemwise(inputs, input_types): - data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2]) + data0, data1 = self.pytorch_promote_types(inputs[:2], input_types[:2]) return get_relay_op(name)(data0, data1) return elemwise def min_max_common(self, name_elemwise, name_reduce, inputs, input_types): if len(inputs) == 1: - data = _pytorch_promote_types(inputs[:1], input_types[:1]) + data = self.pytorch_promote_types(inputs[:1], input_types[:1]) return get_relay_op(name_reduce)(data[0]) elif len(inputs) >= 2 and isinstance(inputs[1], int): - data = _pytorch_promote_types(inputs[:1], input_types[:1]) + data = self.pytorch_promote_types(inputs[:1], input_types[:1]) dim = inputs[1] keepdims = inputs[2] if len(inputs) > 2 else False # also return dummy indices return get_relay_op(name_reduce)(data[0], axis=dim, keepdims=keepdims), None else: - data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2]) + data0, data1 = self.pytorch_promote_types(inputs[:2], input_types[:2]) return get_relay_op(name_elemwise)(data0, data1) def max(self, inputs, input_types): @@ -174,7 +282,7 @@ def min(self, inputs, input_types): def make_unary(self, name): def unary(inputs, input_types): # this is just to ensure tensor input - (data,) = _pytorch_promote_types(inputs[:1], input_types[:1]) + (data,) = self.pytorch_promote_types(inputs[:1], input_types[:1]) return get_relay_op(name)(data) return unary @@ -197,7 +305,7 @@ def _get_value(val, dtype): def _get_type(val, inp_type): if isinstance(val, _expr.Expr): - dtype = str(_infer_type(val).checked_type) + dtype = str(self.infer_type(val)) return dtype return inp_type @@ -252,7 +360,7 @@ def unsqueeze(self, inputs, input_types): def concatenate(self, inputs, input_types): def tensor_array_concat(lst, axis): assert axis == 0, "Tensor array concat supported only for axis 0" - tensor_array, shape = _convert_to_tensor_array(lst, self.prelude) + tensor_array, shape = self.convert_to_tensor_array(lst) concat_shape = (Any(),) + shape[1:] concat = self.prelude.get_global_var_static("tensor_array_concat", "float32", shape) concatenated = concat(tensor_array) @@ -279,7 +387,7 @@ def slice(self, inputs, input_types): axis_dtype = "int64" index_size_limit = 2 ** 63 - 1 data = inputs[0] - dshape = _infer_shape(data) + dshape = self.infer_shape(data) ndim = len(dshape) end = [] for dim in dshape: @@ -305,7 +413,7 @@ def slice(self, inputs, input_types): else: tmp.append(_op.cast(_op.expand_dims(b, axis=0), axis_dtype)) begin = _op.concatenate(tmp, axis=0) - btype = _infer_type(begin).checked_type.dtype + btype = self.infer_type(begin).dtype if str(btype) != axis_dtype: begin = _op.cast(begin, axis_dtype) @@ -353,7 +461,7 @@ def slice(self, inputs, input_types): else: end = _op.cast(_op.shape_of(data), axis_dtype) if not isinstance(target_end, tvm.tir.Any): - ttype = _infer_type(target_end).checked_type.dtype + ttype = self.infer_type(target_end).dtype if str(ttype) != axis_dtype: target_end = _op.cast(target_end, axis_dtype) end = _op.scatter( @@ -364,7 +472,7 @@ def slice(self, inputs, input_types): ) if not isinstance(end, list): - etype = _infer_type(end).checked_type.dtype + etype = self.infer_type(end).dtype if str(etype) != axis_dtype: end = _op.cast(end, axis_dtype) @@ -382,7 +490,7 @@ def split(self, inputs, input_types): split_index = split_size indices = [] - while split_index < _infer_shape(data)[dim]: + while split_index < self.infer_shape(data)[dim]: indices.append(split_index) split_index += split_size @@ -460,11 +568,11 @@ def repeat_interleave(self, inputs, input_types): return _op.transform.repeat(data, repeats=repeats, axis=axis) def addcdiv(self, inputs, input_types): - data, t1, t2, c = _pytorch_promote_types(inputs[:4], input_types[:4]) + data, t1, t2, c = self.pytorch_promote_types(inputs[:4], input_types[:4]) return data + (c * (t1 / t2)) def addcmul(self, inputs, input_types): - data, t1, t2, c = _pytorch_promote_types(inputs[:4], input_types[:4]) + data, t1, t2, c = self.pytorch_promote_types(inputs[:4], input_types[:4]) return data + (c * (t1 * t2)) def where(self, inputs, input_types): @@ -472,7 +580,7 @@ def where(self, inputs, input_types): return self.nonzero([inputs[0], True], input_types) cond = inputs[0] - x, y = _pytorch_promote_types(inputs[1:3], input_types[1:3]) + x, y = self.pytorch_promote_types(inputs[1:3], input_types[1:3]) return _op.where(cond, x, y) def full_impl(self, data, fill_value, dtype): @@ -626,7 +734,7 @@ def linspace(self, inputs, input_types): def relu(self, inputs, input_types): data = inputs[0] - if _is_quantized_tensor(data, self.prelude): + if self.is_quantized_tensor(data): assert len(inputs) == 3, "Input quant param not found in op inputs" input_zero_point = _expr.const(inputs[2], dtype="int32") return qnn_torch.quantized_relu(data, input_zero_point) @@ -689,7 +797,7 @@ def adaptive_avg_pool_2d(self, inputs, input_types): def func(x): return _op.nn.adaptive_avg_pool2d(x, output_size=output_size) - if _is_quantized_tensor(data, self.prelude): + if self.is_quantized_tensor(data): return qnn_torch.apply_with_upcast(data, func) return func(data) @@ -780,7 +888,7 @@ def convolution(self, inputs, input_types): dilation = tuple(inputs[5]) if isinstance(weight, _expr.Expr): - inferred_shape = _infer_shape(weight) + inferred_shape = self.infer_shape(weight) weight_shape = [] for infer in inferred_shape: weight_shape.append(infer) @@ -877,7 +985,7 @@ def batch_norm(self, inputs, input_types): data = inputs[0] data_type = input_types[0] - channels = _infer_shape(data) + channels = self.infer_shape(data) if isinstance(inputs[1], _expr.Expr) and isinstance(inputs[2], _expr.Expr): scale = center = True @@ -912,7 +1020,7 @@ def batch_norm(self, inputs, input_types): def instance_norm(self, inputs, input_types): data = inputs[0] data_type = input_types[0] - channels = _infer_shape(data) + channels = self.infer_shape(data) if isinstance(inputs[1], _expr.Expr) and isinstance(inputs[2], _expr.Expr): scale = center = True @@ -938,7 +1046,7 @@ def get_dims(data): import torch if isinstance(data, _expr.Expr): - dims = _infer_shape(data) + dims = self.infer_shape(data) elif isinstance(data, list): dims = data elif isinstance(data, (torch.Tensor, np.ndarray)): @@ -987,7 +1095,7 @@ def transpose(self, inputs, input_types): import torch if isinstance(data, _expr.Expr): - ndims = len(_infer_shape(data, self.prelude.mod)) + ndims = len(self.infer_shape_with_prelude(data)) elif isinstance(data, list): ndims = data elif isinstance(data, (torch.Tensor, np.ndarray)): @@ -1022,7 +1130,7 @@ def flatten(self, inputs, input_types): data = inputs[0] start = int(inputs[1]) end = int(inputs[2]) - dshape = get_const_tuple(_infer_shape(data)) + dshape = get_const_tuple(self.infer_shape_with_prelude(data)) ndim = len(dshape) if end < 0: end += ndim @@ -1059,13 +1167,13 @@ def addmm(self, inputs, input_types): transposed_mat2 = _op.transform.transpose(mat2, axes=[1, 0]) - units = _infer_shape(transposed_mat2)[0] + units = self.infer_shape(transposed_mat2)[0] dense_out = _op.nn.dense(mat1, transposed_mat2, units=units) return dense_out + input_mat def size(self, inputs, input_types): - shape = _infer_shape(inputs[0], self.prelude.mod) + shape = self.infer_shape_with_prelude(inputs[0]) axis = None if len(inputs) > 1: axis = int(inputs[1]) @@ -1102,12 +1210,12 @@ def view(self, inputs, input_types): data = inputs[0] if len(inputs) == 3: - shape_inp = [inputs[1], _infer_shape(inputs[2])[0]] + shape_inp = [inputs[1], self.infer_shape(inputs[2])[0]] else: if isinstance(inputs[1], list): shape_inp = inputs[1] else: - shape_inp = _infer_shape(inputs[1]) + shape_inp = self.infer_shape(inputs[1]) new_shape = shape_inp for i, shape in enumerate(shape_inp): if isinstance(shape, _expr.Expr): @@ -1151,12 +1259,12 @@ def pixel_shuffle(self, inputs, input_types): data = inputs[0] upscale_factor = inputs[1] upscale_squared = upscale_factor * upscale_factor - b, c, h, w = _infer_shape(data) + b, c, h, w = self.infer_shape(data) assert ( c % upscale_squared == 0 ), "input channel should be divisible by square of upscale_factor" - ndims = len(_infer_shape(data, self.prelude.mod)) + ndims = len(self.infer_shape_with_prelude(data)) axes = list(range(ndims)) num_inputs = len(inputs) oc = c // upscale_squared @@ -1212,7 +1320,7 @@ def func(x): count_include_pad=count_include_pad, ) - if _is_quantized_tensor(data, self.prelude): + if self.is_quantized_tensor(data): return qnn_torch.apply_with_upcast(data, func) return func(data) @@ -1253,7 +1361,7 @@ def reduce(inputs, input_types): elif _is_int_seq(inputs[1]): axis = inputs[1] else: - axis = list(_infer_shape(inputs[1])) + axis = list(self.infer_shape(inputs[1])) keepdims = bool(inputs[2]) return get_relay_op(name)(data, axis=axis, keepdims=keepdims) @@ -1338,7 +1446,7 @@ def mean(self, inputs, input_types): def func(x): return _op.mean(x, axis, keepdims, exclude) - if _is_quantized_tensor(data, self.prelude): + if self.is_quantized_tensor(data): assert len(inputs) == 6, "Input quant param not found in op inputs" input_scale = _expr.const(inputs[4]) input_zero_point = _expr.const(inputs[5]) @@ -1353,7 +1461,7 @@ def chunk(self, inputs, input_types): axis = int(inputs[2]) if isinstance(data, _expr.Expr): - inferred_shape = _infer_shape(data, self.prelude.mod) + inferred_shape = self.infer_shape_with_prelude(data) shape = [] for infer in inferred_shape: @@ -1395,8 +1503,8 @@ def matmul(self, inputs, input_types): inputs_1 = inputs[1] # Need to check input shape as batch matmul must be supported. - a_shape = _infer_shape(inputs_0, self.prelude.mod) - b_shape = _infer_shape(inputs_1, self.prelude.mod) + a_shape = self.infer_shape_with_prelude(inputs_0) + b_shape = self.infer_shape_with_prelude(inputs_1) # When performing a batch matmul, we need to properly handle N-dim shapes. if len(a_shape) > 2 or len(b_shape) > 2: @@ -1404,8 +1512,8 @@ def matmul(self, inputs, input_types): a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]]) b = _op.reshape(inputs_1, [-1, b_shape[-2], b_shape[-1]]) # Broadcast b to match batch size of a - new_b_shape = list(_infer_shape(b, self.prelude.mod)) - new_a_shape = _infer_shape(a, self.prelude.mod) + new_b_shape = list(self.infer_shape_with_prelude(b)) + new_a_shape = self.infer_shape_with_prelude(a) if new_a_shape[0] > new_b_shape[0]: new_b_shape[0] = new_a_shape[0] b = _op.broadcast_to(b, new_b_shape) @@ -1431,7 +1539,7 @@ def matmul(self, inputs, input_types): def expand(self, inputs, input_types): data_in = inputs[0] - shape = list(_infer_shape(data_in)) + shape = list(self.infer_shape(data_in)) ndims = len(shape) sizes = inputs[1] @@ -1468,10 +1576,10 @@ def pad(inputs, input_types): if isinstance(inputs[1], list): pad_list = inputs[1] else: - pad_list = list(_infer_shape(inputs[1])) + pad_list = list(self.infer_shape(inputs[1])) # initialize paddings based on input len - pad_len = len(_infer_shape(data)) * 2 + pad_len = len(self.infer_shape(data)) * 2 paddings = [0] * pad_len if len(pad_list) >= 2: @@ -1526,7 +1634,7 @@ def to(self, inputs, input_types): ret = data if isinstance(data, _expr.Expr): - actual_dtype = str(_infer_type(data).checked_type.dtype) + actual_dtype = str(self.infer_type(data).dtype) if dtype in cast_map and cast_map[dtype] != actual_dtype: ret = _op.cast(data, cast_map[dtype]) elif dtype in cast_map: @@ -1534,8 +1642,7 @@ def to(self, inputs, input_types): return ret - @staticmethod - def get_upsample_out_size(inputs, method): + def get_upsample_out_size(self, inputs, method): # This assumes a static shape out_size = [] if inputs[1] is not None: @@ -1549,7 +1656,7 @@ def get_upsample_out_size(inputs, method): scales = inputs[scale_index] assert scales is not None, "neither out size nor scale provided" assert isinstance(scales, list) - ishape = _infer_shape(inputs[0]) + ishape = self.infer_shape(inputs[0]) for i, scale in enumerate(scales): out_size.append(int(math.floor(float(ishape[2 + i]) * scale))) @@ -1575,7 +1682,7 @@ def upsample(inputs, input_types): def func(x): return _op.image.resize(x, out_size, "NCHW", method, coord_trans) - if _is_quantized_tensor(data, self.prelude): + if self.is_quantized_tensor(data): # input qparams are manually appended by us assert isinstance(inputs[-2], float) assert isinstance(inputs[-1], int) @@ -1610,8 +1717,8 @@ def upsample3d(inputs, input_types): def expand_as(self, inputs, input_types): target = inputs[1] - t0 = _infer_type(inputs[0]).checked_type.dtype - t1 = _infer_type(inputs[1]).checked_type.dtype + t0 = self.infer_type(inputs[0]).dtype + t1 = self.infer_type(inputs[1]).dtype if str(t0) != str(t1): target = _op.cast(target, t0) return _op.broadcast_to_like(inputs[0], target) @@ -1683,7 +1790,7 @@ def add(self, inputs, input_types): def tensor_array_stack(self, inputs, input_types): dim = inputs[1] assert dim == 0, "stacking on a dynamic tensor list only supported on a first axis" - tensor_array, shape = _convert_to_tensor_array(inputs[0], self.prelude) + tensor_array, shape = self.convert_to_tensor_array(inputs[0]) stacked_shape = (Any(),) + shape stack = self.prelude.get_global_var_static("tensor_array_stack", "float32", shape) @@ -1702,14 +1809,14 @@ def stack(self, inputs, input_types): else: # List ADT case assert isinstance(inputs[0], _expr.Expr) - ty = _infer_type_with_prelude(inputs[0], self.prelude) + ty = self.infer_type_with_prelude(inputs[0]) list_ty = self.prelude.mod.get_global_type_var("List") msg = "The input list is expected to be List ADT" assert isinstance(ty, tvm.ir.TypeCall) and ty.func == list_ty, msg return self.tensor_array_stack(inputs, input_types) def rsub(self, inputs, input_types): - data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2]) + data0, data1 = self.pytorch_promote_types(inputs[:2], input_types[:2]) # TODO (t-vi): should this also be part of the type promotion? alpha = _expr.const(float(inputs[2])) @@ -1792,7 +1899,7 @@ def nms(self, inputs, input_types): return _op.cast(ret, "int64") def logsumexp(self, inputs, input_types): - data = _pytorch_promote_types(inputs[:1], input_types[:1]) + data = self.pytorch_promote_types(inputs[:1], input_types[:1]) dim_list = inputs[1] keepdim = inputs[2] if len(inputs) > 2 else False # dim is output of prim::ListConstruct, even if it is int in python code @@ -1816,7 +1923,7 @@ def roi_align(self, inputs, input_types): def unbind(self, inputs, input_types): data = inputs[0] dim = int(inputs[1]) - ishapes = _infer_shape(data) + ishapes = self.infer_shape(data) if dim >= len(ishapes): msg = "Please check input dim, it shouldn't be greater than or equal to rank." raise AttributeError(msg) @@ -1833,7 +1940,7 @@ def unbind(self, inputs, input_types): def shape_as_tensor(self, inputs, input_types): is_symbolic_shape = False - input_shape = _infer_shape(inputs[0], self.prelude.mod) + input_shape = self.infer_shape(inputs[0], self.prelude.mod) for axis in input_shape: if not isinstance(axis, (int, tvm.tir.IntImm)): is_symbolic_shape = True @@ -1921,7 +2028,7 @@ def empty(self, inputs, input_types): def bincount(self, inputs, input_types): data = inputs[0] weights = inputs[1] - input_type = _infer_type(data).checked_type.dtype + input_type = self.infer_type(data).dtype if input_type == "int64": logging.warning( "Casting an int64 input to int32, since we do not have int64 atomic add" @@ -1931,7 +2038,7 @@ def bincount(self, inputs, input_types): maximum = _op.max(data) dim = maximum + _expr.const(1, dtype="int32") if weights: - weight_type = _infer_type(weights).checked_type + weight_type = self.infer_type(weights) out_dtype = weight_type.dtype updates = weights else: @@ -2235,7 +2342,7 @@ def get_input(index): def get_var(name, val): if val: - checked_type = _infer_type_with_prelude(val, self.prelude) + checked_type = self.infer_type_with_prelude(val) if hasattr(checked_type, "shape"): shape = get_const_tuple(checked_type.shape) actual_shape = [] @@ -2325,7 +2432,7 @@ def convert_operators(self, operators, outputs, ret_names): if operator == "prim::Constant": outputs[node_name] = _get_constant(op_node) elif operator == "prim::ListConstruct" and _should_construct_dynamic_list(op_node): - outputs[node_name] = _convert_to_list_adt(inputs, self.prelude) + outputs[node_name] = self.convert_to_list_adt(inputs) elif operator == "prim::ListConstruct": # This assumes that no more elements will be appended to this list # In this case, we keep the Python list @@ -2355,6 +2462,7 @@ def convert_operators(self, operators, outputs, ret_names): relay_out = relay_op( inputs, _get_input_types(op_node, outputs, default_dtype=self.default_dtype) ) + self.record_output_type(relay_out) if isinstance(relay_out, tuple): # This is for torch operators that return multiple outputs @@ -2406,30 +2514,6 @@ def _pytorch_result_type(dtypes, non_tensor_inputs): return result_type -def _pytorch_promote_types(inputs, dtypes): - """This promotes TVM inputs with TVM dtypes passed like PyTorch would""" - actual_dtypes = [] - for i, inp in enumerate(inputs): - if isinstance(inp, _expr.Expr): - idt = _infer_type(inp).checked_type.dtype - actual_dtypes.append(idt) - else: - actual_dtypes.append(dtypes[i]) - dtypes = actual_dtypes - tensor_dtypes = [dt for inp, dt in zip(inputs, dtypes) if not np.isscalar(inp)] - non_tensor_inputs = [inp for inp in inputs if np.isscalar(inp)] - result_type = _pytorch_result_type(tensor_dtypes, non_tensor_inputs) - results = [] - for inp, dt in zip(inputs, dtypes): - if np.isscalar(inp): - results.append(_expr.const(inp, dtype=result_type)) - elif dt == result_type: - results.append(inp) - else: - results.append(_op.cast(inp, result_type)) - return results - - # Helper functions for operator implementation def _convert_dtype_value(val): """converts a PyTorch the PyTorch numeric type id to a torch scalar type.""" @@ -2957,5 +3041,4 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt ret = converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, ret_name) mod["main"] = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0]) - return transform.RemoveUnusedFunctions()(mod), tvm_params