Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement]: Refactor SSD #5291

Merged
merged 17 commits into from
Jun 22, 2021
10 changes: 7 additions & 3 deletions configs/_base_/models/ssd300.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
pretrained='open-mmlab://vgg16_caffe',
backbone=dict(
type='SSDVGG',
input_size=input_size,
depth=16,
with_last_pool=False,
ceil_mode=True,
out_indices=(3, 4),
out_feature_indices=(22, 34),
out_feature_indices=(22, 34)),
neck=dict(
type='SSDNeck',
in_channels=(512, 1024),
out_channels=(512, 1024, 512, 256, 256, 256),
level_strides=(2, 2, 1, 1),
level_paddings=(1, 1, 0, 0),
l2_norm_scale=20),
neck=None,
bbox_head=dict(
type='SSDHead',
in_channels=(512, 1024, 512, 256, 256, 256),
Expand Down
1 change: 0 additions & 1 deletion configs/pascal_voc/ssd512_voc0712.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
_base_ = 'ssd300_voc0712.py'
input_size = 512
model = dict(
backbone=dict(input_size=input_size),
bbox_head=dict(
in_channels=(512, 1024, 512, 256, 256, 256, 256),
anchor_generator=dict(
Expand Down
6 changes: 5 additions & 1 deletion configs/ssd/ssd512_coco.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
_base_ = 'ssd300_coco.py'
input_size = 512
model = dict(
backbone=dict(input_size=input_size),
neck=dict(
out_channels=(512, 1024, 512, 256, 256, 256, 256),
level_strides=(2, 2, 2, 2, 1),
level_paddings=(1, 1, 1, 1, 1),
last_kernel_size=4),
bbox_head=dict(
in_channels=(512, 1024, 512, 256, 256, 256, 256),
anchor_generator=dict(
Expand Down
88 changes: 9 additions & 79 deletions mmdet/models/backbones/ssd_vgg.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import VGG
from mmcv.runner import BaseModule, Sequential
from mmcv.runner import BaseModule

from ..builder import BACKBONES
from ..necks import ssd_neck


@BACKBONES.register_module()
Expand Down Expand Up @@ -40,13 +39,11 @@ class SSDVGG(VGG, BaseModule):
}

def __init__(self,
input_size,
RangiLyu marked this conversation as resolved.
Show resolved Hide resolved
depth,
with_last_pool=False,
ceil_mode=True,
out_indices=(3, 4),
out_feature_indices=(22, 34),
l2_norm_scale=20.,
pretrained=None,
init_cfg=None):
# TODO: in_channels for mmcv.VGG
Expand All @@ -55,8 +52,6 @@ def __init__(self,
with_last_pool=with_last_pool,
ceil_mode=ceil_mode,
out_indices=out_indices)
assert input_size in (300, 512)
self.input_size = input_size

self.features.add_module(
str(len(self.features)),
Expand All @@ -72,12 +67,6 @@ def __init__(self,
str(len(self.features)), nn.ReLU(inplace=True))
self.out_feature_indices = out_feature_indices

self.inplanes = 1024
self.extra = self._make_extra_layers(self.extra_setting[input_size])
self.l2_norm = L2Norm(
self.features[out_feature_indices[0] - 1].out_channels,
l2_norm_scale)

assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
Expand All @@ -94,18 +83,6 @@ def __init__(self,
else:
raise TypeError('pretrained must be a str or None')

if init_cfg is None:
self.init_cfg += [
dict(
type='Xavier',
distribution='uniform',
override=dict(name='extra')),
dict(
type='Constant',
val=self.l2_norm.scale,
override=dict(name='l2_norm'))
]

def init_weights(self, pretrained=None):
super(VGG, self).init_weights()

Expand All @@ -116,64 +93,17 @@ def forward(self, x):
x = layer(x)
if i in self.out_feature_indices:
outs.append(x)
for i, layer in enumerate(self.extra):
x = F.relu(layer(x), inplace=True)
if i % 2 == 1:
outs.append(x)
outs[0] = self.l2_norm(outs[0])

if len(outs) == 1:
return outs[0]
else:
return tuple(outs)

def _make_extra_layers(self, outplanes):
layers = []
kernel_sizes = (1, 3)
num_layers = 0
outplane = None
for i in range(len(outplanes)):
if self.inplanes == 'S':
self.inplanes = outplane
continue
k = kernel_sizes[num_layers % 2]
if outplanes[i] == 'S':
outplane = outplanes[i + 1]
conv = nn.Conv2d(
self.inplanes, outplane, k, stride=2, padding=1)
else:
outplane = outplanes[i]
conv = nn.Conv2d(
self.inplanes, outplane, k, stride=1, padding=0)
layers.append(conv)
self.inplanes = outplanes[i]
num_layers += 1
if self.input_size == 512:
layers.append(nn.Conv2d(self.inplanes, 256, 4, padding=1))

return Sequential(*layers)

class L2Norm(ssd_neck.L2Norm):

class L2Norm(nn.Module):
RangiLyu marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, n_dims, scale=20., eps=1e-10):
"""L2 normalization layer.

Args:
n_dims (int): Number of dimensions to be normalized
scale (float, optional): Defaults to 20..
eps (float, optional): Used to avoid division by zero.
Defaults to 1e-10.
"""
super(L2Norm, self).__init__()
self.n_dims = n_dims
self.weight = nn.Parameter(torch.Tensor(self.n_dims))
self.eps = eps
self.scale = scale

def forward(self, x):
"""Forward function."""
# normalization layer convert to FP32 in FP16 training
x_float = x.float()
norm = x_float.pow(2).sum(1, keepdim=True).sqrt() + self.eps
return (self.weight[None, :, None, None].float().expand_as(x_float) *
x_float / norm).type_as(x)
def __init__(self, **kwargs):
super(L2Norm, self).__init__(**kwargs)
warnings.warn('DeprecationWarning: L2Norm in ssd_vgg.py '
'is deprecated, please use L2Norm in '
'mmdet/models/necks/ssd_neck.py instead')
118 changes: 99 additions & 19 deletions mmdet/models/dense_heads/ssd_head.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.runner import ModuleList, force_fp32
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from mmcv.runner import force_fp32

from mmdet.core import (build_anchor_generator, build_assigner,
build_bbox_coder, build_sampler, multi_apply)
Expand All @@ -19,6 +20,18 @@ class SSDHead(AnchorHead):
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
stacked_convs (int): Number of conv layers in cls and reg tower.
Default: 0.
feat_channels (int): Number of hidden channels when stacked_convs
> 0. Default: 256.
use_depthwise (bool): Whether to use DepthwiseSeparableConv.
Default: False.
conv_cfg (dict): Dictionary to construct and config conv layer.
Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Default: None.
act_cfg (dict): Dictionary to construct and config activation layer.
Default: None.
anchor_generator (dict): Config dict for anchor generator
bbox_coder (dict): Config of bounding box coder.
reg_decoded_bbox (bool): If true, the regression loss would be
Expand All @@ -34,6 +47,12 @@ class SSDHead(AnchorHead):
def __init__(self,
num_classes=80,
in_channels=(512, 1024, 512, 256, 256, 256),
stacked_convs=0,
feat_channels=256,
use_depthwise=False,
conv_cfg=None,
norm_cfg=None,
act_cfg=None,
anchor_generator=dict(
type='SSDAnchorGenerator',
scale_major=False,
Expand All @@ -58,27 +77,18 @@ def __init__(self,
super(AnchorHead, self).__init__(init_cfg)
self.num_classes = num_classes
self.in_channels = in_channels
self.stacked_convs = stacked_convs
self.feat_channels = feat_channels
self.use_depthwise = use_depthwise
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg

self.cls_out_channels = num_classes + 1 # add background class
self.anchor_generator = build_anchor_generator(anchor_generator)
num_anchors = self.anchor_generator.num_base_anchors
self.num_anchors = self.anchor_generator.num_base_anchors

reg_convs = []
cls_convs = []
for i in range(len(in_channels)):
reg_convs.append(
nn.Conv2d(
in_channels[i],
num_anchors[i] * 4,
kernel_size=3,
padding=1))
cls_convs.append(
nn.Conv2d(
in_channels[i],
num_anchors[i] * (num_classes + 1),
kernel_size=3,
padding=1))
self.reg_convs = ModuleList(reg_convs)
self.cls_convs = ModuleList(cls_convs)
self._init_layers()

self.bbox_coder = build_bbox_coder(bbox_coder)
self.reg_decoded_bbox = reg_decoded_bbox
Expand All @@ -95,6 +105,76 @@ def __init__(self,
self.sampler = build_sampler(sampler_cfg, context=self)
self.fp16_enabled = False

def _init_layers(self):
"""Initialize layers of the head."""
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
# TODO: Use registry to choose ConvModule type
conv = DepthwiseSeparableConvModule \
RangiLyu marked this conversation as resolved.
Show resolved Hide resolved
if self.use_depthwise else ConvModule

for channel, num_anchors in zip(self.in_channels, self.num_anchors):
cls_layers = []
reg_layers = []
in_channel = channel
# build stacked conv tower, not used in default ssd
for i in range(self.stacked_convs):
cls_layers.append(
conv(
in_channel,
self.feat_channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
reg_layers.append(
conv(
in_channel,
self.feat_channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
in_channel = self.feat_channels
# SSD-Lite head
if self.use_depthwise:
cls_layers.append(
ConvModule(
in_channel,
in_channel,
3,
padding=1,
groups=in_channel,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
reg_layers.append(
ConvModule(
in_channel,
in_channel,
3,
padding=1,
groups=in_channel,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
cls_layers.append(
nn.Conv2d(
in_channel,
num_anchors * self.cls_out_channels,
kernel_size=1 if self.use_depthwise else 3,
padding=0 if self.use_depthwise else 1))
reg_layers.append(
nn.Conv2d(
in_channel,
num_anchors * 4,
kernel_size=1 if self.use_depthwise else 3,
padding=0 if self.use_depthwise else 1))
self.cls_convs.append(nn.Sequential(*cls_layers))
self.reg_convs.append(nn.Sequential(*reg_layers))

def forward(self, feats):
"""Forward features from the upstream network.

Expand Down
4 changes: 3 additions & 1 deletion mmdet/models/necks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from .nasfcos_fpn import NASFCOS_FPN
from .pafpn import PAFPN
from .rfp import RFP
from .ssd_neck import SSDNeck
from .yolo_neck import YOLOV3Neck

__all__ = [
'FPN', 'BFP', 'ChannelMapper', 'HRFPN', 'NASFPN', 'FPN_CARAFE', 'PAFPN',
'NASFCOS_FPN', 'RFP', 'YOLOV3Neck', 'FPG', 'DilatedEncoder', 'CTResNetNeck'
'NASFCOS_FPN', 'RFP', 'YOLOV3Neck', 'FPG', 'DilatedEncoder',
'CTResNetNeck', 'SSDNeck'
]
Loading