Source code for mmagic.models.archs.vgg
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional
import torch.nn as nn
from mmengine.model import BaseModule
from mmengine.model.weight_init import constant_init, xavier_init
from torch import Tensor
from mmagic.registry import MODELS
from ..archs.aspp import ASPP
@MODELS.register_module()
[docs]class VGG16(BaseModule):
"""Customized VGG16 Encoder.
A 1x1 conv is added after the original VGG16 conv layers. The indices of
max pooling layers are returned for unpooling layers in decoders.
Args:
in_channels (int): Number of input channels.
batch_norm (bool, optional): Whether use ``nn.BatchNorm2d``.
Default to False.
aspp (bool, optional): Whether use ASPP module after the last conv
layer. Default to False.
dilations (list[int], optional): Atrous rates of ASPP module.
Default to None.
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self,
in_channels: int,
batch_norm: Optional[bool] = False,
aspp: Optional[bool] = False,
dilations: Optional[List[int]] = None,
init_cfg: Optional[dict] = None):
super().__init__(init_cfg=init_cfg)
self.batch_norm = batch_norm
self.aspp = aspp
self.dilations = dilations
self.layer1 = self._make_layer(in_channels, 64, 2)
self.layer2 = self._make_layer(64, 128, 2)
self.layer3 = self._make_layer(128, 256, 3)
self.layer4 = self._make_layer(256, 512, 3)
self.layer5 = self._make_layer(512, 512, 3)
self.conv6 = nn.Conv2d(512, 512, kernel_size=1)
if self.batch_norm:
self.bn = nn.BatchNorm2d(512)
self.relu = nn.ReLU(inplace=True)
if self.aspp:
self.aspp = ASPP(512, dilations=self.dilations)
self.out_channels = 256
else:
self.out_channels = 512
[docs] def _make_layer(self, inplanes: int, planes: int,
convs_layers: int) -> nn.Module:
layers = []
for _ in range(convs_layers):
conv2d = nn.Conv2d(inplanes, planes, kernel_size=3, padding=1)
if self.batch_norm:
bn = nn.BatchNorm2d(planes)
layers += [conv2d, bn, nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
inplanes = planes
layers += [nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)]
return nn.Sequential(*layers)
[docs] def init_weights(self) -> None:
"""Init weights for the model."""
if self.init_cfg is not None:
super().init_weights()
else:
# Default initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
xavier_init(m)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
[docs] def forward(self, x: Tensor) -> Dict[str, Tensor]:
"""Forward function for ASPP module.
Args:
x (Tensor): Input tensor with shape (N, C, H, W).
Returns:
dict: Dict containing output tensor and maxpooling indices.
"""
out, max_idx_1 = self.layer1(x)
out, max_idx_2 = self.layer2(out)
out, max_idx_3 = self.layer3(out)
out, max_idx_4 = self.layer4(out)
out, max_idx_5 = self.layer5(out)
out = self.conv6(out)
if self.batch_norm:
out = self.bn(out)
out = self.relu(out)
if self.aspp:
out = self.aspp(out)
return {
'out': out,
'max_idx_1': max_idx_1,
'max_idx_2': max_idx_2,
'max_idx_3': max_idx_3,
'max_idx_4': max_idx_4,
'max_idx_5': max_idx_5
}