Skip to content

Commit

Permalink
Dynamic Batch Support for TRT (apache#6955)
Browse files Browse the repository at this point in the history
* add_annotate_fn

* Reshape_ann_fn

* Prune Subgraph

* Dynamic Shape

* Make PT Mask RCNN Work

* Cleanup

* Remove comments

* Remove COmments

* GetBatchSizeFix

* Fix Remove Droupout

* Fix Remove Droupout

* TRT Runtime

* Add MaskrCNN R50

* New Testing code

* Fix black

* Test Maskrcnn r50 done

* Test MR50

* Space typo

* Change Log to Dlog

* Move test to tensorrt.py

* Remove imports

* Remove function

* Add it to trt

* import error

* Imports

* Add torch to CI

* trt_test

* Check test

* Revert Pytorch install

* Fix

* test dynamic batch

* TRT

* Resolve PR comments

* Zero batch size add

Co-authored-by: Ubuntu <ubuntu@ip-172-31-27-149.us-east-2.compute.internal>
  • Loading branch information
2 people authored and Trevor Morris committed Dec 4, 2020
1 parent c9f303a commit e6b06d5
Show file tree
Hide file tree
Showing 4 changed files with 318 additions and 27 deletions.
117 changes: 104 additions & 13 deletions python/tvm/relay/op/contrib/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.expr import Call, Constant, Tuple, GlobalVar, Var, TupleGetItem
from tvm.relay.expr_functor import ExprMutator
from tvm.relay.expr_functor import ExprMutator, ExprVisitor

logger = logging.getLogger("TensorRT")

Expand Down Expand Up @@ -173,7 +173,7 @@ def check_dynamism(args, op_name):
"""
for arg in args:
if isinstance(arg, (Call, Var, Constant, TupleGetItem)):
for dim_shape in arg.checked_type.shape:
for dim_shape in arg.checked_type.shape[1:]:
if isinstance(dim_shape, tvm.tir.expr.Any):
return True
elif isinstance(arg, Tuple):
Expand All @@ -198,6 +198,21 @@ def _func_wrapper(expr):
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if op_name == "multiply":
shapes = [
[
int(x) if not isinstance(x, tvm.tir.expr.Any) else -1
for x in arg.checked_type.shape
]
for arg in args
]
# Batched multiply operations don't work in implicit batch mode. The following shapes
# have been excluded because they occur in PT MaskRCNN model. The long term solution is
# to switch to explicit batch mode after performance regressions are solved.
if all(
[list(map(int, shape)) in [[300, 64, 7, 7], [300, 1, 1, 1]] for shape in shapes]
):
return False
return checker(attrs, args, op_name)

return _func_wrapper
Expand Down Expand Up @@ -292,19 +307,26 @@ def add_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if add is supported by TensorRT."""

args = expr.args

shapes = [
[int(x) if not isinstance(x, tvm.tir.expr.Any) else -1 for x in arg.checked_type.shape]
for arg in args
]

# RelayVM + TRT doesn't support scalar addition yet.
for arg in args:
if not arg.checked_type.shape:
for shape in shapes:
if len(shape) < 1:
return False

if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if (
not get_tensorrt_use_implicit_batch_mode()
and (isinstance(args[0], Constant) or isinstance(args[1], Constant))
and args[0].checked_type.shape[0] == args[1].checked_type.shape[0]
and args[0].checked_type.shape[0] != 1
and (len(args[0].checked_type.shape) > 3 or len(args[1].checked_type.shape) > 3)
and shapes[0][0] == shapes[1][0]
and shapes[0][0] != 1
and (len(shapes[0]) > 3 or len(shapes[1]) > 3)
):
logger.info("add: bug in TRT with adding batched constants.")
return False
Expand Down Expand Up @@ -592,11 +614,35 @@ def reshape_annotate_fn(expr): # pylint: disable=unused-variable
logger.info("reshape: new shape dims must be explicit.")
return False
if get_tensorrt_use_implicit_batch_mode():
shape = list(map(int, args[0].checked_type.shape))
new_shape = list(map(int, attrs.newshape))
shape = args[0].checked_type.shape
new_shape = attrs.newshape
if len(new_shape) == 0 or len(shape) == 0:
logger.info("reshape: Can't reshape to or from scalar.")
return False

dynamic_reshape = any([isinstance(x, tvm.tir.expr.Any) for x in shape])

if dynamic_reshape:
# Make sure that the batch dim is unmodified.
if int(new_shape[0]) < 0:
for shape_val, new_shape_val in enumerate(shape[1:], new_shape[1:]):
if not (
isinstance(shape_val, int)
and isinstance(new_shape_val, int)
and int(shape_val) == int(new_shape_val)
):
return False
elif int(new_shape[0]) > 0:
if not (
isinstance(shape[0], int)
and isinstance(new_shape[0], int)
and int(shape[0]) == int(new_shape[0])
):
return False
return True
shape = list(map(int, shape))
new_shape = list(map(int, new_shape))

# TRT cannot modify batch dimension.
original_volume = np.prod(shape)
# First, resolve 0.
Expand All @@ -607,6 +653,7 @@ def reshape_annotate_fn(expr): # pylint: disable=unused-variable
for i, value in enumerate(new_shape):
if value == -1:
new_shape[i] = original_volume // np.prod([x for x in new_shape if x != -1])
# Remove batch dimension and see if volumes match
if shape[0] != new_shape[0]:
logger.info("reshape: can't modify batch dimension.")
return False
Expand Down Expand Up @@ -795,31 +842,73 @@ def conv3d_transpose_annotate_fn(expr): # pylint: disable=unused-variable
return True


class IsComputeIntensiveGraph(ExprVisitor):
"""
Visits the Graph recursively and checks if it contains compute heavy ops like convolutions and
its transpose, dense and batch mat-mul.
"""

def __init__(self):
ExprVisitor.__init__(self)
self.is_compute_intensive = False

def visit_call(self, call):
compute_intensive_ops = set(
[
"nn.conv2d",
"nn.conv2d_transpose",
"nn.conv3d",
"nn.conv3d_transpose",
"nn.dense",
"nn.batch_matmul",
]
)
if isinstance(call.op, tvm.tir.op.Op):
if str(call.op) in compute_intensive_ops:
self.is_compute_intensive = True

return super().visit_call(call)

def is_graph_compute_intensive(self, subgraph) -> bool:
"""
This function recursively visits the graph and checks if it's compute intensive"
"""
self.visit(subgraph)
return self.is_compute_intensive


def is_valid_subgraph(params, body):
"""Final check on whether the subgraph is valid and should be offloaded to TensorRT."""
# Remove invalid subgraphs for implicit batch mode.
if get_tensorrt_use_implicit_batch_mode():
input_batch_sizes = []
for var in params:
# In implicit batch mode, all inputs must have same batch size
# TODO: (codeislife99) : Fix different dynamic batch size inputs

if isinstance(var.checked_type, relay.TupleType):
for tupe_type in var.checked_type.fields:
# Scalar inputs not allowed
if len(tupe_type.shape) == 0:
logger.info("tensorrt: scalar inputs not supported")
return False
input_batch_sizes.append(int(tupe_type.shape[0]))

if not isinstance(tupe_type.shape[0], tvm.tir.expr.Any):
input_batch_sizes.append(int(tupe_type.shape[0]))
else:
# Scalar inputs not allowed
if len(var.checked_type.shape) == 0:
logger.info("tensorrt: scalar inputs not supported")
return False
input_batch_sizes.append(int(var.checked_type.shape[0]))
if not isinstance(var.checked_type.shape[0], tvm.tir.expr.Any):
input_batch_sizes.append(int(var.checked_type.shape[0]))
if len(input_batch_sizes) > 1 and len(set(input_batch_sizes)) != 1:
logger.info("tensorrt: inputs have different batch sizes")
return False
# Remove subgraphs with no multiply-accumulates
if get_tensorrt_remove_no_mac_subgraphs() and relay.analysis.get_total_mac_number(body) == 0:
if (
get_tensorrt_remove_no_mac_subgraphs()
and not IsComputeIntensiveGraph().is_graph_compute_intensive(body)
):
return False
return True

Expand Down Expand Up @@ -880,6 +969,8 @@ class RemoveDropout(ExprMutator):

def visit_tuple_getitem(self, op):
visit = super().visit_tuple_getitem(op)
if visit.index != 0:
return visit
if (
isinstance(visit.tuple_value, Call)
and visit.tuple_value.op.name == "nn.dropout"
Expand Down
3 changes: 1 addition & 2 deletions src/relay/backend/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,7 @@ inline std::vector<int64_t> GetIntShape(const Array<IndexExpr>& shape) {
std::vector<int64_t> ret;
for (const auto& dim : shape) {
const int64_t* pval = tir::as_const_int(dim);
ICHECK(pval) << "Expect integer, but received: " << dim->GetTypeKey();
ret.push_back(*pval);
ret.push_back(pval ? *pval : -1);
}
return ret;
}
Expand Down
40 changes: 28 additions & 12 deletions src/runtime/contrib/tensorrt/tensorrt_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ namespace tvm {
namespace runtime {
namespace contrib {

struct PairHash {
template <class T1, class T2>
std::size_t operator()(const std::pair<T1, T2>& pair) const {
return std::hash<T1>()(pair.first) ^ std::hash<T2>()(pair.second);
}
};

using namespace tvm::runtime::json;

class TensorRTRuntime : public JSONRuntimeBase {
Expand Down Expand Up @@ -105,12 +112,13 @@ class TensorRTRuntime : public JSONRuntimeBase {
/*! \brief Run inference using built engine. */
void Run() override {
BuildEngine();
auto& engine_and_context = trt_engine_cache_.at(symbol_name_);
batch_size_ = data_entry_[input_var_eid_[0]]->shape[0];
if (batch_size_ == 0) return;
auto& engine_and_context = trt_engine_cache_.at(std::make_pair(symbol_name_, batch_size_));
auto engine = engine_and_context.engine;
auto context = engine_and_context.context;
auto& device_buffers = engine_and_context.device_buffers;
std::vector<void*> bindings(engine->getNbBindings(), nullptr);

for (size_t i = 0; i < input_nodes_.size(); ++i) {
auto nid = input_nodes_[i];
if (nodes_[nid].GetOpType() == "input") {
Expand Down Expand Up @@ -169,10 +177,11 @@ class TensorRTRuntime : public JSONRuntimeBase {
* do nothing.
*/
void BuildEngine() {
if (trt_engine_cache_.count(symbol_name_)) return;
DLOG(INFO) << "Building new TensorRT engine for subgraph " << symbol_name_;
batch_size_ = data_entry_[input_var_eid_[0]]->shape[0];
if (trt_engine_cache_.count(std::make_pair(symbol_name_, batch_size_))) return;
DLOG(INFO) << "Building new TensorRT engine for subgraph " << symbol_name_
<< " with batch size " << batch_size_;
const bool use_fp16 = dmlc::GetEnv("TVM_TENSORRT_USE_FP16", false);
batch_size_ = GetBatchSize();
TensorRTBuilder builder(&logger_, data_entry_, max_workspace_size_, use_implicit_batch_,
use_fp16, batch_size_);

Expand Down Expand Up @@ -203,8 +212,9 @@ class TensorRTRuntime : public JSONRuntimeBase {
}

// Build engine.
trt_engine_cache_[symbol_name_] = builder.BuildEngine();
DLOG(INFO) << "Finished building TensorRT engine for subgraph " << symbol_name_;
trt_engine_cache_[std::make_pair(symbol_name_, batch_size_)] = builder.BuildEngine();
DLOG(INFO) << "Finished building TensorRT engine for subgraph " << symbol_name_
<< " with batch size " << batch_size_;
CacheEngineToDisk();
}

Expand Down Expand Up @@ -240,30 +250,35 @@ class TensorRTRuntime : public JSONRuntimeBase {
helper.DeclareField("inputs", &engine_and_context.inputs);
helper.DeclareField("outputs", &engine_and_context.outputs);
helper.ReadAllFields(&reader);
trt_engine_cache_[symbol_name_] = engine_and_context;
const int batch_size = 1;
trt_engine_cache_[std::make_pair(symbol_name_, batch_size)] = engine_and_context;
return true;
}

/*! \brief If TVM_TENSORRT_CACHE_DIR is set, will save the engine to that
* directory so it can be loaded later.
*/
void CacheEngineToDisk() {
batch_size_ = data_entry_[input_var_eid_[0]]->shape[0];
std::string cache_dir = dmlc::GetEnv("TVM_TENSORRT_CACHE_DIR", std::string(""));
if (cache_dir.empty()) return;
std::string key = GetSubgraphKey();
std::string path = cache_dir + "/" + key + ".plan";
DLOG(INFO) << "Caching TensorRT engine to " << path;
// Serialize engine to disk
nvinfer1::IHostMemory* serialized_engine = trt_engine_cache_[symbol_name_].engine->serialize();
nvinfer1::IHostMemory* serialized_engine =
trt_engine_cache_[std::make_pair(symbol_name_, batch_size_)].engine->serialize();
SaveBinaryToFile(path, std::string(static_cast<const char*>(serialized_engine->data()),
serialized_engine->size()));
serialized_engine->destroy();
// Serialize metadata
std::ostringstream os;
dmlc::JSONWriter writer(&os);
writer.BeginObject();
writer.WriteObjectKeyValue("inputs", trt_engine_cache_[symbol_name_].inputs);
writer.WriteObjectKeyValue("outputs", trt_engine_cache_[symbol_name_].outputs);
writer.WriteObjectKeyValue("inputs",
trt_engine_cache_[std::make_pair(symbol_name_, batch_size_)].inputs);
writer.WriteObjectKeyValue(
"outputs", trt_engine_cache_[std::make_pair(symbol_name_, batch_size_)].outputs);
writer.EndObject();
std::string meta_path = cache_dir + "/" + key + ".meta";
SaveBinaryToFile(meta_path, os.str());
Expand All @@ -290,7 +305,8 @@ class TensorRTRuntime : public JSONRuntimeBase {
}

/*! \brief Map of function name to TRT engine if built already. */
std::unordered_map<std::string, TensorRTEngineAndContext> trt_engine_cache_;
std::unordered_map<std::pair<std::string, int>, TensorRTEngineAndContext, PairHash>
trt_engine_cache_;

/*! \brief TensorRT logger. */
TensorRTLogger logger_;
Expand Down
Loading

0 comments on commit e6b06d5

Please sign in to comment.