From 9ea4fa24506de8ae1f2a8ad82edcb25320ba5b41 Mon Sep 17 00:00:00 2001 From: Jiawei Liu Date: Fri, 29 Apr 2022 18:17:40 -0500 Subject: [PATCH] [fix] vec * mat in matmul in onnx converter (#11174) * fix: vec * mat in matmul in onnx converter * fix: pylint * fix: vec-mat matmul * fix test * fix test --- python/tvm/contrib/nvcc.py | 14 ++++++-- python/tvm/relay/frontend/onnx.py | 4 +++ tests/python/frontend/onnx/test_forward.py | 39 ++++++++++++---------- 3 files changed, 37 insertions(+), 20 deletions(-) diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 11ac6169192f..5a104be9966d 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -229,10 +229,18 @@ def find_libdevice_path(arch): for fn in os.listdir(lib_path): if not fn.startswith("libdevice"): continue - ver = int(fn.split(".")[-3].split("_")[-1]) - if selected_ver < ver <= arch: - selected_ver = ver + + try: + # expected pattern: libdevice.${ARCH}.10.bc + # e.g., libdevice.compute_20.10.bc + ver = int(fn.split(".")[-3].split("_")[-1]) + if selected_ver < ver <= arch: + selected_ver = ver + selected_path = fn + except ValueError: + # it can just be `libdevice.10.bc` in CUDA 10 selected_path = fn + if selected_path is None: raise RuntimeError("Cannot find libdevice for arch {}".format(arch)) path = os.path.join(lib_path, selected_path) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 7a2379693842..0fc6e9e7b2b2 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -324,6 +324,10 @@ def flatten_to_nd(x, x_shape, nd=3): 0, ) return _op.reshape(output, fold_constant(final_shape)) + + if a_rank == 1: + return _op.squeeze(_op.nn.matmul(_op.expand_dims(inputs[0], axis=0), inputs[1]), axis=[0]) + # Otherwise a simple dense op will get the job done. input_1_t = _op.transpose(inputs[1], axes=(1, 0)) return _op.nn.dense(inputs[0], input_1_t, out_dtype=out_dtype) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 581075403c43..23f594a69ccb 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1213,27 +1213,32 @@ def verify_gemm(a_shape, b_shape, c_shape=None, freeze_params=False, dtype="floa @tvm.testing.parametrize_targets def test_matmul(target, dev): - a_shape = (4, 3) - b_shape = (3, 4) - out_shape = [a_shape[0], b_shape[1]] + def test_one_matmul(a_shape, b_shape): + if len(a_shape) == 1: + out_shape = [b_shape[1]] + else: + out_shape = [a_shape[0], b_shape[1]] - a_array = np.random.uniform(size=a_shape).astype("float32") - b_array = np.random.uniform(size=b_shape).astype("float32") + a_array = np.random.uniform(size=a_shape).astype("float32") + b_array = np.random.uniform(size=b_shape).astype("float32") - mul_node = helper.make_node("MatMul", ["a", "b"], ["out"]) + mul_node = helper.make_node("MatMul", ["a", "b"], ["out"]) - graph = helper.make_graph( - [mul_node], - "matmul_test", - inputs=[ - helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)), - helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)), - ], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))], - ) + graph = helper.make_graph( + [mul_node], + "matmul_test", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)), + helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))], + ) - model = helper.make_model(graph, producer_name="matmul_test") - verify_with_ort_with_inputs(model, [a_array, b_array], target=target, dev=dev) + model = helper.make_model(graph, producer_name="matmul_test") + verify_with_ort_with_inputs(model, [a_array, b_array], target=target, dev=dev) + + test_one_matmul((4, 3), (3, 4)) + test_one_matmul((3,), (3, 1)) @tvm.testing.parametrize_targets