diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 1f5ac445f7ea8..ff03282756049 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -43,7 +43,7 @@ class onnx_input: - """ Dual purpose list or dictionary access object.""" + """Dual purpose list or dictionary access object.""" def __init__(self): self.input_keys = [] @@ -126,7 +126,10 @@ def get_info(info_proto): shape.append(value) name = info_proto.name - dtype = get_type(info_proto.type.tensor_type.elem_type) + if info_proto.type.tensor_type.elem_type: + dtype = get_type(info_proto.type.tensor_type.elem_type) + else: + dtype = None return name, shape, dtype, shape_name @@ -2495,6 +2498,8 @@ def get_var(name, val, scan=False): scan_output_init = [] for i in range(num_scan_outputs): name, shape, dtype, _ = get_info(body.output[i + 1 + num_deps]) + if dtype is None: + dtype = infer_type(loop_deps[i]).checked_type.dtype if dtype == "float": dtype = "float32" scan_output_vars.append( diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index c25dc01651624..f81606c6ae50e 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -45,7 +45,7 @@ def get_input_data_shape_dict(graph_def, input_data): def get_tvm_output_with_vm( graph_def, input_data, target, device, opset=None, freeze_params=False, convert_to_static=False ): - """ Generic function to execute and get tvm output with vm executor""" + """Generic function to execute and get tvm output with vm executor""" if not isinstance(input_data, list): input_data = [input_data] _, shape_dict = get_input_data_shape_dict(graph_def, input_data) @@ -67,7 +67,7 @@ def get_tvm_output_with_vm( def get_tvm_output( graph_def, input_data, target, device, output_shape=None, output_dtype="float32", opset=None ): - """ Generic function to execute and get tvm output""" + """Generic function to execute and get tvm output""" # TODO: Resolve the issues and remove the following lines target = "llvm" device = tvm.cpu(0) @@ -4218,8 +4218,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): "test_qlinearconv/", "test_qlinearmatmul_2D/", "test_qlinearmatmul_3D/", - "test_range_float_type_positive_delta_expanded/", - "test_range_int32_type_negative_delta_expanded/", "test_resize_tf_crop_and_resize/", ## For these three tests, ONNX 1.6.0 has incorrect graphs, they pass with ONNX 1.7.0 "test_resize_upsample_sizes_nearest_ceil_half_pixel/",