Skip to content

Commit

Permalink
Merge 4b7b094 into 0e09f00
Browse files Browse the repository at this point in the history
  • Loading branch information
RangiLyu committed Jun 22, 2021
2 parents 0e09f00 + 4b7b094 commit f142df6
Show file tree
Hide file tree
Showing 12 changed files with 428 additions and 115 deletions.
10 changes: 7 additions & 3 deletions configs/_base_/models/ssd300.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
pretrained='open-mmlab://vgg16_caffe',
backbone=dict(
type='SSDVGG',
input_size=input_size,
depth=16,
with_last_pool=False,
ceil_mode=True,
out_indices=(3, 4),
out_feature_indices=(22, 34),
out_feature_indices=(22, 34)),
neck=dict(
type='SSDNeck',
in_channels=(512, 1024),
out_channels=(512, 1024, 512, 256, 256, 256),
level_strides=(2, 2, 1, 1),
level_paddings=(1, 1, 0, 0),
l2_norm_scale=20),
neck=None,
bbox_head=dict(
type='SSDHead',
in_channels=(512, 1024, 512, 256, 256, 256),
Expand Down
1 change: 0 additions & 1 deletion configs/pascal_voc/ssd512_voc0712.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
_base_ = 'ssd300_voc0712.py'
input_size = 512
model = dict(
backbone=dict(input_size=input_size),
bbox_head=dict(
in_channels=(512, 1024, 512, 256, 256, 256, 256),
anchor_generator=dict(
Expand Down
18 changes: 16 additions & 2 deletions configs/ssd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,19 @@

| Backbone | Size | Style | Lr schd | Mem (GB) | Inf time (fps) | box AP | Config | Download |
| :------: | :---: | :---: | :-----: | :------: | :------------: | :----: | :------: | :--------: |
| VGG16 | 300 | caffe | 120e | 10.2 | 43.7 | 25.6 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd/ssd300_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/ssd/ssd300_coco/ssd300_coco_20200307-a92d2092.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/ssd/ssd300_coco/ssd300_coco_20200307_174216.log.json) |
| VGG16 | 512 | caffe | 120e | 9.3 | 30.7 | 29.4 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd/ssd512_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/ssd/ssd512_coco/ssd512_coco_20200308-038c5591.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/ssd/ssd512_coco/ssd512_coco_20200308_134447.log.json) |
| VGG16 | 300 | caffe | 120e | 9.9 | 43.7 | 25.9 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd/ssd300_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/ssd/ssd300_coco/ssd300_coco_20210604_193052-b61137df.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/ssd/ssd300_coco/ssd300_coco_20210604_193052.log.json) |
| VGG16 | 512 | caffe | 120e | 19.4 | 30.7 | 29.8 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd/ssd512_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/ssd/ssd512_coco/ssd512_coco_20210604_111835-d3eba047.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/ssd/ssd512_coco/ssd512_coco_20210604_111835.log.json) |

## Notice

In v2.14.0, [PR5291](https://github.com/open-mmlab/mmdetection/pull/5291) refactored SSD neck and head for more
flexible usage. If users want to use the SSD checkpoint trained in the older versions, we provide a scripts
`tools/model_converters/upgrade_ssd_version.py` to convert the model weights.

```bash
python tools/model_converters/upgrade_ssd_version.py ${OLD_MODEL_PATH} ${NEW_MODEL_PATH}

```

- OLD_MODEL_PATH: the path to load the old version SSD model.
- NEW_MODEL_PATH: the path to save the converted model weights.
12 changes: 6 additions & 6 deletions configs/ssd/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,24 @@ Models:
In Collection: SSD
Config: configs/ssd/ssd300_coco.py
Metadata:
Training Memory (GB): 10.2
Training Memory (GB): 9.9
inference time (s/im): 0.02288
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 25.6
Weights: https://download.openmmlab.com/mmdetection/v2.0/ssd/ssd300_coco/ssd300_coco_20200307-a92d2092.pth
box AP: 25.9
Weights: https://download.openmmlab.com/mmdetection/v2.0/ssd/ssd300_coco/ssd300_coco_20210604_193052-b61137df.pth

- Name: ssd512_coco
In Collection: SSD
Config: configs/ssd/ssd512_coco.py
Metadata:
Training Memory (GB): 9.3
Training Memory (GB): 19.4
inference time (s/im): 0.03257
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 29.4
Weights: https://download.openmmlab.com/mmdetection/v2.0/ssd/ssd512_coco/ssd512_coco_20200308-038c5591.pth
box AP: 29.8
Weights: https://download.openmmlab.com/mmdetection/v2.0/ssd/ssd512_coco/ssd512_coco_20210604_111835-d3eba047.pth
6 changes: 5 additions & 1 deletion configs/ssd/ssd512_coco.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
_base_ = 'ssd300_coco.py'
input_size = 512
model = dict(
backbone=dict(input_size=input_size),
neck=dict(
out_channels=(512, 1024, 512, 256, 256, 256, 256),
level_strides=(2, 2, 2, 2, 1),
level_paddings=(1, 1, 1, 1, 1),
last_kernel_size=4),
bbox_head=dict(
in_channels=(512, 1024, 512, 256, 256, 256, 256),
anchor_generator=dict(
Expand Down
13 changes: 13 additions & 0 deletions docs/compatibility.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
# Compatibility of MMDetection 2.x

## MMDetection 2.14.0

### SSD compatibility

In v2.14.0, to make SSD more flexible to use, [PR5291](https://github.com/open-mmlab/mmdetection/pull/5291) refactored its backbone, neck and head. The users can use the script `tools/model_converters/upgrade_ssd_version.py` to convert their models.

```bash
python tools/model_converters/upgrade_ssd_version.py ${OLD_MODEL_PATH} ${NEW_MODEL_PATH}
```

- OLD_MODEL_PATH: the path to load the old version SSD model.
- NEW_MODEL_PATH: the path to save the converted model weights.

## MMDetection 2.12.0

MMDetection is going through big refactoring for more general and convenient usages during the releases from v2.12.0 to v2.15.0 (maybe longer).
Expand Down
107 changes: 26 additions & 81 deletions mmdet/models/backbones/ssd_vgg.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,32 @@
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import VGG
from mmcv.runner import BaseModule, Sequential
from mmcv.runner import BaseModule

from ..builder import BACKBONES
from ..necks import ssd_neck


@BACKBONES.register_module()
class SSDVGG(VGG, BaseModule):
"""VGG Backbone network for single-shot-detection.
Args:
input_size (int): width and height of input, from {300, 512}.
depth (int): Depth of vgg, from {11, 13, 16, 19}.
with_last_pool (bool): Whether to add a pooling layer at the last
of the model
ceil_mode (bool): When True, will use `ceil` instead of `floor`
to compute the output shape.
out_indices (Sequence[int]): Output from which stages.
out_feature_indices (Sequence[int]): Output from which feature map.
pretrained (str, optional): model pretrained path. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
input_size (int, optional): Deprecated argumment.
Width and height of input, from {300, 512}.
l2_norm_scale (float, optional) : Deprecated argumment.
L2 normalization layer init scale.
Example:
>>> self = SSDVGG(input_size=300, depth=11)
Expand All @@ -40,23 +47,21 @@ class SSDVGG(VGG, BaseModule):
}

def __init__(self,
input_size,
depth,
with_last_pool=False,
ceil_mode=True,
out_indices=(3, 4),
out_feature_indices=(22, 34),
l2_norm_scale=20.,
pretrained=None,
init_cfg=None):
init_cfg=None,
input_size=None,
l2_norm_scale=None):
# TODO: in_channels for mmcv.VGG
super(SSDVGG, self).__init__(
depth,
with_last_pool=with_last_pool,
ceil_mode=ceil_mode,
out_indices=out_indices)
assert input_size in (300, 512)
self.input_size = input_size

self.features.add_module(
str(len(self.features)),
Expand All @@ -72,18 +77,17 @@ def __init__(self,
str(len(self.features)), nn.ReLU(inplace=True))
self.out_feature_indices = out_feature_indices

self.inplanes = 1024
self.extra = self._make_extra_layers(self.extra_setting[input_size])
self.l2_norm = L2Norm(
self.features[out_feature_indices[0] - 1].out_channels,
l2_norm_scale)

assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
self.init_cfg = [dict(type='Pretrained', checkpoint=pretrained)]
if input_size is not None:
warnings.warn('DeprecationWarning: input_size is deprecated')
if l2_norm_scale is not None:
warnings.warn('DeprecationWarning: l2_norm_scale in VGG is '
'deprecated, it has been moved to SSDNeck.')
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
Expand All @@ -94,18 +98,6 @@ def __init__(self,
else:
raise TypeError('pretrained must be a str or None')

if init_cfg is None:
self.init_cfg += [
dict(
type='Xavier',
distribution='uniform',
override=dict(name='extra')),
dict(
type='Constant',
val=self.l2_norm.scale,
override=dict(name='l2_norm'))
]

def init_weights(self, pretrained=None):
super(VGG, self).init_weights()

Expand All @@ -116,64 +108,17 @@ def forward(self, x):
x = layer(x)
if i in self.out_feature_indices:
outs.append(x)
for i, layer in enumerate(self.extra):
x = F.relu(layer(x), inplace=True)
if i % 2 == 1:
outs.append(x)
outs[0] = self.l2_norm(outs[0])

if len(outs) == 1:
return outs[0]
else:
return tuple(outs)

def _make_extra_layers(self, outplanes):
layers = []
kernel_sizes = (1, 3)
num_layers = 0
outplane = None
for i in range(len(outplanes)):
if self.inplanes == 'S':
self.inplanes = outplane
continue
k = kernel_sizes[num_layers % 2]
if outplanes[i] == 'S':
outplane = outplanes[i + 1]
conv = nn.Conv2d(
self.inplanes, outplane, k, stride=2, padding=1)
else:
outplane = outplanes[i]
conv = nn.Conv2d(
self.inplanes, outplane, k, stride=1, padding=0)
layers.append(conv)
self.inplanes = outplanes[i]
num_layers += 1
if self.input_size == 512:
layers.append(nn.Conv2d(self.inplanes, 256, 4, padding=1))

return Sequential(*layers)

class L2Norm(ssd_neck.L2Norm):

class L2Norm(nn.Module):

def __init__(self, n_dims, scale=20., eps=1e-10):
"""L2 normalization layer.
Args:
n_dims (int): Number of dimensions to be normalized
scale (float, optional): Defaults to 20..
eps (float, optional): Used to avoid division by zero.
Defaults to 1e-10.
"""
super(L2Norm, self).__init__()
self.n_dims = n_dims
self.weight = nn.Parameter(torch.Tensor(self.n_dims))
self.eps = eps
self.scale = scale

def forward(self, x):
"""Forward function."""
# normalization layer convert to FP32 in FP16 training
x_float = x.float()
norm = x_float.pow(2).sum(1, keepdim=True).sqrt() + self.eps
return (self.weight[None, :, None, None].float().expand_as(x_float) *
x_float / norm).type_as(x)
def __init__(self, **kwargs):
super(L2Norm, self).__init__(**kwargs)
warnings.warn('DeprecationWarning: L2Norm in ssd_vgg.py '
'is deprecated, please use L2Norm in '
'mmdet/models/necks/ssd_neck.py instead')
Loading

0 comments on commit f142df6

Please sign in to comment.