From 780ee23e41ab5e1b23ba9dad829273e6a1fc7d0f Mon Sep 17 00:00:00 2001 From: banshan Date: Wed, 1 Nov 2023 10:35:30 +0800 Subject: [PATCH] fix single label error single label will get error: got an unexpected keyword argument 'keepdim' --- mindyolo/utils/metrics.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mindyolo/utils/metrics.py b/mindyolo/utils/metrics.py index 8a98bbd3..0f6d1733 100644 --- a/mindyolo/utils/metrics.py +++ b/mindyolo/utils/metrics.py @@ -85,9 +85,11 @@ def non_max_suppression( x = np.concatenate((box[i], x[i, j + 5, None], j[:, None].astype(np.float32)), 1) if nm == 0 else \ np.concatenate((box[i], x[i, j + 5, None], j[:, None].astype(np.float32), x[i, -nm:]), 1) else: # best class only - conf, j = x[:, 5:5+nc].max(1, keepdim=True) - x = np.concatenate((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] if nm == 0 else \ - np.concatenate((box, conf, j.float(), x[:, -nm:]), 1)[conf.view(-1) > conf_thres] + conf = x[:, 5:5+nc].max(1, keepdims=True) # get maximum conf + j = np.argmax(x[:, 5:5+nc], axis=1,keepdims=True) # get maximum index + x = np.concatenate((box, conf, j.astype(np.float32)), 1)[conf.flatten() > conf_thres] if nm == 0 else \ + np.concatenate((box, conf, j.astype(np.float32), x[:, -nm:]), 1)[conf.flatten() > conf_thres] + # Filter by class if classes is not None: @@ -350,4 +352,4 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False): def sigmoid(x): return 1 / (1 + np.exp(-x)) -#---------------------------------------------------------- \ No newline at end of file +#----------------------------------------------------------