Skip to content

Commit

Permalink
fix crash
Browse files Browse the repository at this point in the history
  • Loading branch information
TNTwise committed May 21, 2024
1 parent 7522b72 commit 3b36427
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions src/torch/UpscaleImageTensorRT.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,8 @@ def __init__(
height (int): The height of the input frame
nt (int): The number of threads to use
"""
original_env = os.environ.copy()
if getattr(sys, 'frozen', False):
cuda_runtime_dir = os.path.join(os.getcwd(), '_internal', 'nvidia', 'cuda_runtime', 'lib')
else:
site_packages = site.getsitepackages()[0]
cuda_runtime_dir = os.path.join(site_packages, 'nvidia', 'cuda_runtime', 'lib')
os.environ['LD_LIBRARY_PATH'] = f'{cuda_runtime_dir}lib:$LD_LIBRARY_PATH'

self.original_env = os.environ.copy()
self.fixTRTBullshit()
from polygraphy.backend.trt import (
TrtRunner,
engine_from_network,
Expand All @@ -71,8 +65,6 @@ def __init__(
self.Profile = Profile
self.EngineFromBytes = EngineFromBytes
self.SaveEngine = SaveEngine
os.environ.clear()
os.environ.update(original_env)

self.modelPath = modelPath
self.upscaleFactor = upscaleFactor
Expand All @@ -88,7 +80,20 @@ def __init__(
if not os.path.exists(self.locationOfOnnxModel):
self.pytorchExportToONNX()
self.handleModel()

self.clearTRTBullshit()
def fixTRTBullshit(self):

if getattr(sys, 'frozen', False):
cuda_runtime_dir = os.path.join(os.getcwd(), '_internal', 'nvidia', 'cuda_runtime', 'lib')
else:
site_packages = site.getsitepackages()[0]
cuda_runtime_dir = os.path.join(site_packages, 'nvidia', 'cuda_runtime', 'lib')
os.environ['LD_LIBRARY_PATH'] = f'{cuda_runtime_dir}:$LD_LIBRARY_PATH'

def clearTRTBullshit(self):
os.environ.clear()
os.environ.update(self.original_env)

def pytorchExportToONNX(self): # Loads model via spandrel, and exports to onnx
model = ModelLoader().load_from_file(self.modelPath)
model = model.model
Expand Down Expand Up @@ -144,6 +149,7 @@ def handleModel(self):
# TO:DO account for FP16/FP32
self.enginePath = f'{self.locationOfOnnxModel.replace(".onnx", "")}{self.width}x{self.height}_scaleFactor={self.upscaleFactor}_half={self.half}_tensorrtVer={self.trt_version}device={self.device_name}_bf16={self.bf16}.engine'
if not os.path.exists(self.enginePath):

toPrint = f"Model engine not found, creating engine for model: {self.locationOfOnnxModel}, this may take a while..."
self.guiLog.emit("Building Engine, this may take a while...")
print((toPrint))
Expand All @@ -161,10 +167,11 @@ def handleModel(self):
config=self.CreateConfig(fp16=self.half, profiles=profiles),
)
self.engine = self.SaveEngine(self.engine, self.enginePath)

with self.TrtRunner(self.engine) as runner:
self.runner = runner


with open(self.enginePath, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.INFO)) as runtime:
self.engine = runtime.deserialize_cuda_engine(f.read())
self.context = self.engine.create_execution_context()
Expand All @@ -189,7 +196,7 @@ def handleModel(self):
tensor_name = self.engine.get_tensor_name(i)
if self.engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT:
self.context.set_input_shape(tensor_name, self.dummyInput.shape)

@torch.inference_mode()
def UpscaleImage(self, frame: bytearray):
with torch.cuda.stream(self.stream):
Expand Down

0 comments on commit 3b36427

Please sign in to comment.