diff --git a/src/torch/UpscaleImageTensorRT.py b/src/torch/UpscaleImageTensorRT.py index ce6296a4..2ca5c7a9 100644 --- a/src/torch/UpscaleImageTensorRT.py +++ b/src/torch/UpscaleImageTensorRT.py @@ -45,14 +45,8 @@ def __init__( height (int): The height of the input frame nt (int): The number of threads to use """ - original_env = os.environ.copy() - if getattr(sys, 'frozen', False): - cuda_runtime_dir = os.path.join(os.getcwd(), '_internal', 'nvidia', 'cuda_runtime', 'lib') - else: - site_packages = site.getsitepackages()[0] - cuda_runtime_dir = os.path.join(site_packages, 'nvidia', 'cuda_runtime', 'lib') - os.environ['LD_LIBRARY_PATH'] = f'{cuda_runtime_dir}lib:$LD_LIBRARY_PATH' - + self.original_env = os.environ.copy() + self.fixTRTBullshit() from polygraphy.backend.trt import ( TrtRunner, engine_from_network, @@ -71,8 +65,6 @@ def __init__( self.Profile = Profile self.EngineFromBytes = EngineFromBytes self.SaveEngine = SaveEngine - os.environ.clear() - os.environ.update(original_env) self.modelPath = modelPath self.upscaleFactor = upscaleFactor @@ -88,7 +80,20 @@ def __init__( if not os.path.exists(self.locationOfOnnxModel): self.pytorchExportToONNX() self.handleModel() - + self.clearTRTBullshit() + def fixTRTBullshit(self): + + if getattr(sys, 'frozen', False): + cuda_runtime_dir = os.path.join(os.getcwd(), '_internal', 'nvidia', 'cuda_runtime', 'lib') + else: + site_packages = site.getsitepackages()[0] + cuda_runtime_dir = os.path.join(site_packages, 'nvidia', 'cuda_runtime', 'lib') + os.environ['LD_LIBRARY_PATH'] = f'{cuda_runtime_dir}:$LD_LIBRARY_PATH' + + def clearTRTBullshit(self): + os.environ.clear() + os.environ.update(self.original_env) + def pytorchExportToONNX(self): # Loads model via spandrel, and exports to onnx model = ModelLoader().load_from_file(self.modelPath) model = model.model @@ -144,6 +149,7 @@ def handleModel(self): # TO:DO account for FP16/FP32 self.enginePath = f'{self.locationOfOnnxModel.replace(".onnx", "")}{self.width}x{self.height}_scaleFactor={self.upscaleFactor}_half={self.half}_tensorrtVer={self.trt_version}device={self.device_name}_bf16={self.bf16}.engine' if not os.path.exists(self.enginePath): + toPrint = f"Model engine not found, creating engine for model: {self.locationOfOnnxModel}, this may take a while..." self.guiLog.emit("Building Engine, this may take a while...") print((toPrint)) @@ -161,10 +167,11 @@ def handleModel(self): config=self.CreateConfig(fp16=self.half, profiles=profiles), ) self.engine = self.SaveEngine(self.engine, self.enginePath) + with self.TrtRunner(self.engine) as runner: self.runner = runner - + with open(self.enginePath, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.INFO)) as runtime: self.engine = runtime.deserialize_cuda_engine(f.read()) self.context = self.engine.create_execution_context() @@ -189,7 +196,7 @@ def handleModel(self): tensor_name = self.engine.get_tensor_name(i) if self.engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT: self.context.set_input_shape(tensor_name, self.dummyInput.shape) - + @torch.inference_mode() def UpscaleImage(self, frame: bytearray): with torch.cuda.stream(self.stream):