From 413dbef234e1aa949501ee7d832e40d4ffa031f1 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 15 May 2022 19:44:16 +0200 Subject: [PATCH] Refactor modules (#7823) --- models/experimental.py | 24 ++++++++++-------------- models/tf.py | 14 ++++++-------- models/yolo.py | 2 +- 3 files changed, 17 insertions(+), 23 deletions(-) diff --git a/models/experimental.py b/models/experimental.py index b8d4d70d26e8..1ffe0fcb5971 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -78,9 +78,7 @@ def __init__(self): super().__init__() def forward(self, x, augment=False, profile=False, visualize=False): - y = [] - for module in self: - y.append(module(x, augment, profile, visualize)[0]) + y = [module(x, augment, profile, visualize)[0] for module in self] # y = torch.stack(y).max(0)[0] # max ensemble # y = torch.stack(y).mean(0) # mean ensemble y = torch.cat(y, 1) # nms ensemble @@ -102,10 +100,9 @@ def attempt_load(weights, map_location=None, inplace=True, fuse=True): t = type(m) if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model): m.inplace = inplace # torch 1.7.0 compatibility - if t is Detect: - if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility - delattr(m, 'anchor_grid') - setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl) + if t is Detect and not isinstance(m.anchor_grid, list): + delattr(m, 'anchor_grid') + setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl) elif t is Conv: m._non_persistent_buffers_set = set() # torch 1.6.0 compatibility elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'): @@ -113,10 +110,9 @@ def attempt_load(weights, map_location=None, inplace=True, fuse=True): if len(model) == 1: return model[-1] # return model - else: - print(f'Ensemble created with {weights}\n') - for k in 'names', 'nc', 'yaml': - setattr(model, k, getattr(model[0], k)) - model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride - assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}' - return model # return ensemble + print(f'Ensemble created with {weights}\n') + for k in 'names', 'nc', 'yaml': + setattr(model, k, getattr(model[0], k)) + model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride + assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}' + return model # return ensemble diff --git a/models/tf.py b/models/tf.py index 04b1cd378f18..7e0d61729e36 100644 --- a/models/tf.py +++ b/models/tf.py @@ -362,7 +362,7 @@ def predict(self, conf_thres=0.25): y = [] # outputs x = inputs - for i, m in enumerate(self.model.layers): + for m in self.model.layers: if m.f != -1: # if not from previous layer x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers @@ -377,7 +377,6 @@ def predict(self, scores = probs * classes if agnostic_nms: nms = AgnosticNMS()((boxes, classes, scores), topk_all, iou_thres, conf_thres) - return nms, x[1] else: boxes = tf.expand_dims(boxes, 2) nms = tf.image.combined_non_max_suppression(boxes, @@ -387,8 +386,7 @@ def predict(self, iou_thres, conf_thres, clip_boxes=False) - return nms, x[1] - + return nms, x[1] return x[0] # output only first tensor [1,6300,85] = [xywh, conf, class0, class1, ...] # x = x[0][0] # [x(1,6300,85), ...] to x(6300,85) # xywh = x[..., :4] # x(6300,4) boxes @@ -444,10 +442,10 @@ def _nms(x, topk_all=100, iou_thres=0.45, conf_thres=0.25): # agnostic NMS def representative_dataset_gen(dataset, ncalib=100): # Representative dataset generator for use with converter.representative_dataset, returns a generator of np arrays for n, (path, img, im0s, vid_cap, string) in enumerate(dataset): - input = np.transpose(img, [1, 2, 0]) - input = np.expand_dims(input, axis=0).astype(np.float32) - input /= 255 - yield [input] + im = np.transpose(img, [1, 2, 0]) + im = np.expand_dims(im, axis=0).astype(np.float32) + im /= 255 + yield [im] if n >= ncalib: break diff --git a/models/yolo.py b/models/yolo.py index b17a59c376f6..55356a6a9b44 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -197,7 +197,7 @@ def _profile_one_layer(self, m, x, dt): m(x.copy() if c else x) dt.append((time_sync() - t) * 100) if m == self.model[0]: - LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} {'module'}") + LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module") LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}') if c: LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")