Skip to content

Commit

Permalink
[fix] vec * mat in matmul in onnx converter (apache#11174)
Browse files Browse the repository at this point in the history
* fix: vec * mat in matmul in onnx converter

* fix: pylint

* fix: vec-mat matmul

* fix test

* fix test
  • Loading branch information
ganler authored and juda committed Jun 21, 2022
1 parent da91e9e commit d79896f
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 20 deletions.
14 changes: 11 additions & 3 deletions python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,18 @@ def find_libdevice_path(arch):
for fn in os.listdir(lib_path):
if not fn.startswith("libdevice"):
continue
ver = int(fn.split(".")[-3].split("_")[-1])
if selected_ver < ver <= arch:
selected_ver = ver

try:
# expected pattern: libdevice.${ARCH}.10.bc
# e.g., libdevice.compute_20.10.bc
ver = int(fn.split(".")[-3].split("_")[-1])
if selected_ver < ver <= arch:
selected_ver = ver
selected_path = fn
except ValueError:
# it can just be `libdevice.10.bc` in CUDA 10
selected_path = fn

if selected_path is None:
raise RuntimeError("Cannot find libdevice for arch {}".format(arch))
path = os.path.join(lib_path, selected_path)
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,10 @@ def flatten_to_nd(x, x_shape, nd=3):
0,
)
return _op.reshape(output, fold_constant(final_shape))

if a_rank == 1:
return _op.squeeze(_op.nn.matmul(_op.expand_dims(inputs[0], axis=0), inputs[1]), axis=[0])

# Otherwise a simple dense op will get the job done.
input_1_t = _op.transpose(inputs[1], axes=(1, 0))
return _op.nn.dense(inputs[0], input_1_t, out_dtype=out_dtype)
Expand Down
39 changes: 22 additions & 17 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,27 +1213,32 @@ def verify_gemm(a_shape, b_shape, c_shape=None, freeze_params=False, dtype="floa

@tvm.testing.parametrize_targets
def test_matmul(target, dev):
a_shape = (4, 3)
b_shape = (3, 4)
out_shape = [a_shape[0], b_shape[1]]
def test_one_matmul(a_shape, b_shape):
if len(a_shape) == 1:
out_shape = [b_shape[1]]
else:
out_shape = [a_shape[0], b_shape[1]]

a_array = np.random.uniform(size=a_shape).astype("float32")
b_array = np.random.uniform(size=b_shape).astype("float32")
a_array = np.random.uniform(size=a_shape).astype("float32")
b_array = np.random.uniform(size=b_shape).astype("float32")

mul_node = helper.make_node("MatMul", ["a", "b"], ["out"])
mul_node = helper.make_node("MatMul", ["a", "b"], ["out"])

graph = helper.make_graph(
[mul_node],
"matmul_test",
inputs=[
helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)),
helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)),
],
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))],
)
graph = helper.make_graph(
[mul_node],
"matmul_test",
inputs=[
helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)),
helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)),
],
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))],
)

model = helper.make_model(graph, producer_name="matmul_test")
verify_with_ort_with_inputs(model, [a_array, b_array], target=target, dev=dev)
model = helper.make_model(graph, producer_name="matmul_test")
verify_with_ort_with_inputs(model, [a_array, b_array], target=target, dev=dev)

test_one_matmul((4, 3), (3, 4))
test_one_matmul((3,), (3, 1))


@tvm.testing.parametrize_targets
Expand Down

0 comments on commit d79896f

Please sign in to comment.