Skip to content

Commit

Permalink
[Fix] open-mmlab#1169 score_per_joint
Browse files Browse the repository at this point in the history
  • Loading branch information
piercus committed Feb 5, 2022
1 parent dca589a commit ce3a0e9
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 6 deletions.
7 changes: 6 additions & 1 deletion mmpose/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,7 @@ def inference_bottom_up_pose_model(model,

cfg = model.cfg
device = next(model.parameters()).device
score_per_joint = cfg.model.test_cfg.get('score_per_joint', False)

# build the data pipeline
channel_order = cfg.test_pipeline[0].get('channel_order', 'rgb')
Expand Down Expand Up @@ -576,7 +577,11 @@ def inference_bottom_up_pose_model(model,
})

# pose nms
keep = oks_nms(pose_results, pose_nms_thr, sigmas)
keep = oks_nms(
pose_results,
pose_nms_thr,
sigmas,
score_per_joint=score_per_joint)
pose_results = [pose_results[_keep] for _keep in keep]

return pose_results, returned_outputs
Expand Down
6 changes: 5 additions & 1 deletion mmpose/core/post_processing/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(self, cfg):
self.pool = torch.nn.MaxPool2d(cfg['nms_kernel'], 1,
cfg['nms_padding'])
self.use_udp = cfg.get('use_udp', False)
self.score_per_joint = cfg.get('score_per_joint', False)

def nms(self, heatmaps):
"""Non-Maximum Suppression for heatmaps.
Expand Down Expand Up @@ -388,7 +389,10 @@ def parse(self, heatmaps, tags, adjust=True, refine=True):
else:
results = self.adjust(results, heatmaps)

scores = [i[:, 2].mean() for i in results[0]]
if self.score_per_joint:
scores = [i[:, 2] for i in results[0]]
else:
scores = [i[:, 2].mean() for i in results[0]]

if refine:
results = results[0]
Expand Down
21 changes: 17 additions & 4 deletions mmpose/core/post_processing/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def oks_iou(g, d, a_g, a_d, sigmas=None, vis_thr=None):
return ious


def oks_nms(kpts_db, thr, sigmas=None, vis_thr=None):
def oks_nms(kpts_db, thr, sigmas=None, vis_thr=None, score_per_joint=False):
"""OKS NMS implementations.
Args:
Expand All @@ -101,7 +101,11 @@ def oks_nms(kpts_db, thr, sigmas=None, vis_thr=None):
if len(kpts_db) == 0:
return []

scores = np.array([k['score'] for k in kpts_db])
if score_per_joint:
scores = np.array([k['score'].mean() for k in kpts_db])
else:
scores = np.array([k['score'] for k in kpts_db])

kpts = np.array([k['keypoints'].flatten() for k in kpts_db])
areas = np.array([k['area'] for k in kpts_db])

Expand Down Expand Up @@ -147,7 +151,12 @@ def _rescore(overlap, scores, thr, type='gaussian'):
return scores


def soft_oks_nms(kpts_db, thr, max_dets=20, sigmas=None, vis_thr=None):
def soft_oks_nms(kpts_db,
thr,
max_dets=20,
sigmas=None,
vis_thr=None,
score_per_joint=False):
"""Soft OKS NMS implementations.
Args:
Expand All @@ -162,7 +171,11 @@ def soft_oks_nms(kpts_db, thr, max_dets=20, sigmas=None, vis_thr=None):
if len(kpts_db) == 0:
return []

scores = np.array([k['score'] for k in kpts_db])
if score_per_joint:
scores = np.array([k['score'].mean() for k in kpts_db])
else:
scores = np.array([k['score'] for k in kpts_db])

kpts = np.array([k['keypoints'].flatten() for k in kpts_db])
areas = np.array([k['area'] for k in kpts_db])

Expand Down
32 changes: 32 additions & 0 deletions tests/test_post_processing/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,41 @@ def test_group():
fake_tag[0, 8, 6, 6] = 0.9
grouped, scores = parser.parse(fake_heatmap, fake_tag, True, True)
assert grouped[0][0, 0, 0] == 10.25
assert abs(scores[0] - 0.2) < 0.001
cfg['tag_per_joint'] = False
parser = HeatmapParser(cfg)
grouped, scores = parser.parse(fake_heatmap, fake_tag, False, False)
assert grouped[0][0, 0, 0] == 10.
grouped, scores = parser.parse(fake_heatmap, fake_tag, False, True)
assert grouped[0][0, 0, 0] == 10.


def test_group_score_per_joint():
cfg = {}
cfg['num_joints'] = 17
cfg['detection_threshold'] = 0.1
cfg['tag_threshold'] = 1
cfg['use_detection_val'] = True
cfg['ignore_too_much'] = False
cfg['nms_kernel'] = 5
cfg['nms_padding'] = 2
cfg['tag_per_joint'] = True
cfg['max_num_people'] = 1
cfg['score_per_joint'] = True
parser = HeatmapParser(cfg)
fake_heatmap = torch.zeros(1, 1, 5, 5)
fake_heatmap[0, 0, 3, 3] = 1
fake_heatmap[0, 0, 3, 2] = 0.8
assert parser.nms(fake_heatmap)[0, 0, 3, 2] == 0
fake_heatmap = torch.zeros(1, 17, 32, 32)
fake_tag = torch.zeros(1, 17, 32, 32, 1)
fake_heatmap[0, 0, 10, 10] = 0.8
fake_heatmap[0, 1, 12, 12] = 0.9
fake_heatmap[0, 4, 8, 8] = 0.8
fake_heatmap[0, 8, 6, 6] = 0.9
fake_tag[0, 0, 10, 10] = 0.8
fake_tag[0, 1, 12, 12] = 0.9
fake_tag[0, 4, 8, 8] = 0.8
fake_tag[0, 8, 6, 6] = 0.9
grouped, scores = parser.parse(fake_heatmap, fake_tag, True, True)
assert len(scores[0]) == 17

0 comments on commit ce3a0e9

Please sign in to comment.