Skip to content

Commit

Permalink
[TensorRT] Add transpose_a/b for TensorRT batch_matmul (#8607)
Browse files Browse the repository at this point in the history
* Add transpose support for tensorrt batch_matmul

* Address PR comment

* Refactor to add ONNX_DEFAULT_CONFIGS
  • Loading branch information
ymwangg committed Aug 5, 2021
1 parent a495f95 commit 26c2a9a
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 18 deletions.
35 changes: 30 additions & 5 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@

__all__ = ["from_onnx"]

# The default configurations of Relay ONNX frontend.
ONNX_DEFAULT_CONFIGS = {
# By default, TVM converts qualified onnx `matmul` to `transpose(weight) + nn.batch_matmul_NT`.
# Change this flag to False to directly convert to `nn.batch_matmul`.
# Note that `nn.batch_matmul` with format other than NT is in experimental, it may have some
# performance issues.
"use_nt_batch_matmul": True,
}


class onnx_input:
"""Dual purpose list or dictionary access object."""
Expand Down Expand Up @@ -770,10 +779,14 @@ def flatten_to_nd(x, x_shape, nd=3):
# Convert a and b into 3 dimensional tensors.
a = flatten_to_nd(inputs[0], a_shape, 3)
b = flatten_to_nd(inputs[1], b_shape, 3)
# Transpose matrix dimensions of b.
b = _op.transpose(b, [0, 2, 1])
# Perform a batch matmul.
output = _op.nn.batch_matmul(a, b)
if ONNX_DEFAULT_CONFIGS["use_nt_batch_matmul"]:
# Transpose matrix dimensions of b.
b = _op.transpose(b, [0, 2, 1])
# Perform a NT batch matmul.
output = _op.nn.batch_matmul(a, b)
else:
# Perform a NN batch matmul.
output = _op.nn.batch_matmul(a, b, transpose_b=False)
# Determine the output batch dimension.
if a_rank > b_rank:
out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
Expand Down Expand Up @@ -3916,7 +3929,9 @@ def _fix_outputs(self, op_name, outputs):
return outputs


def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=False):
def from_onnx(
model, shape=None, dtype="float32", opset=None, freeze_params=False, convert_config=None
):
"""Convert a ONNX model into an equivalent Relay Function.
ONNX graphs are represented as Python Protobuf objects.
Expand Down Expand Up @@ -3955,6 +3970,12 @@ def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=Fals
at compile time and helps in making models static if certain inputs represent
attributes relay would traditionally consider compile-time constants.
convert_config : Optional[Dict[str, Any]]
Default config:
use_nt_batch_matmul : bool = True
True to convert qualified onnx `matmul` to `nn.batch_matmul` strict to NT format
(transpose_a=False, transpose_b=True).
Returns
-------
mod : tvm.IRModule
Expand All @@ -3963,6 +3984,10 @@ def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=Fals
params : dict of str to tvm.nd.NDArray
The parameter dict to be used by relay
"""
global ONNX_DEFAULT_CONFIGS
if convert_config is not None:
ONNX_DEFAULT_CONFIGS.update(convert_config)

try:
import onnx

Expand Down
2 changes: 1 addition & 1 deletion src/runtime/contrib/tensorrt/tensorrt_logger.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class TensorRTLogger : public nvinfer1::ILogger {
public:
TensorRTLogger() : TensorRTLogger(Severity::kWARNING) {}
explicit TensorRTLogger(Severity severity) : reportable_severity(severity) {}
void log(Severity severity, const char* msg) override {
void log(Severity severity, const char* msg) noexcept override {
// suppress messages with severity enum value greater than the reportable
if (severity > reportable_severity) return;

Expand Down
11 changes: 8 additions & 3 deletions src/runtime/contrib/tensorrt/tensorrt_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ class SplitOpConverter : public TensorRTOpConverter {
std::vector<int> start(input_dims.size(), 0);
std::vector<int> size(input_dims.begin(), input_dims.end());
std::vector<int> strides(input_dims.size(), 1);
for (int i = 0; i < split_sizes.size(); ++i) {
for (size_t i = 0; i < split_sizes.size(); ++i) {
start[axis] = split_starts[i];
size[axis] = split_sizes[i];
auto slice_layer = params->network->addSlice(*input, VectorToTrtDims(start),
Expand Down Expand Up @@ -1174,9 +1174,14 @@ class BatchMatmulOpConverter : public TensorRTOpConverter {
BatchMatmulOpConverter() : TensorRTOpConverter({kTensor, kTensor}) {}

void Convert(TensorRTOpConverterParams* params) const {
auto transa = std::stoi(params->node.GetAttr<std::vector<std::string>>("transpose_a")[0]);
auto transb = std::stoi(params->node.GetAttr<std::vector<std::string>>("transpose_b")[0]);
nvinfer1::MatrixOperation trt_transa =
transa ? nvinfer1::MatrixOperation::kTRANSPOSE : nvinfer1::MatrixOperation::kNONE;
nvinfer1::MatrixOperation trt_transb =
transb ? nvinfer1::MatrixOperation::kTRANSPOSE : nvinfer1::MatrixOperation::kNONE;
nvinfer1::IMatrixMultiplyLayer* matmul_layer = params->network->addMatrixMultiply(
*params->inputs.at(0).tensor, nvinfer1::MatrixOperation::kNONE,
*params->inputs.at(1).tensor, nvinfer1::MatrixOperation::kTRANSPOSE);
*params->inputs.at(0).tensor, trt_transa, *params->inputs.at(1).tensor, trt_transb);
ICHECK(matmul_layer != nullptr);
params->outputs.push_back(matmul_layer->getOutput(0));
}
Expand Down
17 changes: 14 additions & 3 deletions tests/python/contrib/test_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,14 +474,25 @@ def get_graph(x_shape=(1, 16), k_shape=(32, 16)):


def test_batch_matmul():
def get_graph(x_shape=(12, 128, 64), y_shape=(12, 128, 64)):
def get_graph(x_shape=(12, 128, 64), y_shape=(12, 128, 64), transa=False, transb=True):
x = relay.var("x", shape=(x_shape), dtype="float32")
y = relay.var("y", shape=(y_shape), dtype="float32")
out = relay.nn.batch_matmul(x, y)
out = relay.nn.batch_matmul(x, y, transpose_a=transa, transpose_b=transb)
f = relay.Function([x, y], out)
return f, {"x": x_shape, "y": y_shape}, []

run_and_verify_func(get_graph())
run_and_verify_func(
get_graph(x_shape=(12, 64, 128), y_shape=(12, 128, 64), transa=True, transb=True)
)
run_and_verify_func(
get_graph(x_shape=(12, 64, 128), y_shape=(12, 64, 128), transa=True, transb=False)
)
run_and_verify_func(
get_graph(x_shape=(12, 128, 64), y_shape=(12, 128, 64), transa=False, transb=True)
)
run_and_verify_func(
get_graph(x_shape=(12, 128, 64), y_shape=(12, 64, 128), transa=False, transb=False)
)


def test_bias_add():
Expand Down
50 changes: 44 additions & 6 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,26 @@ def get_input_data_shape_dict(graph_def, input_data):


def get_tvm_output_with_vm(
graph_def, input_data, target, dev, opset=None, freeze_params=False, convert_to_static=False
graph_def,
input_data,
target,
dev,
opset=None,
freeze_params=False,
convert_to_static=False,
convert_config=None,
):
"""Generic function to execute and get tvm output with vm executor"""
if not isinstance(input_data, list):
input_data = [input_data]
_, shape_dict = get_input_data_shape_dict(graph_def, input_data)

mod, params = relay.frontend.from_onnx(
graph_def, shape_dict, opset=opset, freeze_params=freeze_params
graph_def,
shape_dict,
opset=opset,
freeze_params=freeze_params,
convert_config=convert_config,
)

if convert_to_static:
Expand All @@ -78,12 +89,15 @@ def get_tvm_output(
output_dtype="float32",
opset=None,
opt_level=1,
convert_config=None,
):
"""Generic function to execute and get tvm output"""
# TODO: Resolve the issues and remove the following lines
input_names, shape_dict = get_input_data_shape_dict(graph_def, input_data)

mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset)
mod, params = relay.frontend.from_onnx(
graph_def, shape_dict, opset=opset, convert_config=convert_config
)

with tvm.transform.PassContext(opt_level=opt_level):
graph, lib, params = relay.build(mod, target, params=params)
Expand Down Expand Up @@ -146,6 +160,7 @@ def verify_with_ort_with_inputs(
atol=1e-5,
apply_softmax=False,
opt_level=1,
convert_config=None,
):
if opset is not None:
model.opset_import[0].version = opset
Expand All @@ -161,10 +176,19 @@ def verify_with_ort_with_inputs(
opset=opset,
freeze_params=freeze_params,
convert_to_static=convert_to_static,
convert_config=convert_config,
)
else:
tvm_out = get_tvm_output(
model, inputs, target, dev, out_shape, dtype, opset=opset, opt_level=opt_level
model,
inputs,
target,
dev,
out_shape,
dtype,
opset=opset,
opt_level=opt_level,
convert_config=convert_config,
)
if not isinstance(tvm_out, list):
tvm_out = [tvm_out]
Expand Down Expand Up @@ -1179,7 +1203,7 @@ def test_matmul(target, dev):

@tvm.testing.parametrize_targets
def test_batch_matmul(target, dev):
def verify_batch_matmul(a_shape, b_shape, out_shape):
def verify_batch_matmul(a_shape, b_shape, out_shape, convert_config=None):
a_array = np.random.uniform(size=a_shape).astype("float32")
b_array = np.random.uniform(size=b_shape).astype("float32")

Expand All @@ -1196,7 +1220,14 @@ def verify_batch_matmul(a_shape, b_shape, out_shape):
)

model = helper.make_model(graph, producer_name="matmul_test")
verify_with_ort_with_inputs(model, [a_array, b_array], use_vm=True, target=target, dev=dev)
verify_with_ort_with_inputs(
model,
[a_array, b_array],
use_vm=True,
target=target,
dev=dev,
convert_config=convert_config,
)

verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4), (2, 3, 4, 4))
verify_batch_matmul((2, 4, 3), (3, 4), (2, 4, 4))
Expand All @@ -1207,6 +1238,13 @@ def verify_batch_matmul(a_shape, b_shape, out_shape):
verify_batch_matmul((1, 4, 3), (2, 3, 4), (2, 4, 4))
verify_batch_matmul((4, 32, 16), (16, 32), (4, 32, 32))
verify_batch_matmul((4, 32, 16, 32), (32, 16), (4, 32, 16, 16))
# Test transb=False
verify_batch_matmul(
(2, 3, 4, 3),
(2, 3, 3, 4),
(2, 3, 4, 4),
convert_config={"use_nt_batch_matmul": False},
)


def verify_simple_dynamic_model(a_shape, b_shape, target, dev):
Expand Down

0 comments on commit 26c2a9a

Please sign in to comment.