Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor new model.warmup() method #5810

Merged
merged 2 commits into from
Nov 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
vid_path, vid_writer = [None] * bs, [None] * bs

# Run inference
if pt and device.type != 'cpu':
model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.model.parameters()))) # warmup
model.warmup(imgsz=(1, 3, *imgsz), half=half) # warmup
dt, seen = [0.0, 0.0, 0.0], 0
for path, im, im0s, vid_cap, s in dataset:
t1 = time_sync()
Expand Down
7 changes: 7 additions & 0 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,13 @@ def forward(self, im, augment=False, visualize=False, val=False):
y = torch.tensor(y) if isinstance(y, np.ndarray) else y
return (y, []) if val else y

def warmup(self, imgsz=(1, 3, 640, 640), half=False):
# Warmup model by running inference once
if self.pt or self.engine or self.onnx: # warmup types
if isinstance(self.device, torch.device) and self.device.type != 'cpu': # only warmup GPU models
im = torch.zeros(*imgsz).to(self.device).type(torch.half if half else torch.float) # input image
self.forward(im) # warmup


class AutoShape(nn.Module):
# YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
Expand Down
3 changes: 1 addition & 2 deletions val.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,7 @@ def run(data,

# Dataloader
if not training:
if pt and device.type != 'cpu':
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.model.parameters()))) # warmup
model.warmup(imgsz=(1, 3, imgsz, imgsz), half=half) # warmup
pad = 0.0 if task == 'speed' else 0.5
task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images
dataloader = create_dataloader(data[task], imgsz, batch_size, stride, single_cls, pad=pad, rect=pt,
Expand Down