diff --git a/engines/tensorrt/src/main/native/ai_djl_tensorrt_jni_model.cc b/engines/tensorrt/src/main/native/ai_djl_tensorrt_jni_model.cc index 1bda9b57148..764467a7526 100644 --- a/engines/tensorrt/src/main/native/ai_djl_tensorrt_jni_model.cc +++ b/engines/tensorrt/src/main/native/ai_djl_tensorrt_jni_model.cc @@ -151,15 +151,15 @@ void TrtModel::buildModel() { } TrtSession *TrtModel::createSession() { - auto *session = new TrtSession(mEngine, mParams.maxBatchSize); - - CHECK(cudaSetDevice(mParams.device)); + auto *session = new TrtSession(mEngine, mParams.device, mParams.maxBatchSize); session->init(); return session; } void TrtSession::init() { + CHECK(cudaSetDevice(mDeviceId)); + mContext = mEngine->createExecutionContext(); int bindings = mEngine->getNbBindings(); mBufferSizes.reserve(bindings); @@ -221,6 +221,8 @@ void TrtSession::copyOutputs() { } void TrtSession::predict() { + CHECK(cudaSetDevice(mDeviceId)); + copyInputs(); bool status; diff --git a/engines/tensorrt/src/main/native/ai_djl_tensorrt_jni_model.h b/engines/tensorrt/src/main/native/ai_djl_tensorrt_jni_model.h index 94110e3ee59..896c9d5daeb 100644 --- a/engines/tensorrt/src/main/native/ai_djl_tensorrt_jni_model.h +++ b/engines/tensorrt/src/main/native/ai_djl_tensorrt_jni_model.h @@ -42,8 +42,8 @@ struct ModelParams { class TrtSession { public: - explicit TrtSession(std::shared_ptr engine, int batchSize) - : mEngine(std::move(engine)), mBatchSize(batchSize), mContext(nullptr) {} + explicit TrtSession(std::shared_ptr engine, int deviceId, int batchSize) + : mEngine(std::move(engine)), mDeviceId(deviceId), mBatchSize(batchSize), mContext(nullptr) {} void init(); nvinfer1::Dims getShape(const char* name); @@ -58,6 +58,7 @@ class TrtSession { private: std::shared_ptr mEngine; + int mDeviceId; nvinfer1::IExecutionContext* mContext; int32_t mBatchSize; std::vector mBufferSizes;