Skip to content

Commit

Permalink
Merge pull request #231 from dwdcth/patch-1
Browse files Browse the repository at this point in the history
fix  single label error
  • Loading branch information
CaitinZhao committed Dec 14, 2023
2 parents 0983315 + 780ee23 commit a287863
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions mindyolo/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,11 @@ def non_max_suppression(
x = np.concatenate((box[i], x[i, j + 5, None], j[:, None].astype(np.float32)), 1) if nm == 0 else \
np.concatenate((box[i], x[i, j + 5, None], j[:, None].astype(np.float32), x[i, -nm:]), 1)
else: # best class only
conf, j = x[:, 5:5+nc].max(1, keepdim=True)
x = np.concatenate((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] if nm == 0 else \
np.concatenate((box, conf, j.float(), x[:, -nm:]), 1)[conf.view(-1) > conf_thres]
conf = x[:, 5:5+nc].max(1, keepdims=True) # get maximum conf
j = np.argmax(x[:, 5:5+nc], axis=1,keepdims=True) # get maximum index
x = np.concatenate((box, conf, j.astype(np.float32)), 1)[conf.flatten() > conf_thres] if nm == 0 else \
np.concatenate((box, conf, j.astype(np.float32), x[:, -nm:]), 1)[conf.flatten() > conf_thres]


# Filter by class
if classes is not None:
Expand Down Expand Up @@ -350,4 +352,4 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False):
def sigmoid(x):
return 1 / (1 + np.exp(-x))

#----------------------------------------------------------
#----------------------------------------------------------

0 comments on commit a287863

Please sign in to comment.