diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 9405bc532702..7a3b168fc8fd 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -17,6 +17,7 @@ # pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines # pylint: disable=import-outside-toplevel """ONNX: Open Neural Network Exchange frontend for Relay.""" +import copy import warnings import numpy as np import tvm @@ -1028,10 +1029,6 @@ def _impl_v9(cls, inputs, attr, params): 'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode) ) - if method == "nearest_neighbor": - align_corners = False - else: - align_corners = True # in 3d case, we use the purely static op if dims == 5: if isinstance(scales, _expr.Call): @@ -1065,7 +1062,7 @@ def _impl_v9(cls, inputs, attr, params): scale_w, layout=layout, method=method, - align_corners=align_corners, + align_corners=False, ) return out @@ -1111,17 +1108,22 @@ class Split(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - splits = attr.get("split", False) - if splits: + splits = attr.get("split", None) + if splits is not None: + indices = [] attr["indices_or_sections"] = [] index = 0 for i in splits[:-1]: index += i - attr["indices_or_sections"].append(index) + indices.append(index) # When splits isnt specified divide evenly over axis. else: - attr["indices_or_sections"] = attr["tvm_custom"]["num_outputs"] - return AttrCvt("split", ignores=["split"])(inputs, attr, params) + indices = attr["tvm_custom"]["num_outputs"] + output = _op.split(inputs[0], indices, attr.get("axis", 0)) + # If the output of split is a single value, unpack if from the TupleWrapper + if len(output) == 1: + output = output[0] + return output class Slice(OnnxOpConverter): @@ -1227,7 +1229,9 @@ class GatherND(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - return _op.gather_nd(inputs[0], inputs[1]) + indices_dims = len(infer_shape(inputs[1])) + indices = _op.transpose(inputs[1], axes=[-1] + list(range(indices_dims - 1))) + return _op.gather_nd(inputs[0], indices) class Scatter(OnnxOpConverter): @@ -1538,15 +1542,6 @@ def _impl_v1(cls, inputs, attr, params): class Tile(Elemwise): """Operator converter for Tile""" - @classmethod - def _impl_v1(cls, inputs, attr, params): - if "repeats" not in attr: - raise tvm.error.OpAttributeInvalid( - 'Attribute "repeats" should be set ' "for operator Tile." - ) - reps = attr.pop("repeats") # The number of times repeating the tensor data. - return _op.tile(inputs[0], reps) - @classmethod def _impl_v6(cls, inputs, attr, params): return _op.tile(inputs[0], inputs[1]) @@ -2113,7 +2108,9 @@ def _impl_v11(cls, inputs, attr, params): cond = inputs[1] loop_deps = inputs[2:] num_deps = len(loop_deps) - body = attr["body"] + # Create a copy of the body function to prevent the original + # from being modified. + body = copy.copy(attr["body"]) iter_dtype = infer_type(max_loop_count).checked_type.dtype # Determine what condition mode we're in. @@ -2150,6 +2147,8 @@ def get_var(name, val, scan=False): checked_type = infer_type(val) if hasattr(checked_type, "type_annotation"): checked_type = checked_type.type_annotation + if hasattr(checked_type, "checked_type"): + checked_type = checked_type.checked_type shape = get_const_tuple(checked_type.shape) actual_shape = [] for dim in shape: @@ -2185,8 +2184,14 @@ def get_var(name, val, scan=False): scan_output_init = [] for i in range(num_scan_outputs): name, shape, dtype, _ = get_info(body.output[i + 1 + num_deps]) - scan_output_vars.append(_expr.var(name, shape=([_ty.Any()] + shape), dtype=dtype)) - scan_output_init.append(_op.reshape(_expr.const([]), [0] + shape)) + if dtype == "float": + dtype = "float32" + scan_output_vars.append( + _expr.var(name, shape=([_ty.Any()] * (len(shape) + 1)), dtype=dtype) + ) + scan_output_init.append( + _op.reshape(_expr.const(np.array([]).astype(dtype)), [0] + [1] * len(shape)) + ) # Now we can remove loop iter variables from our inner loop's inputs. # This is kind of a hack since we have graph inputs that we don't @@ -2219,11 +2224,6 @@ def body_fn(*loop_inputs): new_loop_vars = [loop_outputs[i] for i in range(1, 1 + num_deps)] new_scan_outputs = [loop_outputs[i] for i in range(1 + num_deps, len(loop_outputs))] - # Increment counter. - if max_loop_count is not None: - incr = _expr.const(1, dtype=iter_dtype) - loop_count = loop_count + incr - # Add new scan outputs to tracking combined_scan_outputs = [] for i, scan in enumerate(scan_outputs): @@ -2231,6 +2231,11 @@ def body_fn(*loop_inputs): combined_scan = _op.concatenate([scan, new_scan], axis=0) combined_scan_outputs.append(combined_scan) + # Increment counter. + if max_loop_count is not None: + incr = _expr.const(1, dtype=iter_dtype) + loop_count = loop_count + incr + # Pack loop outputs for next iteration # [iter_count, cond, loop_deps, loop_scans] return [loop_count, max_count, new_cond] + new_loop_vars + combined_scan_outputs @@ -2630,12 +2635,12 @@ def _get_convert_map(opset): "Greater": Greater.get_converter(opset), "Less": Less.get_converter(opset), "Log": Renamer("log"), - "ACos": Renamer("acos"), - "ACosh": Renamer("acosh"), - "ASin": Renamer("asin"), - "ASinh": Renamer("asinh"), - "ATan": Renamer("atan"), - "ATanh": Renamer("atanh"), + "Acos": Renamer("acos"), + "Acosh": Renamer("acosh"), + "Asin": Renamer("asin"), + "Asinh": Renamer("asinh"), + "Atan": Renamer("atan"), + "Atanh": Renamer("atanh"), "Cos": Renamer("cos"), "Cosh": Renamer("cosh"), "Sin": Renamer("sin"), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 96be6fba113a..20937d2060c5 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. import numpy as np -import math import onnx from onnx import helper, TensorProto, mapping, numpy_helper import torch @@ -94,7 +93,7 @@ def get_tvm_output( # execute m.run() # get outputs - if isinstance(output_shape, list) and isinstance(output_dtype, list): + if isinstance(output_shape, list): tvm_output_list = [] for i, _ in enumerate(output_shape): tvm_output = m.get_output(i) @@ -105,17 +104,19 @@ def get_tvm_output( return tvm_output.asnumpy() -def get_onnxruntime_output(model, inputs, dtype="float32"): +def get_onnxruntime_output(model, inputs): import onnxruntime.backend rep = onnxruntime.backend.prepare(model, "CPU") - if isinstance(inputs, list) and len(inputs) > 1: - return rep.run(inputs) - elif isinstance(inputs, list) and len(inputs) == 1: + if isinstance(inputs, list) and len(inputs) == 1: inp = inputs[0] else: inp = inputs - return rep.run(inp.astype(dtype))[0] + output = rep.run(inp) + # Unpack output if there's only a single value. + if len(output) == 1: + output = output[0] + return output def verify_with_ort_with_inputs( @@ -130,15 +131,11 @@ def verify_with_ort_with_inputs( dtype="float32", rtol=1e-5, atol=1e-5, + apply_softmax=False, ): - def flatten(out): - if isinstance(out, list) and len(out) == 1: - out = out[0] - if isinstance(out, np.ndarray): - return out.flatten() - return out - - ort_out = get_onnxruntime_output(model, inputs, dtype) + if opset is not None: + model.opset_import[0].version = opset + ort_out = get_onnxruntime_output(model, inputs) if targets is None: targets = [tgt for (tgt, _) in tvm.testing.enabled_targets()] @@ -157,8 +154,15 @@ def flatten(out): ) else: tvm_out = get_tvm_output(model, inputs, target, ctx, out_shape, dtype, opset=opset) - - tvm.testing.assert_allclose(flatten(ort_out), flatten(tvm_out), rtol=rtol, atol=atol) + if not isinstance(tvm_out, list): + tvm_out = [tvm_out] + if not isinstance(ort_out, list): + ort_out = [ort_out] + for tvm_val, ort_val in zip(tvm_out, ort_out): + if apply_softmax: + ort_val = scipy.special.softmax(ort_val) + tvm_val = scipy.special.softmax(tvm_val) + tvm.testing.assert_allclose(ort_val, tvm_val, rtol=rtol, atol=atol) def verify_with_ort( @@ -342,7 +346,7 @@ def verify_depth_to_space(inshape, outshape, mode, blockSize): model = helper.make_model(graph, producer_name="depth_to_space_test") - verify_with_ort(model, [inshape], outshape) + verify_with_ort(model, [inshape], [outshape]) @tvm.testing.uses_gpu @@ -365,7 +369,7 @@ def verify_space_to_depth(inshape, outshape, blockSize): model = helper.make_model(graph, producer_name="space_to_depth_test") - verify_with_ort(model, [inshape], outshape) + verify_with_ort(model, [inshape], [outshape]) @tvm.testing.uses_gpu @@ -494,11 +498,8 @@ def test_squeeze(): ) model = helper.make_model(graph, producer_name="squeeze_test") - - for target, ctx in tvm.testing.enabled_targets(): - x = np.random.uniform(size=in_shape).astype("float32") - tvm_out = get_tvm_output(model, x, target, ctx, out_shape, "float32") - tvm.testing.assert_allclose(out_shape, tvm_out.shape) + x = np.random.uniform(size=in_shape).astype("float32") + verify_with_ort_with_inputs(model, [x], [out_shape]) @tvm.testing.uses_gpu @@ -518,11 +519,7 @@ def test_flatten(): ) model = helper.make_model(graph, producer_name="flatten_test") - - for target, ctx in tvm.testing.enabled_targets(): - x = np.random.uniform(size=in_shape).astype("int32") - tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, "float32") - tvm.testing.assert_allclose(ref_shape, tvm_out.shape) + verify_with_ort(model, [in_shape]) @tvm.testing.uses_gpu @@ -540,16 +537,12 @@ def test_unsqueeze(): ) model = helper.make_model(graph, producer_name="squeeze_test") - - for target, ctx in tvm.testing.enabled_targets(): - x = np.random.uniform(size=in_shape).astype("float32") - tvm_out = get_tvm_output(model, x, target, ctx, out_shape, "float32") - tvm.testing.assert_allclose(out_shape, tvm_out.shape) + verify_with_ort(model, [in_shape]) def verify_gather(in_shape, indices, axis, dtype): x = np.random.uniform(size=in_shape).astype(dtype) - indices = np.array(indices, dtype="int32") + indices = np.array(indices, dtype="int64") out_np = np.take(x, indices, axis=axis) y = helper.make_node("Gather", ["in", "indices"], ["out"], axis=axis) @@ -558,16 +551,19 @@ def verify_gather(in_shape, indices, axis, dtype): [y], "gather_test", inputs=[ - helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape)), - helper.make_tensor_value_info("indices", TensorProto.INT32, list(indices.shape)), + helper.make_tensor_value_info( + "in", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], list(in_shape) + ), + helper.make_tensor_value_info("indices", TensorProto.INT64, list(indices.shape)), + ], + outputs=[ + helper.make_tensor_value_info( + "out", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], list(out_np.shape) + ) ], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_np.shape))], ) model = helper.make_model(graph, producer_name="gather_test") - - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [x, indices], target, ctx, out_np.shape) - tvm.testing.assert_allclose(out_np, tvm_out) + verify_with_ort_with_inputs(model, [x, indices], dtype=dtype) @tvm.testing.uses_gpu @@ -660,10 +656,7 @@ def _test_slice_iteration_v1(indata, outdata, starts, ends, axes=None): ) model = helper.make_model(graph, producer_name="slice_test") - - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, "float32", opset=1) - tvm.testing.assert_allclose(outdata, tvm_out) + verify_with_ort_with_inputs(model, [indata], [outdata.shape], opset=1) def _test_slice_iteration_v10(indata, outdata, **attrs): @@ -738,14 +731,14 @@ def add_noop_to_input_attr(attr_name, attr): if axes: axes = np.asarray(axes) - inputs.append(helper.make_tensor_value_info("axes", TensorProto.INT32, list(axes.shape))) - initializer.append(helper.make_tensor("axes", TensorProto.INT32, list(axes.shape), axes)) + inputs.append(helper.make_tensor_value_info("axes", TensorProto.INT64, list(axes.shape))) + initializer.append(helper.make_tensor("axes", TensorProto.INT64, list(axes.shape), axes)) if steps: assert axes is not None and len(axes) == len(steps) steps = np.asarray(steps) - inputs.append(helper.make_tensor_value_info("steps", TensorProto.INT32, list(axes.shape))) - initializer.append(helper.make_tensor("steps", TensorProto.INT32, list(steps.shape), steps)) + inputs.append(helper.make_tensor_value_info("steps", TensorProto.INT64, list(axes.shape))) + initializer.append(helper.make_tensor("steps", TensorProto.INT64, list(steps.shape), steps)) y = helper.make_node("Slice", ["data", *slice_inputs], ["out"]) @@ -758,10 +751,7 @@ def add_noop_to_input_attr(attr_name, attr): initializer=initializer, ) model = helper.make_model(graph, producer_name="slice_test") - - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output_with_vm(model, indata, target, ctx, opset=10, freeze_params=True) - tvm.testing.assert_allclose(outdata, tvm_out) + verify_with_ort_with_inputs(model, [indata], opset=10, freeze_params=True, use_vm=True) # TODO(mbrookhart): enable once VM supports heterogenous execution @@ -854,10 +844,7 @@ def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs, o ) model = helper.make_model(graph, producer_name=opname + "_test") - - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, dtype, opset=opset) - tvm.testing.assert_allclose(outdata, tvm_out) + verify_with_ort_with_inputs(model, [indata], [outdata.shape], opset=opset, dtype=dtype) @tvm.testing.uses_gpu @@ -879,6 +866,7 @@ def test_clip(): "float32", "Clip", {"min": -1.0, "max": 1.0}, + opset=6, ) _test_onnx_op_elementwise( @@ -888,7 +876,7 @@ def test_clip(): "float32", "Clip", {"max": 1.0}, - opset=1, + opset=6, ) _test_onnx_op_elementwise( @@ -898,7 +886,7 @@ def test_clip(): "float32", "Clip", {"min": -1.0}, - opset=1, + opset=6, ) @@ -919,7 +907,7 @@ def test_clip_min_max_as_inputs(): ) model = helper.make_model(graph, producer_name="clip_test") - verify_with_ort(model, [input_shape], input_shape) + verify_with_ort(model, [input_shape], out_shape=[input_shape]) @tvm.testing.uses_gpu @@ -941,10 +929,7 @@ def _test_finite_ops(inshape, outfunc, npargs, dtype, opname, kwargs): ) model = helper.make_model(graph, producer_name=opname + "_test") - - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, dtype) - tvm.testing.assert_allclose(outdata, tvm_out) + verify_with_ort_with_inputs(model, [indata], [outdata.shape], dtype=dtype) @tvm.testing.uses_gpu @@ -957,10 +942,9 @@ def test_isnan(): _test_finite_ops((2, 4, 5, 6), np.isnan, {}, "float32", "IsNaN", {}) -def verify_gather_nd(in_shape, indices, dtype): +def verify_gather_nd(in_shape, indices, out_shape, dtype="float32"): x = np.random.uniform(size=in_shape).astype(dtype) - indices = np.array(indices, dtype="int32") - out_np = tvm.topi.testing.gather_nd_python(x, indices) + indices = np.array(indices, dtype="int64") y = helper.make_node("GatherND", ["in", "indices"], ["out"]) @@ -968,23 +952,27 @@ def verify_gather_nd(in_shape, indices, dtype): [y], "gather_test", inputs=[ - helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape)), - helper.make_tensor_value_info("indices", TensorProto.INT32, list(indices.shape)), + helper.make_tensor_value_info( + "in", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], list(in_shape) + ), + helper.make_tensor_value_info("indices", TensorProto.INT64, list(indices.shape)), + ], + outputs=[ + helper.make_tensor_value_info( + "out", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], list(out_shape) + ) ], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_np.shape))], ) model = helper.make_model(graph, producer_name="gather_test") - - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [x, indices], target, ctx, out_np.shape) - tvm.testing.assert_allclose(out_np, tvm_out) + verify_with_ort_with_inputs(model, [x, indices], [out_shape]) @tvm.testing.uses_gpu def test_gather_nd(): - verify_gather_nd((2, 2), [[0, 0], [1, 1]], "int32") - verify_gather_nd((3, 3, 3), [[0, 1], [1, 0]], "float32") - verify_gather_nd((4, 3, 5, 6), [[2, 1, 0, 0]], "float32") + verify_gather_nd([2, 2], [[0, 0], [1, 1]], [2], "int32") + verify_gather_nd([2, 2], [[1], [0]], [2, 2]) + verify_gather_nd([2, 2, 2], [[0, 1], [1, 0]], [2, 2]) + verify_gather_nd([2, 2, 2], [[[0, 1]], [[1, 0]]], [2, 1, 2]) # TODO(mbrookhart): enable once VM supports heterogenous execution @@ -1011,6 +999,7 @@ def test_onehot(): model = helper.make_model(graph, producer_name="onehot_test") + # TODO(jwfromm): Replace test against np with test against onnxrt once we update versions. for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output_with_vm( model, [indices_array, np.array([depth]).astype("int32"), values], target, ctx @@ -1022,10 +1011,10 @@ def test_onehot(): def test_matmul(): a_shape = (4, 3) b_shape = (3, 4) + out_shape = [a_shape[0], b_shape[1]] a_array = np.random.uniform(size=a_shape).astype("float32") b_array = np.random.uniform(size=b_shape).astype("float32") - out_np = np.matmul(a_array, b_array) mul_node = helper.make_node("MatMul", ["a", "b"], ["out"]) @@ -1036,14 +1025,11 @@ def test_matmul(): helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)), helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)), ], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_np.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))], ) model = helper.make_model(graph, producer_name="matmul_test") - - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [a_array, b_array], target, ctx, out_np.shape) - tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) + verify_with_ort_with_inputs(model, [a_array, b_array]) def verify_batch_matmul(a_shape, b_shape, out_shape, target, ctx): @@ -1063,10 +1049,7 @@ def verify_batch_matmul(a_shape, b_shape, out_shape, target, ctx): ) model = helper.make_model(graph, producer_name="matmul_test") - onnx_out = get_onnxruntime_output(model, [a_array, b_array], "float32")[0] - - tvm_out = get_tvm_output_with_vm(model, [a_array, b_array], target, ctx) - tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) + verify_with_ort_with_inputs(model, [a_array, b_array], use_vm=True, targets=[target]) # TODO(mbrookhart): enable cuda once VM supports heterogenous execution @@ -1152,29 +1135,7 @@ def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None): outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(shape))], ) model = helper.make_model(graph, producer_name="lrn_test") - - def _get_python_lrn(): - square_sum = np.zeros(shape).astype(dtype) - for n, c, h, w in np.ndindex(in_array.shape): - square_sum[n, c, h, w] = sum( - in_array[ - n, - max(0, c - int(math.floor((nsize - 1) / 2))) : min( - 5, c + int(math.ceil((nsize - 1) / 2)) + 1 - ), - h, - w, - ] - ** 2 - ) - py_out = in_array / ((bias + (alpha / nsize) * square_sum) ** beta) - return py_out - - for target, ctx in tvm.testing.enabled_targets(): - input_name = model.graph.input[0].name - py_out = _get_python_lrn() - tvm_out = get_tvm_output(model, in_array, target, ctx, py_out.shape, "float32") - tvm.testing.assert_allclose(py_out, tvm_out, rtol=1e-5, atol=1e-5) + verify_with_ort_with_inputs(model, [in_array]) @tvm.testing.uses_gpu @@ -1184,21 +1145,10 @@ def test_lrn(): def verify_instance_norm(shape, axis=1): - def _get_python_instance_norm(x, gamma, beta, epsilon=1e-5): - dims_x = len(x.shape) - axis = tuple(range(2, dims_x)) - mean = np.mean(x, axis=axis, keepdims=True) - var = np.var(x, axis=axis, keepdims=True) - dim_ones = (1,) * (dims_x - 2) - gamma = gamma.reshape(-1, *dim_ones) - beta = beta.reshape(-1, *dim_ones) - return gamma * (x - mean) / np.sqrt(var + epsilon) + beta - x = np.random.randn(*shape).astype(np.float32) gamma = np.random.randn(shape[1]).astype(np.float32) beta = np.random.randn(shape[1]).astype(np.float32) epsilon = 1e-5 - y = _get_python_instance_norm(x, gamma, beta, epsilon).astype(np.float32) node = onnx.helper.make_node( "InstanceNormalization", @@ -1217,9 +1167,7 @@ def _get_python_instance_norm(x, gamma, beta, epsilon=1e-5): outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(shape))], ) model = helper.make_model(graph, producer_name="instance_norm_test") - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [x, gamma, beta], target, ctx, shape, "float32") - tvm.testing.assert_allclose(y, tvm_out, rtol=1e-5, atol=1e-5) + verify_with_ort_with_inputs(model, [x, gamma, beta], out_shape=[shape]) @tvm.testing.uses_gpu @@ -1230,14 +1178,13 @@ def test_instance_norm(): verify_instance_norm((8, 7, 6, 5, 4)) -def _test_upsample_nearest(): +def verify_upsample_nearest(): scale = 2 in_shape = (1, 1, 3, 3) out_shape = (1, 1, 3 * scale, 3 * scale) y = helper.make_node("Upsample", ["in"], ["out"], mode="nearest", scales=[1.0, 1.0, 2.0, 2.0]) in_array = np.random.uniform(size=in_shape).astype(np.float32) - out_array = tvm.topi.testing.upsampling_python(in_array, (scale, scale), "NCHW") graph = helper.make_graph( [y], @@ -1247,13 +1194,10 @@ def _test_upsample_nearest(): ) model = helper.make_model(graph, producer_name="upsample_nearest_test") + verify_with_ort_with_inputs(model, [in_array], [out_shape], opset=7) - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, "float32") - tvm.testing.assert_allclose(out_array, tvm_out) - -def _test_upsample3d_nearest(): +def verify_upsample3d_nearest(): scale = 2 in_shape = (1, 1, 3, 3, 3) out_shape = (1, 1, 3 * scale, 3 * scale, 3 * scale) @@ -1262,7 +1206,6 @@ def _test_upsample3d_nearest(): ) in_array = np.random.uniform(size=in_shape).astype(np.float32) - out_array = tvm.topi.testing.upsampling3d_python(in_array, (scale, scale, scale), "NCDHW") graph = helper.make_graph( [y], @@ -1272,20 +1215,17 @@ def _test_upsample3d_nearest(): ) model = helper.make_model(graph, producer_name="upsample_nearest_test") - - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, "float32") - tvm.testing.assert_allclose(out_array, tvm_out) + # Upsample is deprecated after opset 9 + verify_with_ort_with_inputs(model, [in_array], [out_shape], opset=7) -def _test_upsample_bilinear(): +def verify_upsample_bilinear(): scale = 2 in_shape = (1, 1, 3, 3) out_shape = (1, 1, 3 * scale, 3 * scale) y = helper.make_node("Upsample", ["in"], ["out"], mode="linear", scales=[1.0, 1.0, 2.0, 2.0]) in_array = np.random.uniform(size=in_shape).astype(np.float32) - out_array = tvm.topi.testing.bilinear_resize_python(in_array, (3 * scale, 3 * scale), "NCHW") graph = helper.make_graph( [y], @@ -1295,51 +1235,10 @@ def _test_upsample_bilinear(): ) model = helper.make_model(graph, producer_name="upsample_bilinear_test") - - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, "float32") - tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5) + verify_with_ort_with_inputs(model, [in_array], [out_shape], opset=7) -def _test_upsample_bilinear_opset9(): - scale = 2 - in_shape = (1, 1, 3, 3) - out_shape = (1, 1, 3 * scale, 3 * scale) - y = helper.make_node("Upsample", ["in", "scales"], ["out"], mode="linear") - scales = [1, 1, 2, 2] - in_array = np.random.uniform(size=in_shape).astype(np.float32) - out_array = tvm.topi.testing.bilinear_resize_python(in_array, (3 * scale, 3 * scale), "NCHW") - - ref_node = helper.make_node( - "Constant", - inputs=[], - outputs=["const"], - value=onnx.helper.make_tensor( - name="const_tensor", - data_type=TensorProto.FLOAT, - dims=scales, - vals=np.random.random(scales).flatten().astype(float), - ), - ) - - shape_node = helper.make_node("Shape", ["const"], ["scales"]) - - graph = helper.make_graph( - [ref_node, shape_node, y], - "upsample_bilinear_opset9_test", - inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))], - ) - - model = helper.make_model(graph, producer_name="upsample_bilinear_opset9_test") - - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output_with_vm( - model, [in_array], target, ctx, opset=9, freeze_params=True - ) - - -def _test_upsample3d_trilinear(): +def verify_upsample3d_trilinear(): scale = 2 in_shape = (1, 1, 3, 3, 3) out_shape = (1, 1, 3 * scale, 3 * scale, 3 * scale) @@ -1374,7 +1273,8 @@ def _test_upsample3d_trilinear(): ) model = helper.make_model(graph, producer_name="upsample_trilinear_test") - + # TODO(jwfromm): Trilinear upsampling not supported in 1.0.0 onnxruntime. + # Replace topi comparison with verify_with_ort once we update. for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, "float32") tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5) @@ -1383,41 +1283,36 @@ def _test_upsample3d_trilinear(): # TODO(mbrookhart): enable once VM supports heterogenous execution # @tvm.testing.uses_gpu def test_upsample(): - _test_upsample_nearest() - _test_upsample_bilinear() - _test_upsample_bilinear_opset9() - _test_upsample3d_nearest() - _test_upsample3d_trilinear() + verify_upsample_nearest() + verify_upsample_bilinear() + verify_upsample3d_nearest() + verify_upsample3d_trilinear() -def _test_softmax(inshape, axis): +def verify_softmax(inshape, axis): opname = "Softmax" indata = np.random.uniform(size=inshape).astype(np.float32) outshape = inshape - outdata = tvm.topi.testing.softmax_python(indata) - if isinstance(axis, int): - y = helper.make_node(opname, ["in"], ["out"], axis=axis) - elif axis is None: - y = helper.make_node(opname, ["in"], ["out"]) + y = helper.make_node(opname, ["in"], ["out"]) + if axis is not None: + axis_attr = helper.make_attribute("axis", axis) + y.attribute.append(axis_attr) graph = helper.make_graph( [y], opname + "_test", inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outshape))], ) model = helper.make_model(graph, producer_name=opname + "_test") - - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, indata, target, ctx, outshape, "float32") - tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5) + verify_with_ort_with_inputs(model, [indata]) @tvm.testing.uses_gpu def test_softmax(): - _test_softmax((1, 10), None) - _test_softmax((1, 10), 1) + verify_softmax((1, 10), None) + verify_softmax((1, 10), 1) def verify_min(input_dim): @@ -1427,8 +1322,6 @@ def verify_min(input_dim): a_np2 = np.random.uniform(size=input_dim).astype(dtype) a_np3 = np.random.uniform(size=input_dim).astype(dtype) - b_np = np.min((a_np1, a_np2, a_np3), axis=0) - min_node = helper.make_node("Min", ["a_np1", "a_np2", "a_np3"], ["out"]) graph = helper.make_graph( @@ -1439,14 +1332,11 @@ def verify_min(input_dim): helper.make_tensor_value_info("a_np2", TensorProto.FLOAT, list(input_dim)), helper.make_tensor_value_info("a_np3", TensorProto.FLOAT, list(input_dim)), ], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(b_np.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(input_dim))], ) model = helper.make_model(graph, producer_name="Min_test") - - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape) - tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) + verify_with_ort_with_inputs(model, [a_np1, a_np2, a_np3]) @tvm.testing.uses_gpu @@ -1462,8 +1352,6 @@ def verify_max(input_dim): a_np2 = np.random.uniform(size=input_dim).astype(dtype) a_np3 = np.random.uniform(size=input_dim).astype(dtype) - b_np = np.max((a_np1, a_np2, a_np3), axis=0) - max_node = helper.make_node("Max", ["a_np1", "a_np2", "a_np3"], ["out"]) graph = helper.make_graph( @@ -1474,14 +1362,11 @@ def verify_max(input_dim): helper.make_tensor_value_info("a_np2", TensorProto.FLOAT, list(input_dim)), helper.make_tensor_value_info("a_np3", TensorProto.FLOAT, list(input_dim)), ], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(b_np.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(input_dim))], ) model = helper.make_model(graph, producer_name="Max_test") - - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape) - tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) + verify_with_ort_with_inputs(model, [a_np1, a_np2, a_np3]) @tvm.testing.uses_gpu @@ -1497,8 +1382,6 @@ def verify_mean(input_dim): a_np2 = np.random.uniform(size=input_dim).astype(dtype) a_np3 = np.random.uniform(size=input_dim).astype(dtype) - b_np = np.mean((a_np1, a_np2, a_np3), axis=0) - mean_node = helper.make_node("Mean", ["a_np1", "a_np2", "a_np3"], ["out"]) graph = helper.make_graph( @@ -1509,14 +1392,11 @@ def verify_mean(input_dim): helper.make_tensor_value_info("a_np2", TensorProto.FLOAT, list(input_dim)), helper.make_tensor_value_info("a_np3", TensorProto.FLOAT, list(input_dim)), ], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(b_np.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(input_dim))], ) model = helper.make_model(graph, producer_name="Mean_test") - - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape) - tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) + verify_with_ort_with_inputs(model, [a_np1, a_np2, a_np3]) @tvm.testing.uses_gpu @@ -1530,22 +1410,17 @@ def verify_hardsigmoid(input_dim, alpha, beta): a_np1 = np.random.uniform(size=input_dim).astype(dtype) - b_np = np.clip(a_np1 * alpha + beta, 0, 1) - hardsigmoid_node = helper.make_node("HardSigmoid", ["a_np1"], ["out"], alpha=alpha, beta=beta) graph = helper.make_graph( [hardsigmoid_node], "HardSigmoid_test", inputs=[helper.make_tensor_value_info("a_np1", TensorProto.FLOAT, list(input_dim))], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(b_np.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(input_dim))], ) model = helper.make_model(graph, producer_name="HardSigmoid_test") - - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape) - tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) + verify_with_ort_with_inputs(model, [a_np1]) @tvm.testing.uses_gpu @@ -1554,98 +1429,51 @@ def test_forward_hardsigmoid(): verify_hardsigmoid((20, 20), 0.3, 0.4) -def verify_argmin(input_dim, axis=None, keepdims=None): - def _argmin_numpy(data, axis=0, keepdims=True): - result = np.argmin(data, axis=axis) - if keepdims == 1: - result = np.expand_dims(result, axis) - return result.astype(data.dtype) - +def verify_argreduce(input_dim, op_name, axis=None, keepdims=None): a_np1 = np.random.uniform(-10, 10, input_dim).astype(np.int32) - if keepdims is None and axis is None: - b_np = _argmin_numpy(a_np1) - node = onnx.helper.make_node("ArgMin", inputs=["a_np1"], outputs=["out"]) - elif axis is None: - b_np = _argmin_numpy(a_np1, keepdims=keepdims) - node = onnx.helper.make_node("ArgMin", inputs=["a_np1"], outputs=["out"], keepdims=keepdims) - elif keepdims is None: - b_np = _argmin_numpy(a_np1, axis=axis) - node = onnx.helper.make_node("ArgMin", inputs=["a_np1"], outputs=["out"], axis=axis) + out_shape = list(a_np1.shape) + def_axis = axis if axis is not None else 0 + if keepdims == 1 or keepdims == None: + out_shape[def_axis] = 1 else: - b_np = _argmin_numpy(a_np1, axis=axis, keepdims=keepdims) - node = onnx.helper.make_node( - "ArgMin", inputs=["a_np1"], outputs=["out"], axis=axis, keepdims=keepdims - ) - graph = helper.make_graph( - [node], - "argmin_test", - inputs=[helper.make_tensor_value_info("a_np1", TensorProto.INT32, list(a_np1.shape))], - outputs=[helper.make_tensor_value_info("out", TensorProto.INT32, list(b_np.shape))], - ) - - model = helper.make_model(graph, producer_name="argmin_test") - - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape, b_np.dtype) - tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) - + out_shape.pop(def_axis) -def verify_argmax(input_dim, axis=None, keepdims=None): - def _argmax_numpy(data, axis=0, keepdims=True): - result = np.argmax(data, axis=axis) - if keepdims == 1: - result = np.expand_dims(result, axis) - return result.astype(data.dtype) + node = onnx.helper.make_node(op_name, inputs=["a_np1"], outputs=["out"]) - a_np1 = np.random.uniform(-10, 10, input_dim).astype(np.int32) - if keepdims is None and axis is None: - b_np = _argmax_numpy(a_np1) - node = onnx.helper.make_node("ArgMax", inputs=["a_np1"], outputs=["out"]) - elif axis is None: - b_np = _argmax_numpy(a_np1, keepdims=keepdims) - node = onnx.helper.make_node("ArgMax", inputs=["a_np1"], outputs=["out"], keepdims=keepdims) - elif keepdims is None: - b_np = _argmax_numpy(a_np1, axis=axis) - node = onnx.helper.make_node("ArgMax", inputs=["a_np1"], outputs=["out"], axis=axis) - else: - b_np = _argmax_numpy(a_np1, axis=axis, keepdims=keepdims) - node = onnx.helper.make_node( - "ArgMax", inputs=["a_np1"], outputs=["out"], axis=axis, keepdims=keepdims - ) + if keepdims is not None: + keepdims_attr = helper.make_attribute("keepdims", keepdims) + node.attribute.append(keepdims_attr) + if axis is not None: + axis_attr = helper.make_attribute("axis", axis) + node.attribute.append(axis_attr) graph = helper.make_graph( [node], - "argmax_test", + "argreduce_test", inputs=[helper.make_tensor_value_info("a_np1", TensorProto.INT32, list(a_np1.shape))], - outputs=[helper.make_tensor_value_info("out", TensorProto.INT32, list(b_np.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.INT64, list(out_shape))], ) - model = helper.make_model(graph, producer_name="argmax_test") - - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape, b_np.dtype) - tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) + model = helper.make_model(graph, producer_name="argreduce_test") + verify_with_ort_with_inputs(model, [a_np1]) @tvm.testing.uses_gpu def test_forward_arg_min_max(): """Verify argmin and argmax""" - verify_argmin([3, 4, 4]) - verify_argmax([3, 4, 4]) - verify_argmin([3, 4, 4], axis=1) - verify_argmax([3, 4, 4], axis=0) - verify_argmin([3, 4, 4], keepdims=0) - verify_argmax([3, 4, 4], keepdims=1) + verify_argreduce([3, 4, 4], "ArgMin") + verify_argreduce([3, 4, 4], "ArgMax") + verify_argreduce([3, 4, 4], "ArgMin", axis=1) + verify_argreduce([3, 4, 4], "ArgMax", axis=0) + verify_argreduce([3, 4, 4], "ArgMin", keepdims=0) + verify_argreduce([3, 4, 4], "ArgMax", keepdims=1) for axis in [None, 0, 1, 2]: for keepdims in [None, True, False]: - verify_argmin([3, 4, 4], axis, keepdims) - verify_argmax([3, 4, 4], axis, keepdims) + verify_argreduce([3, 4, 4], "ArgMin", axis, keepdims) + verify_argreduce([3, 4, 4], "ArgMax", axis, keepdims) def verify_constantofshape(input_dim, value, dtype): - out = np.empty(shape=input_dim, dtype=dtype) - out.fill(value) - fill_node = helper.make_node( "ConstantOfShape", ["input"], @@ -1655,22 +1483,22 @@ def verify_constantofshape(input_dim, value, dtype): ), ) - inputs = [helper.make_tensor_value_info("input", TensorProto.FLOAT, input_dim)] + inputs = [helper.make_tensor_value_info("input", TensorProto.INT64, [len(input_dim)])] graph = helper.make_graph( [fill_node], "fill_test", inputs, - outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, list(out.shape))], + outputs=[ + helper.make_tensor_value_info( + "output", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], input_dim + ) + ], ) model = helper.make_model(graph, producer_name="fill_test") - - for target, ctx in tvm.testing.enabled_targets(): - input_np = np.array(input_dim).astype("float32") - tvm_out = get_tvm_output_with_vm(model, [input_np], target, ctx) - - tvm.testing.assert_allclose(out, tvm_out, rtol=1e-5, atol=1e-5) + input_np = np.array(input_dim).astype("int64") + verify_with_ort_with_inputs(model, [input_np], use_vm=True) # TODO(mbrookhart): enable once VM supports heterogenous execution @@ -1708,10 +1536,7 @@ def verify_pad(indata, pads, mode="constant", value=0.0): outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, list(outdata.shape))], ) model = helper.make_model(graph, producer_name="pad_test") - # tvm result - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, "float32", opset=2) - tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5) + verify_with_ort_with_inputs(model, [indata], [outdata.shape], dtype="float32", opset=2) def verify_pad_v11(indata, pads, mode="constant", value=0.0): @@ -1760,10 +1585,7 @@ def verify_pad_v11(indata, pads, mode="constant", value=0.0): ], ) model = helper.make_model(graph, producer_name="pad_test") - # tvm result - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output_with_vm(model, inputs, target, ctx, opset=11, freeze_params=False) - tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5) + verify_with_ort_with_inputs(model, inputs, opset=11, use_vm=True) # TODO(mbrookhart): enable once VM supports heterogenous execution @@ -1804,7 +1626,7 @@ def verify_reduce_func(func, data, axis, keepdims): model = helper.make_model(graph, producer_name="reduce_test") - verify_with_ort_with_inputs(model, [data], outshape) + verify_with_ort_with_inputs(model, [data], [outshape]) @tvm.testing.uses_gpu @@ -1849,32 +1671,45 @@ def test_all_reduce_funcs(): ) -def verify_split(indata, outdatas, split, axis=0, pass_split=True): +def verify_split(indata, outdatas, split, axis=0, pass_split=True, opset=11): indata = np.array(indata).astype(np.float32) outdatas = [np.array(o).astype(np.float32) for o in outdatas] + inputs = [helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape))] + input_names = ["input"] + initializer = [] + if split: split_index = range(len(split)) else: split_index = range(len(outdatas)) + if pass_split: - node = helper.make_node( - "Split", - inputs=["input"], - outputs=["output_{}".format(i) for i in range(len(split_index))], - axis=axis, - split=split, - ) - else: - node = helper.make_node( - "Split", - inputs=["input"], - outputs=["output_{}".format(i) for i in range(len(split_index))], - axis=axis, - ) + if opset >= 13: + input_names.append("split") + np_split = np.array(split).astype(np.int64) + inputs.append( + helper.make_tensor_value_info("split", TensorProto.INT64, list(np_split.shape)) + ) + indata = [indata, np_split] + initializer.append( + helper.make_tensor("split", TensorProto.INT64, list(np_split.shape), np_split) + ) + node = helper.make_node( + "Split", + inputs=input_names, + outputs=["output_{}".format(i) for i in range(len(split_index))], + axis=axis, + ) + + if pass_split and opset < 13: + split_attr = helper.make_attribute("split", split) + node.attribute.append(split_attr) + graph = helper.make_graph( [node], "split_test", - inputs=[helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape))], + inputs=inputs, + initializer=initializer, outputs=[ helper.make_tensor_value_info( "output_{}".format(i), TensorProto.FLOAT, list(outdatas[i].shape) @@ -1883,18 +1718,7 @@ def verify_split(indata, outdatas, split, axis=0, pass_split=True): ], ) model = helper.make_model(graph, producer_name="split_test") - - import onnxruntime.backend - - rep = onnxruntime.backend.prepare(model, "CPU") - onnx_out = rep.run(indata) - - for target, ctx in tvm.testing.enabled_targets(): - output_shape = [o.shape for o in outdatas] - output_type = ["float32", "float32", "float32"] - tvm_out = get_tvm_output(model, indata, target, ctx, output_shape, output_type) - for o, t in zip(onnx_out, tvm_out): - tvm.testing.assert_allclose(o, t) + verify_with_ort_with_inputs(model, indata, out_shape=list(range(len(split_index))), opset=opset) @tvm.testing.uses_gpu @@ -1914,6 +1738,8 @@ def test_split(): ) # Split evenly (unstack) verify_split([1, 2, 3], [[1], [2], [3]], False, 0, False) + # Split a single value to a single value + verify_split([1], [[1]], [1], pass_split=True) @tvm.testing.uses_gpu @@ -1922,50 +1748,52 @@ def test_binary_ops(): dtype = "float32" out_shape = in_shape - def verify_binary_ops(op, x, y, out_np, x_name="in1", y_name="in2", broadcast=None): - if broadcast is None: - z = helper.make_node(op, [x_name, y_name], ["out"]) - else: - z = helper.make_node(op, [x_name, y_name], ["out"], broadcast=1) + def verify_binary_ops(op, x, y, out_type="float32"): + z = helper.make_node(op, ["in1", "in2"], ["out"]) graph = helper.make_graph( [z], "_test", inputs=[ - helper.make_tensor_value_info(x_name, TensorProto.FLOAT, list(in_shape)), - helper.make_tensor_value_info(y_name, TensorProto.FLOAT, list(in_shape)), + helper.make_tensor_value_info("in1", TensorProto.FLOAT, x.shape), + helper.make_tensor_value_info("in2", TensorProto.FLOAT, y.shape), + ], + outputs=[ + helper.make_tensor_value_info( + "out", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(out_type)], list(out_shape) + ) ], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))], ) model = helper.make_model(graph, producer_name="_test") - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [x, y], target, ctx) - tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) + verify_with_ort_with_inputs(model, [x, y]) x = np.random.uniform(size=in_shape).astype(dtype) y = np.random.uniform(size=in_shape).astype(dtype) z = np.random.uniform(size=(3,)).astype(dtype) - verify_binary_ops("Add", x, y, x + y, broadcast=None) - verify_binary_ops("Add", x, z, x + z, broadcast=True) - verify_binary_ops("Sub", x, y, x - y, broadcast=None) - verify_binary_ops("Sub", x, z, x - z, broadcast=True) - verify_binary_ops("Mul", x, y, x * y, broadcast=None) - verify_binary_ops("Mul", x, z, x * z, broadcast=True) - verify_binary_ops("Mul", x, x, x * x, x_name="in1", y_name="in1", broadcast=None) - verify_binary_ops("Div", x, y, x / y, broadcast=None) - verify_binary_ops("Div", x, z, x / z, broadcast=True) - verify_binary_ops("Sum", x, y, x + y, broadcast=None) - verify_binary_ops("Greater", x, y, x > y, broadcast=True) - verify_binary_ops("Less", x, y, x < y, broadcast=True) - verify_binary_ops("Equal", x, y, x == y, broadcast=True) - - -@tvm.testing.uses_gpu -def test_single_ops(): + verify_binary_ops("Add", x, y) + verify_binary_ops("Add", x, z) + verify_binary_ops("Sub", x, y) + verify_binary_ops("Sub", x, z) + verify_binary_ops("Mul", x, y) + verify_binary_ops("Mul", x, z) + verify_binary_ops("Div", x, y) + verify_binary_ops("Div", x, z) + verify_binary_ops("Sum", x, y) + verify_binary_ops("Sum", x, z) + verify_binary_ops("Greater", x, y, "bool") + verify_binary_ops("Greater", x, z, "bool") + verify_binary_ops("Less", x, y, "bool") + verify_binary_ops("Less", x, z, "bool") + verify_binary_ops("Equal", x, y, "bool") + verify_binary_ops("Equal", x, z, "bool") + + +@tvm.testing.uses_gpu +def test_unary_ops(): in_shape = (1, 2, 3, 3) dtype = "float32" out_shape = in_shape - def verify_single_ops(op, x, out_np, rtol=1e-5, atol=1e-5): + def verify_unary_ops(op, x, rtol=1e-5, atol=1e-5): z = helper.make_node(op, ["in1"], ["out"]) graph = helper.make_graph( [z], @@ -1976,33 +1804,31 @@ def verify_single_ops(op, x, out_np, rtol=1e-5, atol=1e-5): outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))], ) model = helper.make_model(graph, producer_name="_test") - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [x], target, ctx) - tvm.testing.assert_allclose(out_np, tvm_out, rtol=rtol, atol=atol) + verify_with_ort_with_inputs(model, [x], rtol=rtol, atol=atol) x = np.random.uniform(size=in_shape).astype(dtype) - verify_single_ops("Neg", x, -x) - verify_single_ops("Abs", x, np.abs(x)) - verify_single_ops("Reciprocal", x, 1 / x) - verify_single_ops("Sqrt", x, np.sqrt(x)) - verify_single_ops("Relu", x, np.maximum(x, 0)) - verify_single_ops("Exp", x, np.exp(x)) - verify_single_ops("Log", x, np.log(x)) - verify_single_ops("Log", x, np.log(x)) - verify_single_ops("ACos", x, np.arccos(x)) - verify_single_ops("ACosh", x, np.arccosh(x)) - verify_single_ops("ASin", x, np.arcsin(x)) - verify_single_ops("ASinh", x, np.arcsinh(x)) - verify_single_ops("ATan", x, np.arctan(x)) - verify_single_ops("ATanh", x, np.arctanh(x)) - verify_single_ops("Cos", x, np.cos(x)) - verify_single_ops("Cosh", x, np.cosh(x)) - verify_single_ops("Sin", x, np.sin(x)) - verify_single_ops("Sinh", x, np.sinh(x)) - verify_single_ops("Tan", x, np.tan(x)) - verify_single_ops("Tanh", x, np.tanh(x)) - verify_single_ops("Sigmoid", x, 1 / (1 + np.exp(-x))) - verify_single_ops("Softsign", x, x / (1 + np.abs(x))) + verify_unary_ops("Neg", x) + verify_unary_ops("Abs", x) + verify_unary_ops("Reciprocal", x) + verify_unary_ops("Sqrt", x) + verify_unary_ops("Relu", x) + verify_unary_ops("Exp", x) + verify_unary_ops("Log", x) + verify_unary_ops("Log", x) + verify_unary_ops("Acos", x) + verify_unary_ops("Acosh", x) + verify_unary_ops("Asin", x) + verify_unary_ops("Asinh", x) + verify_unary_ops("Atan", x) + verify_unary_ops("Atanh", x) + verify_unary_ops("Cos", x) + verify_unary_ops("Cosh", x) + verify_unary_ops("Sin", x) + verify_unary_ops("Sinh", x) + verify_unary_ops("Tan", x) + verify_unary_ops("Tanh", x) + verify_unary_ops("Sigmoid", x) + verify_unary_ops("Softsign", x) @tvm.testing.uses_gpu @@ -2058,7 +1884,11 @@ def verify_prelu(x_shape, a_shape): model = helper.make_model(graph, producer_name="prelu_test") verify_with_ort( - model, [x_shape, a_shape], list(x_shape), use_vm=True, convert_to_static=True + model, + [x_shape, a_shape], + out_shape=[list(x_shape)], + use_vm=True, + convert_to_static=True, ) verify_prelu([3, 4, 5, 6], [1, 4, 1, 1]) @@ -2085,46 +1915,6 @@ def ThresholdedRelu_x(x, alpha): ) -@tvm.testing.uses_gpu -def test_ScaledTanh(): - def ScaledTanh_x(x, alpha, beta): - return alpha * np.tanh(beta * x) - - _test_onnx_op_elementwise( - (2, 4, 5, 6), - ScaledTanh_x, - {"alpha": 0.25, "beta": 0.3}, - "float32", - "ScaledTanh", - {"alpha": 0.25, "beta": 0.3}, - ) - - -@tvm.testing.uses_gpu -def test_ParametricSoftplus(): - def ParametricSoftplus_x(x, alpha, beta): - return alpha * np.log(np.exp(beta * x) + 1) - - _test_onnx_op_elementwise( - (2, 4, 5, 6), - ParametricSoftplus_x, - {"alpha": 0.25, "beta": 0.3}, - "float32", - "ParametricSoftplus", - {"alpha": 0.25, "beta": 0.3}, - ) - - -@tvm.testing.uses_gpu -def test_Scale(): - def Scale_x(x, scale): - return scale * x - - _test_onnx_op_elementwise( - (2, 4, 5, 6), Scale_x, {"scale": 0.25}, "float32", "Scale", {"scale": 0.25} - ) - - @tvm.testing.uses_gpu def test_LogSoftmax(): _test_onnx_op_elementwise( @@ -2138,8 +1928,8 @@ def check_torch_conversion(model, input_size): # Set verbose=True for more output torch.onnx.export(model(), dummy_input, file_name, export_params=True, verbose=False) onnx_model = onnx.load(file_name) - input_data = np.random.uniform(size=input_size).astype("int32") - verify_with_ort_with_inputs(onnx_model, [input_data]) + input_data = np.random.uniform(size=input_size).astype("float32") + verify_with_ort_with_inputs(onnx_model, [input_data], apply_softmax=True) @tvm.testing.uses_gpu @@ -2191,7 +1981,6 @@ def Sign_x(x): def verify_not(indata, dtype): x = indata.astype(dtype) - outdata = np.logical_not(x) node = helper.make_node( "Not", @@ -2203,14 +1992,11 @@ def verify_not(indata, dtype): [node], "not_test", inputs=[helper.make_tensor_value_info("in", TensorProto.BOOL, list(x.shape))], - outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(x.shape))], ) model = helper.make_model(graph, producer_name="not_test") - - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [x], target, ctx, outdata.shape) - tvm.testing.assert_allclose(outdata, tvm_out) + verify_with_ort_with_inputs(model, [x]) @tvm.testing.uses_gpu @@ -2245,10 +2031,7 @@ def verify_and(indata, dtype): ) model = helper.make_model(graph, producer_name="and_test") - - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [x, y], target, ctx, outdata.shape) - tvm.testing.assert_allclose(outdata, tvm_out) + verify_with_ort_with_inputs(model, [x, y], [outdata.shape]) @tvm.testing.uses_gpu @@ -2279,22 +2062,6 @@ def test_and(): verify_and(indata=[x, y], dtype=bool) -def verify_tile_v1(indata, outdata, **kwargs): - node = helper.make_node("Tile", inputs=["in"], outputs=["out"], **kwargs) - graph = helper.make_graph( - [node], - "tile_test", - inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))], - ) - - model = helper.make_model(graph, producer_name="tile_test") - - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [indata], target, ctx, outdata.shape, opset=1) - tvm.testing.assert_allclose(outdata, tvm_out) - - def verify_tile_v6(indata, repeats, outdata): node = helper.make_node("Tile", inputs=["input", "repeats"], outputs=["out"]) graph = helper.make_graph( @@ -2308,10 +2075,7 @@ def verify_tile_v6(indata, repeats, outdata): ) model = helper.make_model(graph, producer_name="tile_test") - - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output_with_vm(model, [indata, repeats], target, ctx, opset=6) - tvm.testing.assert_allclose(outdata, tvm_out) + verify_with_ort_with_inputs(model, [indata, repeats], use_vm=True, opset=6) # TODO(mbrookhart): enable once VM supports heterogenous execution @@ -2320,7 +2084,6 @@ def test_tile(): x = np.random.rand(2, 3, 4, 5).astype(np.float32) repeats = np.random.randint(low=1, high=10, size=(np.ndim(x),)).astype(np.int64) z = np.tile(x, repeats) - verify_tile_v1(x, z, repeats=repeats) verify_tile_v6(x, repeats, z) @@ -2333,10 +2096,7 @@ def verify_erf(indata, outdata): outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))], ) model = helper.make_model(graph, producer_name="erf_test") - - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [indata], target, ctx, outdata.shape) - tvm.testing.assert_allclose(outdata, tvm_out) + verify_with_ort_with_inputs(model, [indata], [outdata.shape]) @tvm.testing.uses_gpu @@ -2359,10 +2119,7 @@ def verify_where(condition, x, y, dtype, outdata): outputs=[helper.make_tensor_value_info("out", dtype, list(outdata.shape))], ) model = helper.make_model(graph, producer_name="where_test") - - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [condition, x, y], target, ctx, outdata.shape) - tvm.testing.assert_allclose(outdata, tvm_out) + verify_with_ort_with_inputs(model, [condition, x, y], [outdata.shape]) @tvm.testing.uses_gpu @@ -2422,10 +2179,7 @@ def verify_or(indata, dtype): ) model = helper.make_model(graph, producer_name="or_test") - - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [x, y], target, ctx, outdata.shape) - tvm.testing.assert_allclose(outdata, tvm_out) + verify_with_ort_with_inputs(model, [x, y], [outdata.shape]) @tvm.testing.uses_gpu @@ -2479,7 +2233,7 @@ def verify_batch_norm(in_shape): model = helper.make_model(graph, producer_name="batchnorm_test") # X, scale, b, mean, var inshapes = [in_shape, in_shape[1], in_shape[1], in_shape[1], in_shape[1]] - verify_with_ort(model, inshapes, in_shape) + verify_with_ort(model, inshapes, out_shape=[in_shape]) verify_batch_norm([1, 3, 224, 224]) verify_batch_norm([1, 3, 24, 24]) @@ -2517,7 +2271,7 @@ def verify_batch_norm_dynamic_subgraph(in_shape, o_shape): # X, inp, scale, b, mean, var inshapes = [in_shape, o_shape, in_shape[1], in_shape[1], in_shape[1], in_shape[1]] - verify_with_ort(model, inshapes, in_shape, use_vm=True) + verify_with_ort(model, inshapes, out_shape=[in_shape], use_vm=True) verify_batch_norm_dynamic_subgraph([16, 16, 10, 10], [160, 160]) @@ -2581,7 +2335,7 @@ def verify_conv( model = helper.make_model(graph, producer_name="conv_test") - verify_with_ort(model, [x_shape, w_shape], y_shape, use_vm=True, convert_to_static=True) + verify_with_ort(model, [x_shape, w_shape], [y_shape], use_vm=True, convert_to_static=True) @tvm.testing.uses_gpu @@ -2735,7 +2489,7 @@ def verify_convtranspose_with_padding( model = helper.make_model(graph, producer_name="conv_test") - verify_with_ort(model, [x_shape, w_shape], y_shape, use_vm=True, convert_to_static=True) + verify_with_ort(model, [x_shape, w_shape], [y_shape], use_vm=True, convert_to_static=True) def verify_convtranspose(x_shape, w_shape, y_shape, p): @@ -2908,7 +2662,7 @@ def verify_pooling(x_shape, kernel_shape, strides, pads, out_shape, mode, auto_p ) model = helper.make_model(graph, producer_name="pooling_test") - verify_with_ort(model, [x_shape], out_shape, use_vm=True, convert_to_static=True) + verify_with_ort(model, [x_shape], [out_shape], use_vm=True, convert_to_static=True) @tvm.testing.uses_gpu @@ -3013,7 +2767,7 @@ def verify_mod(x_shape, y_shape, fmod, out_shape, dtype="float32"): outputs=[helper.make_tensor_value_info("z", onnx_dtype, list(out_shape))], ) model = helper.make_model(graph, producer_name="mod_test") - verify_with_ort_with_inputs(model, [x_np, y_np], out_shape) + verify_with_ort_with_inputs(model, [x_np, y_np], [out_shape]) @tvm.testing.uses_gpu @@ -3066,10 +2820,7 @@ def verify_xor(x_shape, y_shape): outputs=[helper.make_tensor_value_info("z", onnx_dtype, list(out_shape))], ) model = helper.make_model(graph, producer_name="xor_test") - - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [x_np, y_np], target, ctx, out_shape) - tvm.testing.assert_allclose(np_out, tvm_out, rtol=1e-5, atol=1e-5) + verify_with_ort_with_inputs(model, [x_np, y_np], [out_shape]) @tvm.testing.uses_gpu @@ -3106,7 +2857,7 @@ def verify_max_roi_pool(x_shape, rois_shape, pooled_shape, spatial_scale, out_sh ) model = helper.make_model(graph, producer_name="pool_test") - verify_with_ort(model, [x_shape, rois_shape], out_shape) + verify_with_ort(model, [x_shape, rois_shape], [out_shape]) @tvm.testing.uses_gpu @@ -3158,7 +2909,7 @@ def verify_lppool(x_shape, kernel_shape, p, strides, pads, out_shape, auto_pad=" ) model = helper.make_model(graph, producer_name="lppool_test") - verify_with_ort(model, [x_shape], out_shape, use_vm=True, convert_to_static=True) + verify_with_ort(model, [x_shape], [out_shape], use_vm=True, convert_to_static=True) @tvm.testing.uses_gpu @@ -3350,18 +3101,7 @@ def verify_rnn( model = helper.make_model(graph, producer_name="rnn_test") - for target, ctx in tvm.testing.enabled_targets(): - onnx_out = get_onnxruntime_output(model, input_values, "float32") - tvm_out = get_tvm_output( - model, - input_values, - target, - ctx, - output_shapes, - output_dtype=["float32"] * len(output_shapes), - ) - for o_out, t_out in zip(onnx_out, tvm_out): - tvm.testing.assert_allclose(o_out, t_out, rtol=5e-3, atol=5e-3) + verify_with_ort_with_inputs(model, input_values, output_shapes, atol=1e-2, rtol=1e-2) @tvm.testing.uses_gpu @@ -3566,7 +3306,7 @@ def verify(ishape, oshape, scales, mode, coord_trans): model = helper.make_model(graph, producer_name="resize_test") - verify_with_ort(model, [ishape], oshape, use_vm=True, opset=11, freeze_params=True) + verify_with_ort(model, [ishape], [oshape], use_vm=True, opset=11, freeze_params=True) # upsampling verify([1, 16, 32, 32], [1, 16, 64, 64], [], "nearest", "asymmetric") @@ -3603,9 +3343,7 @@ def verify_opset_10(ishape, scales, mode): ) model = helper.make_model(graph, producer_name="resize_test") - model.opset_import[0].version = 10 - - verify_with_ort(model, [ishape], oshape, use_vm=True, freeze_params=True) + verify_with_ort(model, [ishape], [oshape], use_vm=True, freeze_params=True, opset=10) verify_opset_10([1, 16, 32, 32], [1, 1, 2, 2], "nearest") verify_opset_10([1, 16, 32, 32], [1, 1, 0.5, 0.5], "linear") @@ -3674,11 +3412,7 @@ def verify_topk(input_dims, K, axis=-1): model = helper.make_model(graph, producer_name="topk_test") indata = np.random.uniform(-10, 10, input_dims).astype(np.float32) - onnx_out = get_onnxruntime_output(model, [indata, np.array([K])]) - - for target, ctx in [("llvm", tvm.cpu())]: - tvm_out = get_tvm_output_with_vm(model, [indata, np.array(K)], target, ctx) - tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05) + verify_with_ort_with_inputs(model, [indata, np.array([K])], use_vm=True) for n in [12, 32]: for shape in [[n], [n, n], [n, n, n]]: @@ -3731,7 +3465,9 @@ def verify_roi_align( np_rois = np.random.uniform(size=[num_roi, 4]).astype("float32") * input_dims[2] np_batch_indicies = np.random.randint(low=0, high=input_dims[0], size=num_roi) - verify_with_ort_with_inputs(model, [np_data, np_rois, np_batch_indicies], output_dims) + verify_with_ort_with_inputs( + model, [np_data, np_rois, np_batch_indicies], out_shape=[output_dims] + ) verify_roi_align((1, 4, 16, 16), 32, 7, 7, sampling_ratio=0, spatial_scale=1.0) verify_roi_align((4, 4, 16, 32), 32, 7, 7, sampling_ratio=0, spatial_scale=1.0) @@ -3914,12 +3650,7 @@ def verify_cond_loop(): trip_count = np.array(40).astype(np.int64) cond = np.array(1).astype(np.bool) input_vals = [trip_count, cond, y] - onnx_out = get_onnxruntime_output(loop_model, input_vals) - - for target, ctx in [("llvm", tvm.cpu())]: - tvm_out = get_tvm_output_with_vm(loop_model, input_vals, target, ctx, freeze_params=True) - for i in range(len(tvm_out)): - tvm.testing.assert_allclose(onnx_out[i], tvm_out[i], rtol=1e-05, atol=1e-05) + verify_with_ort_with_inputs(loop_model, input_vals, use_vm=True, freeze_params=True) def verify_count_loop(): @@ -3974,12 +3705,7 @@ def verify_count_loop(): trip_count = np.array(5).astype(np.int64) cond = np.array(1).astype(np.bool) input_vals = [trip_count, cond, y] - onnx_out = get_onnxruntime_output(loop_model, input_vals) - - for target, ctx in [("llvm", tvm.cpu())]: - tvm_out = get_tvm_output_with_vm(loop_model, input_vals, target, ctx, freeze_params=True) - for i in range(len(tvm_out)): - tvm.testing.assert_allclose(onnx_out[i], tvm_out[i], rtol=1e-05, atol=1e-05) + verify_with_ort_with_inputs(loop_model, input_vals, use_vm=True, freeze_params=True) def test_loop(): @@ -3999,11 +3725,11 @@ def verify_if(cond_array): y = np.array([5, 4, 3, 2, 1]).astype(np.float32) then_const_node = onnx.helper.make_node( - "Constant", inputs=[], outputs=["then_out"], value=onnx.numpy_helper.from_array(x) + "Constant", inputs=[], outputs=["then_out"], value=numpy_helper.from_array(x) ) else_const_node = onnx.helper.make_node( - "Constant", inputs=[], outputs=["else_out"], value=onnx.numpy_helper.from_array(y) + "Constant", inputs=[], outputs=["else_out"], value=numpy_helper.from_array(y) ) then_body = onnx.helper.make_graph([then_const_node], "then_body", [], [then_out]) @@ -4032,6 +3758,8 @@ def verify_if(cond_array): cond = np.array(1).astype("bool") correct_out = x if cond else y + # TODO(jwfromm): Onnxruntime 1.0.0 is buggy with If statements. Replace this with + # verify_with_ort once we update versions. for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output_with_vm(if_model, [cond], target, ctx, freeze_params=True) for i in range(len(tvm_out)): @@ -4204,15 +3932,12 @@ def verify_softplus(indata): test_pad() test_split() test_binary_ops() - test_single_ops() + test_unary_ops() test_leaky_relu() test_elu() test_selu() test_prelu() test_ThresholdedRelu() - test_ScaledTanh() - test_ParametricSoftplus() - test_Scale() test_LogSoftmax() test_resnet() test_inception()