Skip to content

Commit

Permalink
Fix issue in refine_bboxes and add doctest (open-mmlab#1962)
Browse files Browse the repository at this point in the history
* Fix issue in refine_bboxes and add doctest

* fix pillow version on travis

* Fixes based on review

* Fix errors in doctest and add comprehensive unit test

* Fix linting error
  • Loading branch information
Erotemic authored and ioir123ju committed Mar 30, 2020
1 parent a404993 commit d85b8a7
Show file tree
Hide file tree
Showing 5 changed files with 279 additions and 4 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ before_install:

install:
- pip install Cython torch==1.2
- pip install Pillow==6.2.2
- pip install -r requirements.txt
- pip install -r tests/requirements.txt

Expand Down
65 changes: 65 additions & 0 deletions mmdet/core/bbox/demodata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy as np
import torch


def ensure_rng(rng=None):
"""
Simple version of the ``kwarray.ensure_rng``
Args:
rng (int | numpy.random.RandomState | None):
if None, then defaults to the global rng. Otherwise this can be an
integer or a RandomState class
Returns:
(numpy.random.RandomState) : rng -
a numpy random number generator
References:
https://gitlab.kitware.com/computer-vision/kwarray/blob/master/kwarray/util_random.py#L270
"""

if rng is None:
rng = np.random.mtrand._rand
elif isinstance(rng, int):
rng = np.random.RandomState(rng)
else:
rng = rng
return rng


def random_boxes(num=1, scale=1, rng=None):
"""
Simple version of ``kwimage.Boxes.random``
Returns:
Tensor: shape (n, 4) in x1, y1, x2, y2 format.
References:
https://gitlab.kitware.com/computer-vision/kwimage/blob/master/kwimage/structs/boxes.py#L1390
Example:
>>> num = 3
>>> scale = 512
>>> rng = 0
>>> boxes = random_boxes(num, scale, rng)
>>> print(boxes)
tensor([[280.9925, 278.9802, 308.6148, 366.1769],
[216.9113, 330.6978, 224.0446, 456.5878],
[405.3632, 196.3221, 493.3953, 270.7942]])
"""
rng = ensure_rng(rng)

tlbr = rng.rand(num, 4).astype(np.float32)

tl_x = np.minimum(tlbr[:, 0], tlbr[:, 2])
tl_y = np.minimum(tlbr[:, 1], tlbr[:, 3])
br_x = np.maximum(tlbr[:, 0], tlbr[:, 2])
br_y = np.maximum(tlbr[:, 1], tlbr[:, 3])

tlbr[:, 0] = tl_x * scale
tlbr[:, 1] = tl_y * scale
tlbr[:, 2] = br_x * scale
tlbr[:, 3] = br_y * scale

boxes = torch.from_numpy(tlbr)
return boxes
45 changes: 41 additions & 4 deletions mmdet/models/bbox_heads/bbox_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas):
Args:
rois (Tensor): Shape (n*bs, 5), where n is image number per GPU,
and bs is the sampled RoIs per image.
and bs is the sampled RoIs per image. The first column is
the image id and the next 4 columns are x1, y1, x2, y2.
labels (Tensor): Shape (n*bs, ).
bbox_preds (Tensor): Shape (n*bs, 4) or (n*bs, 4*#class).
pos_is_gts (list[Tensor]): Flags indicating if each positive bbox
Expand All @@ -187,13 +188,48 @@ def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas):
Returns:
list[Tensor]: Refined bboxes of each image in a mini-batch.
Example:
>>> # xdoctest: +REQUIRES(module:kwarray)
>>> import kwarray
>>> import numpy as np
>>> from mmdet.core.bbox.demodata import random_boxes
>>> self = BBoxHead(reg_class_agnostic=True)
>>> n_roi = 2
>>> n_img = 4
>>> scale = 512
>>> rng = np.random.RandomState(0)
>>> img_metas = [{'img_shape': (scale, scale)}
... for _ in range(n_img)]
>>> # Create rois in the expected format
>>> roi_boxes = random_boxes(n_roi, scale=scale, rng=rng)
>>> img_ids = torch.randint(0, n_img, (n_roi,))
>>> img_ids = img_ids.float()
>>> rois = torch.cat([img_ids[:, None], roi_boxes], dim=1)
>>> # Create other args
>>> labels = torch.randint(0, 2, (n_roi,)).long()
>>> bbox_preds = random_boxes(n_roi, scale=scale, rng=rng)
>>> # For each image, pretend random positive boxes are gts
>>> is_label_pos = (labels.numpy() > 0).astype(np.int)
>>> lbl_per_img = kwarray.group_items(is_label_pos,
... img_ids.numpy())
>>> pos_per_img = [sum(lbl_per_img.get(gid, []))
... for gid in range(n_img)]
>>> pos_is_gts = [
>>> torch.randint(0, 2, (npos,)).byte().sort(
>>> descending=True)[0]
>>> for npos in pos_per_img
>>> ]
>>> bboxes_list = self.refine_bboxes(rois, labels, bbox_preds,
>>> pos_is_gts, img_metas)
>>> print(bboxes_list)
"""
img_ids = rois[:, 0].long().unique(sorted=True)
assert img_ids.numel() == len(img_metas)
assert img_ids.numel() <= len(img_metas)

bboxes_list = []
for i in range(len(img_metas)):
inds = torch.nonzero(rois[:, 0] == i).squeeze()
inds = torch.nonzero(rois[:, 0] == i).squeeze(dim=1)
num_rois = inds.numel()

bboxes_ = rois[inds, 1:]
Expand All @@ -204,6 +240,7 @@ def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas):

bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_,
img_meta_)

# filter gt bboxes
pos_keep = 1 - pos_is_gts_
keep_inds = pos_is_gts_.new_ones(num_rois)
Expand All @@ -226,7 +263,7 @@ def regress_by_class(self, rois, label, bbox_pred, img_meta):
Returns:
Tensor: Regressed bboxes, the same shape as input rois.
"""
assert rois.size(1) == 4 or rois.size(1) == 5
assert rois.size(1) == 4 or rois.size(1) == 5, repr(rois.shape)

if not self.reg_class_agnostic:
label = label * 4
Expand Down
3 changes: 3 additions & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ pytest-cov
codecov
xdoctest >= 0.10.0
asynctest

# Note: used for kwarray.group_items, this may be ported to mmcv in the future.
kwarray
169 changes: 169 additions & 0 deletions tests/test_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,172 @@ def _dummy_bbox_sampling(proposal_list, gt_bboxes, gt_labels):
bbox_targets, bbox_weights)
assert losses.get('loss_cls', 0) > 0, 'cls-loss should be non-zero'
assert losses.get('loss_bbox', 0) > 0, 'box-loss should be non-zero'


def test_refine_boxes():
"""
Mirrors the doctest in
``mmdet.models.bbox_heads.bbox_head.BBoxHead.refine_boxes`` but checks for
multiple values of n_roi / n_img.
"""
self = BBoxHead(reg_class_agnostic=True)

test_settings = [

# Corner case: less rois than images
{
'n_roi': 2,
'n_img': 4,
'rng': 34285940
},

# Corner case: no images
{
'n_roi': 0,
'n_img': 0,
'rng': 52925222
},

# Corner cases: few images / rois
{
'n_roi': 1,
'n_img': 1,
'rng': 1200281
},
{
'n_roi': 2,
'n_img': 1,
'rng': 1200282
},
{
'n_roi': 2,
'n_img': 2,
'rng': 1200283
},
{
'n_roi': 1,
'n_img': 2,
'rng': 1200284
},

# Corner case: no rois few images
{
'n_roi': 0,
'n_img': 1,
'rng': 23955860
},
{
'n_roi': 0,
'n_img': 2,
'rng': 25830516
},

# Corner case: no rois many images
{
'n_roi': 0,
'n_img': 10,
'rng': 671346
},
{
'n_roi': 0,
'n_img': 20,
'rng': 699807
},

# Corner case: similar num rois and images
{
'n_roi': 20,
'n_img': 20,
'rng': 1200238
},
{
'n_roi': 10,
'n_img': 20,
'rng': 1200238
},
{
'n_roi': 5,
'n_img': 5,
'rng': 1200238
},

# ----------------------------------
# Common case: more rois than images
{
'n_roi': 100,
'n_img': 1,
'rng': 337156
},
{
'n_roi': 150,
'n_img': 2,
'rng': 275898
},
{
'n_roi': 500,
'n_img': 5,
'rng': 4903221
},
]

for demokw in test_settings:
try:
n_roi = demokw['n_roi']
n_img = demokw['n_img']
rng = demokw['rng']

print('Test refine_boxes case: {!r}'.format(demokw))
tup = _demodata_refine_boxes(n_roi, n_img, rng=rng)
rois, labels, bbox_preds, pos_is_gts, img_metas = tup
bboxes_list = self.refine_bboxes(rois, labels, bbox_preds,
pos_is_gts, img_metas)
assert len(bboxes_list) == n_img
assert sum(map(len, bboxes_list)) <= n_roi
assert all(b.shape[1] == 4 for b in bboxes_list)
except Exception:
print('Test failed with demokw={!r}'.format(demokw))
raise


def _demodata_refine_boxes(n_roi, n_img, rng=0):
"""
Create random test data for the
``mmdet.models.bbox_heads.bbox_head.BBoxHead.refine_boxes`` method
"""
import numpy as np
from mmdet.core.bbox.demodata import random_boxes
from mmdet.core.bbox.demodata import ensure_rng
try:
import kwarray
except ImportError:
import pytest
pytest.skip('kwarray is required for this test')
scale = 512
rng = ensure_rng(rng)
img_metas = [{'img_shape': (scale, scale)} for _ in range(n_img)]
# Create rois in the expected format
roi_boxes = random_boxes(n_roi, scale=scale, rng=rng)
if n_img == 0:
assert n_roi == 0, 'cannot have any rois if there are no images'
img_ids = torch.empty((0, ), dtype=torch.long)
roi_boxes = torch.empty((0, 4), dtype=torch.float32)
else:
img_ids = rng.randint(0, n_img, (n_roi, ))
img_ids = torch.from_numpy(img_ids)
rois = torch.cat([img_ids[:, None].float(), roi_boxes], dim=1)
# Create other args
labels = rng.randint(0, 2, (n_roi, ))
labels = torch.from_numpy(labels).long()
bbox_preds = random_boxes(n_roi, scale=scale, rng=rng)
# For each image, pretend random positive boxes are gts
is_label_pos = (labels.numpy() > 0).astype(np.int)
lbl_per_img = kwarray.group_items(is_label_pos, img_ids.numpy())
pos_per_img = [sum(lbl_per_img.get(gid, [])) for gid in range(n_img)]
# randomly generate with numpy then sort with torch
_pos_is_gts = [
rng.randint(0, 2, (npos, )).astype(np.uint8) for npos in pos_per_img
]
pos_is_gts = [
torch.from_numpy(p).sort(descending=True)[0] for p in _pos_is_gts
]
return rois, labels, bbox_preds, pos_is_gts, img_metas

0 comments on commit d85b8a7

Please sign in to comment.