Skip to content

Commit

Permalink
[BYOC][TensorRT] Add nn.batch_matmul, nn.layer_norm, erf (apache#8005)
Browse files Browse the repository at this point in the history
  • Loading branch information
Trevor Morris authored and trevor-m committed Jun 17, 2021
1 parent 06cd7ba commit 4c05219
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 1 deletion.
6 changes: 6 additions & 0 deletions docs/deploy/tensorrt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ Operator support
+------------------------+------------------------------------+
| nn.batch_norm | |
+------------------------+------------------------------------+
| nn.layer_norm | |
+------------------------+------------------------------------+
| nn.softmax | |
+------------------------+------------------------------------+
| nn.conv2d | |
Expand Down Expand Up @@ -253,6 +255,8 @@ Operator support
+------------------------+------------------------------------+
| nn.adaptive_avg_pool2d | |
+------------------------+------------------------------------+
| nn.batch_matmul | |
+------------------------+------------------------------------+
| clip | Requires TensorRT 5.1.5 or greater |
+------------------------+------------------------------------+
| nn.leaky_relu | Requires TensorRT 5.1.5 or greater |
Expand All @@ -277,6 +281,8 @@ Operator support
+------------------------+------------------------------------+
| nn.conv3d_transpose | Requires TensorRT 6.0.1 or greater |
+------------------------+------------------------------------+
| erf | Requires TensorRT 7.0.0 or greater |
+------------------------+------------------------------------+


Adding a new operator
Expand Down
29 changes: 29 additions & 0 deletions python/tvm/relay/op/contrib/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def _func_wrapper(attrs, args, op_name):
_register_external_op_helper_with_checker("cos", trt_version_annotate_fn((5, 1, 5)))
_register_external_op_helper_with_checker("atan", trt_version_annotate_fn((5, 1, 5)))
_register_external_op_helper_with_checker("ceil", trt_version_annotate_fn((5, 1, 5)))
_register_external_op_helper_with_checker("erf", trt_version_annotate_fn((7, 0, 0)))


@_register_external_dynamic_check_func("add")
Expand Down Expand Up @@ -411,6 +412,34 @@ def dense_annotate_fn(expr): # pylint: disable=unused-variable
return True


@_register_external_dynamic_check_func("nn.batch_matmul")
def batch_matmul_annotate_fn(expr):
"""Check if dense is supported by TensorRT."""

if any([x.checked_type.dtype != "float32" for x in expr.args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if get_tensorrt_use_implicit_batch_mode() and len(expr.args[0].checked_type.shape) != len(
expr.args[1].checked_type.shape
):
logger.info("nn.batch_matmul: requires use_implict_batch=False.")
return False
return True


@_register_external_dynamic_check_func("nn.layer_norm")
def layer_norm_annotate_fn(expr):
"""Check if dense is supported by TensorRT."""

if any([x.checked_type.dtype != "float32" for x in expr.args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if get_tensorrt_use_implicit_batch_mode() and int(expr.attrs.axis) == 0:
logger.info("nn.layer_norm: requires use_implict_batch=False.")
return False
return True


@_register_external_dynamic_check_func("nn.bias_add")
def bias_add_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.bias_add is supported by TensorRT."""
Expand Down
93 changes: 93 additions & 0 deletions src/runtime/contrib/tensorrt/tensorrt_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,78 @@ class BatchNormOpConverter : public TensorRTOpConverter {
}
};

class LayerNormOpConverter : public TensorRTOpConverter {
public:
LayerNormOpConverter() : TensorRTOpConverter({kTensor, kWeight, kWeight}) {}

void Convert(TensorRTOpConverterParams* params) const {
auto input = params->inputs.at(0).tensor;
auto gamma_input = params->inputs.at(1).weight;
auto beta_input = params->inputs.at(2).weight;
ICHECK_EQ(gamma_input.count, beta_input.count);

const float epsilon = std::stof(params->node.GetAttr<std::vector<std::string>>("epsilon")[0]);
const bool scale = std::stoi(params->node.GetAttr<std::vector<std::string>>("scale")[0]);
const bool center = std::stoi(params->node.GetAttr<std::vector<std::string>>("center")[0]);
const int input_rank = input->getDimensions().nbDims;
const int original_axis = std::stoi(params->node.GetAttr<std::vector<std::string>>("axis")[0]);
const int axis = ConvertAxis(params, original_axis, input_rank);

std::vector<int> weight_shape(input_rank, 1);
weight_shape[axis] = gamma_input.count;
auto gamma =
params->network->addConstant(VectorToTrtDims(weight_shape), gamma_input)->getOutput(0);
auto beta =
params->network->addConstant(VectorToTrtDims(weight_shape), beta_input)->getOutput(0);

// Compute mean
auto mean_layer = params->network->addReduce(*input, nvinfer1::ReduceOperation::kAVG, 1 << axis,
/*keepdims=*/true);
ICHECK(mean_layer != nullptr);
auto mean = mean_layer->getOutput(0);
// Compute variance
auto diff_layer =
params->network->addElementWise(*input, *mean, nvinfer1::ElementWiseOperation::kSUB);
ICHECK(diff_layer != nullptr);
auto square_layer =
params->network->addElementWise(*diff_layer->getOutput(0), *diff_layer->getOutput(0),
nvinfer1::ElementWiseOperation::kPROD);
ICHECK(square_layer != nullptr);
auto var_layer = params->network->addReduce(
*square_layer->getOutput(0), nvinfer1::ReduceOperation::kAVG, 1 << axis, /*keepdims=*/true);
ICHECK(var_layer != nullptr);
auto var = var_layer->getOutput(0);
// sqrt(var + epsilon)
auto epsilon_tensor = CreateScalar(params, epsilon, var->getDimensions());
auto denom_add_layer = params->network->addElementWise(*var, *epsilon_tensor,
nvinfer1::ElementWiseOperation::kSUM);
ICHECK(denom_add_layer != nullptr);
auto denom_layer =
params->network->addUnary(*denom_add_layer->getOutput(0), nvinfer1::UnaryOperation::kSQRT);
ICHECK(denom_layer != nullptr);
// (input - mean) / sqrt(var + epsilon)
auto output_layer =
params->network->addElementWise(*diff_layer->getOutput(0), *denom_layer->getOutput(0),
nvinfer1::ElementWiseOperation::kDIV);
ICHECK(output_layer != nullptr);
auto output = output_layer->getOutput(0);

if (scale) {
auto scale_layer =
params->network->addElementWise(*output, *gamma, nvinfer1::ElementWiseOperation::kPROD);
ICHECK(scale_layer != nullptr);
output = scale_layer->getOutput(0);
}
if (center) {
auto center_layer =
params->network->addElementWise(*output, *beta, nvinfer1::ElementWiseOperation::kSUM);
ICHECK(center_layer != nullptr);
output = center_layer->getOutput(0);
}
params->outputs.push_back(output);
}
};

class BatchFlattenOpConverter : public TensorRTOpConverter {
public:
BatchFlattenOpConverter() : TensorRTOpConverter({kTensor}) {}
Expand Down Expand Up @@ -686,6 +758,9 @@ class UnaryOpConverter : public TensorRTOpConverter {
{"atan", nvinfer1::UnaryOperation::kATAN},
{"ceil", nvinfer1::UnaryOperation::kCEIL},
{"floor", nvinfer1::UnaryOperation::kFLOOR},
#endif
#if TRT_VERSION_GE(7, 0, 0)
{"erf", nvinfer1::UnaryOperation::kERF},
#endif
};
auto it = op_map.find(params->op_name);
Expand Down Expand Up @@ -1094,6 +1169,19 @@ class AdaptivePoolingOpConverter : public TensorRTOpConverter {
}
};

class BatchMatmulOpConverter : public TensorRTOpConverter {
public:
BatchMatmulOpConverter() : TensorRTOpConverter({kTensor, kTensor}) {}

void Convert(TensorRTOpConverterParams* params) const {
nvinfer1::IMatrixMultiplyLayer* matmul_layer = params->network->addMatrixMultiply(
*params->inputs.at(0).tensor, nvinfer1::MatrixOperation::kNONE,
*params->inputs.at(1).tensor, nvinfer1::MatrixOperation::kTRANSPOSE);
ICHECK(matmul_layer != nullptr);
params->outputs.push_back(matmul_layer->getOutput(0));
}
};

const std::shared_ptr<std::unordered_map<std::string, std::shared_ptr<TensorRTOpConverter>>>
GetOpConverters() {
static auto map =
Expand All @@ -1103,6 +1191,7 @@ GetOpConverters() {
map->emplace("sigmoid", std::make_shared<ActivationOpConverter>());
map->emplace("tanh", std::make_shared<ActivationOpConverter>());
map->emplace("nn.batch_norm", std::make_shared<BatchNormOpConverter>());
map->emplace("nn.layer_norm", std::make_shared<LayerNormOpConverter>());
map->emplace("nn.softmax", std::make_shared<SoftmaxOpConverter>());
map->emplace("nn.conv2d", std::make_shared<Conv2DOpConverter>());
map->emplace("nn.dense", std::make_shared<DenseOpConverter>());
Expand Down Expand Up @@ -1140,6 +1229,7 @@ GetOpConverters() {
map->emplace("mean", std::make_shared<ReduceOpConverter>());
map->emplace("nn.adaptive_max_pool2d", std::make_shared<AdaptivePoolingOpConverter>());
map->emplace("nn.adaptive_avg_pool2d", std::make_shared<AdaptivePoolingOpConverter>());
map->emplace("nn.batch_matmul", std::make_shared<BatchMatmulOpConverter>());
#if TRT_VERSION_GE(5, 1, 5)
map->emplace("clip", std::make_shared<ActivationOpConverter>());
map->emplace("nn.leaky_relu", std::make_shared<ActivationOpConverter>());
Expand All @@ -1156,6 +1246,9 @@ GetOpConverters() {
map->emplace("nn.avg_pool3d", std::make_shared<Pooling3DOpConverter>());
map->emplace("nn.conv3d_transpose", std::make_shared<Conv3DTransposeOpConverter>());
#endif // TRT_VERSION_GE(6, 0, 1)
#if TRT_VERSION_GE(7, 0, 0)
map->emplace("erf", std::make_shared<UnaryOpConverter>());
#endif // TRT_VERSION_GE(7, 0, 0)
return map;
}

Expand Down
44 changes: 43 additions & 1 deletion tests/python/contrib/test_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def test_tensorrt_not_compatible():

x = relay.var("x", shape=(xshape), dtype=dtype)
y = relay.add(x, x)
z = relay.erf(y)
z = relay.cast(relay.cast(y, "int32"), "float32")
out = relay.nn.relu(z)
f = relay.Function([x], out)
mod = tvm.IRModule()
Expand Down Expand Up @@ -473,6 +473,17 @@ def get_graph(x_shape=(1, 16), k_shape=(32, 16)):
run_and_verify_func(get_graph(k_shape=(1, 16)))


def test_batch_matmul():
def get_graph(x_shape=(12, 128, 64), y_shape=(12, 128, 64)):
x = relay.var("x", shape=(x_shape), dtype="float32")
y = relay.var("y", shape=(y_shape), dtype="float32")
out = relay.nn.batch_matmul(x, y)
f = relay.Function([x, y], out)
return f, {"x": x_shape, "y": y_shape}, []

run_and_verify_func(get_graph())


def test_bias_add():
def get_graph(x_shape=(1, 16), channels=16):
x = relay.var("x", shape=(x_shape), dtype="float32")
Expand Down Expand Up @@ -848,6 +859,36 @@ def get_graph(x_shape, param_shape, axis=1, epsilon=1e-5):
run_and_verify_func(get_graph((1, 3, 8), (8,), axis=2))


def test_layer_norm():
def get_graph(x_shape, param_shape, axis=1, epsilon=1e-5):
x = relay.var("x", shape=(x_shape), dtype="float32")
gamma = relay.var("gamma", shape=(param_shape), dtype="float32")
beta = relay.var("beta", shape=(param_shape), dtype="float32")
out = relay.nn.layer_norm(
x,
gamma=gamma,
beta=beta,
axis=axis,
epsilon=epsilon,
center=True,
scale=True,
)
f = relay.Function([x, gamma, beta], out)
return (
f,
{
"x": x_shape,
"beta": param_shape,
"gamma": param_shape,
},
["beta", "gamma"],
)

run_and_verify_func(get_graph((1, 32, 8, 8), (32,)))
run_and_verify_func(get_graph((1, 8, 8, 32), (32,), axis=3, epsilon=1.001e-05))
run_and_verify_func(get_graph((1, 8), (8,), axis=1))


def test_unary():
def get_graph(op, x_shape=(1, 8, 3, 3)):
x = relay.var("x", shape=(x_shape), dtype="float32")
Expand All @@ -869,6 +910,7 @@ def get_graph(op, x_shape=(1, 8, 3, 3)):
relay.atan,
relay.ceil,
relay.floor,
relay.erf,
]:
run_and_verify_func(get_graph(op))

Expand Down

0 comments on commit 4c05219

Please sign in to comment.