Skip to content

Commit

Permalink
[tensorrt] Fix multi gpu issue (#1249)
Browse files Browse the repository at this point in the history
Change-Id: Ib9b219b773a6fa52cac6dd6d59a2541530c91a43
  • Loading branch information
frankfliu committed Sep 24, 2021
1 parent 0c311d0 commit 5907d50
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
8 changes: 5 additions & 3 deletions engines/tensorrt/src/main/native/ai_djl_tensorrt_jni_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,15 @@ void TrtModel::buildModel() {
}

TrtSession *TrtModel::createSession() {
auto *session = new TrtSession(mEngine, mParams.maxBatchSize);

CHECK(cudaSetDevice(mParams.device));
auto *session = new TrtSession(mEngine, mParams.device, mParams.maxBatchSize);

session->init();
return session;
}

void TrtSession::init() {
CHECK(cudaSetDevice(mDeviceId));

mContext = mEngine->createExecutionContext();
int bindings = mEngine->getNbBindings();
mBufferSizes.reserve(bindings);
Expand Down Expand Up @@ -221,6 +221,8 @@ void TrtSession::copyOutputs() {
}

void TrtSession::predict() {
CHECK(cudaSetDevice(mDeviceId));

copyInputs();

bool status;
Expand Down
5 changes: 3 additions & 2 deletions engines/tensorrt/src/main/native/ai_djl_tensorrt_jni_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ struct ModelParams {

class TrtSession {
public:
explicit TrtSession(std::shared_ptr<nvinfer1::ICudaEngine> engine, int batchSize)
: mEngine(std::move(engine)), mBatchSize(batchSize), mContext(nullptr) {}
explicit TrtSession(std::shared_ptr<nvinfer1::ICudaEngine> engine, int deviceId, int batchSize)
: mEngine(std::move(engine)), mDeviceId(deviceId), mBatchSize(batchSize), mContext(nullptr) {}

void init();
nvinfer1::Dims getShape(const char* name);
Expand All @@ -58,6 +58,7 @@ class TrtSession {

private:
std::shared_ptr<nvinfer1::ICudaEngine> mEngine;
int mDeviceId;
nvinfer1::IExecutionContext* mContext;
int32_t mBatchSize;
std::vector<size_t> mBufferSizes;
Expand Down

0 comments on commit 5907d50

Please sign in to comment.