From 183f5180631aa07cdad721101036a06d5a20ee19 Mon Sep 17 00:00:00 2001 From: Mannat Singh Date: Fri, 7 Feb 2020 13:40:37 -0800 Subject: [PATCH] Add support to attach heads to DenseNets (#383) Summary: Pull Request resolved: https://github.com/facebookresearch/ClassyVision/pull/383 Added the following attachable blocks to the `DenseNet` implementation- Contains the following attachable blocks - - `block{block_idx}-{idx}` - `transition-{idx}` - `trunk_output` The `trunk_output` block is the final output of the `DenseNet`. This is where a `fully_connected` head will normally be attached. Also added an additional arg (`zero_init_bias`) to `FullyConnectedHead` since `DenseNets` init the fc bias to 0. Reviewed By: vreis Differential Revision: D19780372 fbshipit-source-id: f52bafc7a6d03f4f6678d2726918b23769ddeb3f --- .../densenet121_imagenet_classy_config.json | 13 ++- classy_vision/heads/fully_connected_head.py | 18 ++- classy_vision/models/densenet.py | 105 ++++++++++-------- test/models_densenet_test.py | 13 ++- 4 files changed, 97 insertions(+), 52 deletions(-) diff --git a/classy_vision/configs/imagenet/densenet121_imagenet_classy_config.json b/classy_vision/configs/imagenet/densenet121_imagenet_classy_config.json index 0fbf1f1f36..956e432df7 100644 --- a/classy_vision/configs/imagenet/densenet121_imagenet_classy_config.json +++ b/classy_vision/configs/imagenet/densenet121_imagenet_classy_config.json @@ -50,8 +50,17 @@ "model": { "name": "densenet", "num_blocks": [6, 12, 24, 16], - "num_classes": 1000, - "small_input": false + "small_input": false, + "heads": [ + { + "name": "fully_connected", + "unique_id": "default_head", + "num_classes": 1000, + "fork_block": "trunk_output", + "in_plane": 1024, + "zero_init_bias": true + } + ] }, "optimizer": { "name": "sgd", diff --git a/classy_vision/heads/fully_connected_head.py b/classy_vision/heads/fully_connected_head.py index 1377fd55e9..d9c32f78fb 100644 --- a/classy_vision/heads/fully_connected_head.py +++ b/classy_vision/heads/fully_connected_head.py @@ -18,7 +18,13 @@ class FullyConnectedHead(ClassyHead): layer (:class:`torch.nn.Linear`). """ - def __init__(self, unique_id: str, num_classes: int, in_plane: int): + def __init__( + self, + unique_id: str, + num_classes: int, + in_plane: int, + zero_init_bias: bool = False, + ): """Constructor for FullyConnectedHead Args: @@ -37,6 +43,9 @@ def __init__(self, unique_id: str, num_classes: int, in_plane: int): self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = None if num_classes is None else nn.Linear(in_plane, num_classes) + if zero_init_bias: + self.fc.bias.data.zero_() + @classmethod def from_config(cls, config: Dict[str, Any]) -> "FullyConnectedHead": """Instantiates a FullyConnectedHead from a configuration. @@ -50,7 +59,12 @@ def from_config(cls, config: Dict[str, Any]) -> "FullyConnectedHead": """ num_classes = config.get("num_classes", None) in_plane = config["in_plane"] - return cls(config["unique_id"], num_classes, in_plane) + return cls( + config["unique_id"], + num_classes, + in_plane, + zero_init_bias=config.get("zero_init_bias", False), + ) def forward(self, x): # perform average pooling: diff --git a/classy_vision/models/densenet.py b/classy_vision/models/densenet.py index 0efdeea622..1e914a5bad 100644 --- a/classy_vision/models/densenet.py +++ b/classy_vision/models/densenet.py @@ -62,29 +62,6 @@ def forward(self, x): return torch.cat([x, new_features], 1) -class _DenseBlock(nn.Sequential): - """ - Block of densely connected layers at same resolution. - """ - - def __init__(self, num_layers, in_planes, growth_rate=32, expansion=4): - - # assertions: - assert is_pos_int(in_planes) - assert is_pos_int(growth_rate) - assert is_pos_int(expansion) - - # create block of dense layers at same resolution: - super(_DenseBlock, self).__init__() - for idx in range(num_layers): - layer = _DenseLayer( - in_planes + idx * growth_rate, - growth_rate=growth_rate, - expansion=expansion, - ) - self.add_module("denselayer-%d" % (idx + 1), layer) - - class _Transition(nn.Sequential): """ Transition layer to reduce spatial resolution. @@ -130,6 +107,13 @@ def __init__( Set `final_bn_relu` to `False` to exclude the final batchnorm and ReLU layers. These settings are useful when training Siamese networks. + + Contains the following attachable blocks: + block{block_idx}-{idx}: This is the output of each dense block, + indexed by the block index and the index of the dense layer + transition-{idx}: This is the output of the transition layers + trunk_output: The final output of the `DenseNet`. This is + where a `fully_connected` head is normally attached. """ super().__init__() @@ -165,31 +149,28 @@ def __init__( ) # loop over spatial resolutions: num_planes = init_planes - self.features = nn.Sequential() + blocks = [] for idx, num_layers in enumerate(num_blocks): - - # add dense block: - block = _DenseBlock( - num_layers, num_planes, growth_rate=growth_rate, expansion=expansion + # add dense block + block = self._make_dense_block( + num_layers, + num_planes, + idx, + growth_rate=growth_rate, + expansion=expansion, ) - self.features.add_module("denseblock-%d" % (idx + 1), block) + blocks.append(block) num_planes = num_planes + num_layers * growth_rate # add transition layer: if idx != len(num_blocks) - 1: trans = _Transition(num_planes, num_planes // 2) - self.features.add_module("transition-%d" % (idx + 1), trans) + blocks.append(self.build_attachable_block(f"transition-{idx}", trans)) num_planes = num_planes // 2 - # final batch normalization: - if final_bn_relu: - self.features.add_module("norm-final", nn.BatchNorm2d(num_planes)) - self.features.add_module("relu-final", nn.ReLU(inplace=INPLACE)) + blocks.append(self._make_trunk_output_block(num_planes, final_bn_relu)) - # final classifier: - self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.fc = None if num_classes is None else nn.Linear(num_planes, num_classes) - self.num_planes = num_planes + self.features = nn.Sequential(*blocks) # initialize weights of convolutional and batchnorm layers: for m in self.modules(): @@ -202,6 +183,36 @@ def __init__( elif isinstance(m, nn.Linear): m.bias.data.zero_() + def _make_trunk_output_block(self, num_planes, final_bn_relu): + layers = nn.Sequential() + if final_bn_relu: + # final batch normalization: + layers.add_module("norm-final", nn.BatchNorm2d(num_planes)) + layers.add_module("relu-final", nn.ReLU(inplace=INPLACE)) + return self.build_attachable_block("trunk_output", layers) + + def _make_dense_block( + self, num_layers, in_planes, block_idx, growth_rate=32, expansion=4 + ): + assert is_pos_int(in_planes) + assert is_pos_int(growth_rate) + assert is_pos_int(expansion) + + # create a block of dense layers at same resolution: + layers = [] + for idx in range(num_layers): + layers.append( + self.build_attachable_block( + f"block{block_idx}-{idx}", + _DenseLayer( + in_planes + idx * growth_rate, + growth_rate=growth_rate, + expansion=expansion, + ), + ) + ) + return nn.Sequential(*layers) + @classmethod def from_config(cls, config: Dict[str, Any]) -> "DenseNet": """Instantiates a DenseNet from a configuration. @@ -234,14 +245,16 @@ def forward(self, x): # evaluate all dense blocks: out = self.features(out) - # perform average pooling: - out = self.avgpool(out) - - # final classifier: - out = out.view(out.size(0), -1) - if self.fc is not None: - out = self.fc(out) - return out + # By default the classification layer is implemented as one head on top + # of the last block. The head is automatically computed right after the + # last block. + head_outputs = self.execute_heads() + if len(head_outputs) == 0: + raise Exception("Expecting at least one head that generates output") + elif len(head_outputs) == 1: + return list(head_outputs.values())[0] + else: + return head_outputs def get_optimizer_params(self): # use weight decay on BatchNorm for DenseNets diff --git a/test/models_densenet_test.py b/test/models_densenet_test.py index eba18a1bed..dde30b7e52 100644 --- a/test/models_densenet_test.py +++ b/test/models_densenet_test.py @@ -15,12 +15,21 @@ "small_densenet": { "name": "densenet", "num_blocks": [1, 1, 1, 1], - "num_classes": 1000, "init_planes": 4, "growth_rate": 32, "expansion": 4, "final_bn_relu": True, "small_input": True, + "heads": [ + { + "name": "fully_connected", + "unique_id": "default_head", + "num_classes": 1000, + "fork_block": "trunk_output", + "in_plane": 60, + "zero_init_bias": True, + } + ], } } @@ -49,5 +58,5 @@ def _test_model(self, model_config): compare_model_state(self, state, new_state, check_heads=True) - def test_small_resnet(self): + def test_small_densenet(self): self._test_model(MODELS["small_densenet"])