| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542 |
- """
- SEResNet implementation from Cadene's pretrained models
- https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/senet.py
- Additional credit to https://github.com/creafz
- Original model: https://github.com/hujie-frank/SENet
- ResNet code gently borrowed from
- https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
- FIXME I'm deprecating this model and moving them to ResNet as I don't want to maintain duplicate
- support for extras like dilation, switchable BN/activations, feature extraction, etc that don't exist here.
- """
- import math
- from collections import OrderedDict
- from typing import Type, Optional, Tuple
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
- from timm.layers import create_classifier
- from ._builder import build_model_with_cfg
- from ._registry import register_model, generate_default_cfgs
- __all__ = ['SENet']
- def _weight_init(m):
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
- elif isinstance(m, nn.BatchNorm2d):
- nn.init.constant_(m.weight, 1.)
- nn.init.constant_(m.bias, 0.)
- class SEModule(nn.Module):
- def __init__(self, channels: int, reduction: int, device=None, dtype=None):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, **dd)
- self.relu = nn.ReLU(inplace=True)
- self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, **dd)
- self.sigmoid = nn.Sigmoid()
- def forward(self, x):
- module_input = x
- x = x.mean((2, 3), keepdim=True)
- x = self.fc1(x)
- x = self.relu(x)
- x = self.fc2(x)
- x = self.sigmoid(x)
- return module_input * x
- class Bottleneck(nn.Module):
- """
- Base class for bottlenecks that implements `forward()` method.
- """
- def forward(self, x):
- shortcut = x
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu(out)
- out = self.conv2(out)
- out = self.bn2(out)
- out = self.relu(out)
- out = self.conv3(out)
- out = self.bn3(out)
- if self.downsample is not None:
- shortcut = self.downsample(x)
- out = self.se_module(out) + shortcut
- out = self.relu(out)
- return out
- class SEBottleneck(Bottleneck):
- """
- Bottleneck for SENet154.
- """
- expansion = 4
- def __init__(
- self,
- inplanes: int,
- planes: int,
- groups: int,
- reduction: int,
- stride: int = 1,
- downsample: Optional[nn.Module] = None,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False, **dd)
- self.bn1 = nn.BatchNorm2d(planes * 2, **dd)
- self.conv2 = nn.Conv2d(
- planes * 2,
- planes * 4,
- kernel_size=3,
- stride=stride,
- padding=1,
- groups=groups,
- bias=False,
- **dd,
- )
- self.bn2 = nn.BatchNorm2d(planes * 4, **dd)
- self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1, bias=False, **dd)
- self.bn3 = nn.BatchNorm2d(planes * 4, **dd)
- self.relu = nn.ReLU(inplace=True)
- self.se_module = SEModule(planes * 4, reduction=reduction, **dd)
- self.downsample = downsample
- self.stride = stride
- class SEResNetBottleneck(Bottleneck):
- """
- ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe
- implementation and uses `stride=stride` in `conv1` and not in `conv2`
- (the latter is used in the torchvision implementation of ResNet).
- """
- expansion = 4
- def __init__(
- self,
- inplanes: int,
- planes: int,
- groups: int,
- reduction: int,
- stride: int = 1,
- downsample: Optional[nn.Module] = None,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False, stride=stride, **dd)
- self.bn1 = nn.BatchNorm2d(planes, **dd)
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, groups=groups, bias=False, **dd)
- self.bn2 = nn.BatchNorm2d(planes, **dd)
- self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False, **dd)
- self.bn3 = nn.BatchNorm2d(planes * 4, **dd)
- self.relu = nn.ReLU(inplace=True)
- self.se_module = SEModule(planes * 4, reduction=reduction, **dd)
- self.downsample = downsample
- self.stride = stride
- class SEResNeXtBottleneck(Bottleneck):
- """
- ResNeXt bottleneck type C with a Squeeze-and-Excitation module.
- """
- expansion = 4
- def __init__(
- self,
- inplanes: int,
- planes: int,
- groups: int,
- reduction: int,
- stride: int = 1,
- downsample: Optional[nn.Module] = None,
- base_width: int = 4,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- width = math.floor(planes * (base_width / 64)) * groups
- self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, stride=1, **dd)
- self.bn1 = nn.BatchNorm2d(width, **dd)
- self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False, **dd)
- self.bn2 = nn.BatchNorm2d(width, **dd)
- self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False, **dd)
- self.bn3 = nn.BatchNorm2d(planes * 4, **dd)
- self.relu = nn.ReLU(inplace=True)
- self.se_module = SEModule(planes * 4, reduction=reduction, **dd)
- self.downsample = downsample
- self.stride = stride
- class SEResNetBlock(nn.Module):
- expansion = 1
- def __init__(
- self,
- inplanes: int,
- planes: int,
- groups: int,
- reduction: int,
- stride: int = 1,
- downsample: Optional[nn.Module] = None,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, padding=1, stride=stride, bias=False, **dd)
- self.bn1 = nn.BatchNorm2d(planes, **dd)
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, groups=groups, bias=False, **dd)
- self.bn2 = nn.BatchNorm2d(planes, **dd)
- self.relu = nn.ReLU(inplace=True)
- self.se_module = SEModule(planes, reduction=reduction, **dd)
- self.downsample = downsample
- self.stride = stride
- def forward(self, x):
- shortcut = x
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu(out)
- out = self.conv2(out)
- out = self.bn2(out)
- out = self.relu(out)
- if self.downsample is not None:
- shortcut = self.downsample(x)
- out = self.se_module(out) + shortcut
- out = self.relu(out)
- return out
- class SENet(nn.Module):
- def __init__(
- self,
- block: Type[nn.Module],
- layers: Tuple[int, ...],
- groups: int,
- reduction: int,
- drop_rate: float = 0.2,
- in_chans: int = 3,
- inplanes: int = 64,
- input_3x3: bool = False,
- downsample_kernel_size: int = 1,
- downsample_padding: int = 0,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- device=None,
- dtype=None,
- ):
- """
- Parameters
- ----------
- block (nn.Module): Bottleneck class.
- - For SENet154: SEBottleneck
- - For SE-ResNet models: SEResNetBottleneck
- - For SE-ResNeXt models: SEResNeXtBottleneck
- layers (list of ints): Number of residual blocks for 4 layers of the
- network (layer1...layer4).
- groups (int): Number of groups for the 3x3 convolution in each
- bottleneck block.
- - For SENet154: 64
- - For SE-ResNet models: 1
- - For SE-ResNeXt models: 32
- reduction (int): Reduction ratio for Squeeze-and-Excitation modules.
- - For all models: 16
- dropout_p (float or None): Drop probability for the Dropout layer.
- If `None` the Dropout layer is not used.
- - For SENet154: 0.2
- - For SE-ResNet models: None
- - For SE-ResNeXt models: None
- inplanes (int): Number of input channels for layer1.
- - For SENet154: 128
- - For SE-ResNet models: 64
- - For SE-ResNeXt models: 64
- input_3x3 (bool): If `True`, use three 3x3 convolutions instead of
- a single 7x7 convolution in layer0.
- - For SENet154: True
- - For SE-ResNet models: False
- - For SE-ResNeXt models: False
- downsample_kernel_size (int): Kernel size for downsampling convolutions
- in layer2, layer3 and layer4.
- - For SENet154: 3
- - For SE-ResNet models: 1
- - For SE-ResNeXt models: 1
- downsample_padding (int): Padding for downsampling convolutions in
- layer2, layer3 and layer4.
- - For SENet154: 1
- - For SE-ResNet models: 0
- - For SE-ResNeXt models: 0
- num_classes (int): Number of outputs in `last_linear` layer.
- - For all models: 1000
- """
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- self.inplanes = inplanes
- self.num_classes = num_classes
- self.in_chans = in_chans
- self.drop_rate = drop_rate
- if input_3x3:
- layer0_modules = [
- ('conv1', nn.Conv2d(in_chans, 64, 3, stride=2, padding=1, bias=False, **dd)),
- ('bn1', nn.BatchNorm2d(64, **dd)),
- ('relu1', nn.ReLU(inplace=True)),
- ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False, **dd)),
- ('bn2', nn.BatchNorm2d(64, **dd)),
- ('relu2', nn.ReLU(inplace=True)),
- ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1, bias=False, **dd)),
- ('bn3', nn.BatchNorm2d(inplanes, **dd)),
- ('relu3', nn.ReLU(inplace=True)),
- ]
- else:
- layer0_modules = [
- ('conv1', nn.Conv2d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False, **dd)),
- ('bn1', nn.BatchNorm2d(inplanes, **dd)),
- ('relu1', nn.ReLU(inplace=True)),
- ]
- self.layer0 = nn.Sequential(OrderedDict(layer0_modules))
- # To preserve compatibility with Caffe weights `ceil_mode=True` is used instead of `padding=1`.
- self.pool0 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
- self.feature_info = [dict(num_chs=inplanes, reduction=2, module='layer0')]
- self.layer1 = self._make_layer(
- block,
- planes=64,
- blocks=layers[0],
- groups=groups,
- reduction=reduction,
- downsample_kernel_size=1,
- downsample_padding=0,
- **dd,
- )
- self.feature_info += [dict(num_chs=64 * block.expansion, reduction=4, module='layer1')]
- self.layer2 = self._make_layer(
- block,
- planes=128,
- blocks=layers[1],
- stride=2,
- groups=groups,
- reduction=reduction,
- downsample_kernel_size=downsample_kernel_size,
- downsample_padding=downsample_padding,
- **dd,
- )
- self.feature_info += [dict(num_chs=128 * block.expansion, reduction=8, module='layer2')]
- self.layer3 = self._make_layer(
- block,
- planes=256,
- blocks=layers[2],
- stride=2,
- groups=groups,
- reduction=reduction,
- downsample_kernel_size=downsample_kernel_size,
- downsample_padding=downsample_padding,
- **dd,
- )
- self.feature_info += [dict(num_chs=256 * block.expansion, reduction=16, module='layer3')]
- self.layer4 = self._make_layer(
- block,
- planes=512,
- blocks=layers[3],
- stride=2,
- groups=groups,
- reduction=reduction,
- downsample_kernel_size=downsample_kernel_size,
- downsample_padding=downsample_padding,
- **dd,
- )
- self.feature_info += [dict(num_chs=512 * block.expansion, reduction=32, module='layer4')]
- self.num_features = self.head_hidden_size = 512 * block.expansion
- self.global_pool, self.last_linear = create_classifier(
- self.num_features,
- self.num_classes,
- pool_type=global_pool,
- **dd,
- )
- for m in self.modules():
- _weight_init(m)
- def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
- downsample_kernel_size=1, downsample_padding=0, device=None, dtype=None):
- dd = {'device': device, 'dtype': dtype}
- downsample = None
- if stride != 1 or self.inplanes != planes * block.expansion:
- downsample = nn.Sequential(
- nn.Conv2d(
- self.inplanes, planes * block.expansion, kernel_size=downsample_kernel_size,
- stride=stride, padding=downsample_padding, bias=False, **dd),
- nn.BatchNorm2d(planes * block.expansion, **dd),
- )
- layers = [block(self.inplanes, planes, groups, reduction, stride, downsample, **dd)]
- self.inplanes = planes * block.expansion
- for i in range(1, blocks):
- layers.append(block(self.inplanes, planes, groups, reduction, **dd))
- return nn.Sequential(*layers)
- @torch.jit.ignore
- def group_matcher(self, coarse=False):
- matcher = dict(stem=r'^layer0', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+)\.(\d+)')
- return matcher
- @torch.jit.ignore
- def set_grad_checkpointing(self, enable=True):
- assert not enable, 'gradient checkpointing not supported'
- @torch.jit.ignore
- def get_classifier(self) -> nn.Module:
- return self.last_linear
- def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
- self.num_classes = num_classes
- self.global_pool, self.last_linear = create_classifier(
- self.num_features, self.num_classes, pool_type=global_pool)
- def forward_features(self, x):
- x = self.layer0(x)
- x = self.pool0(x)
- x = self.layer1(x)
- x = self.layer2(x)
- x = self.layer3(x)
- x = self.layer4(x)
- return x
- def forward_head(self, x, pre_logits: bool = False):
- x = self.global_pool(x)
- if self.drop_rate > 0.:
- x = F.dropout(x, p=self.drop_rate, training=self.training)
- return x if pre_logits else self.last_linear(x)
- def forward(self, x):
- x = self.forward_features(x)
- x = self.forward_head(x)
- return x
- def _create_senet(variant, pretrained=False, **kwargs):
- return build_model_with_cfg(SENet, variant, pretrained, **kwargs)
- def _cfg(url='', **kwargs):
- return {
- 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
- 'crop_pct': 0.875, 'interpolation': 'bilinear',
- 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
- 'first_conv': 'layer0.conv1', 'classifier': 'last_linear', 'license': 'apache-2.0',
- **kwargs
- }
- default_cfgs = generate_default_cfgs({
- 'legacy_senet154.in1k': _cfg(
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/legacy_senet154-e9eb9fe6.pth'),
- 'legacy_seresnet18.in1k': _cfg(
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet18-4bb0ce65.pth',
- interpolation='bicubic'),
- 'legacy_seresnet34.in1k': _cfg(
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet34-a4004e63.pth'),
- 'legacy_seresnet50.in1k': _cfg(
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet50-ce0d4300.pth'),
- 'legacy_seresnet101.in1k': _cfg(
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet101-7e38fcc6.pth'),
- 'legacy_seresnet152.in1k': _cfg(
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet152-d17c99b7.pth'),
- 'legacy_seresnext26_32x4d.in1k': _cfg(
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26_32x4d-65ebdb501.pth',
- interpolation='bicubic'),
- 'legacy_seresnext50_32x4d.in1k': _cfg(
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/legacy_se_resnext50_32x4d-f3651bad.pth'),
- 'legacy_seresnext101_32x4d.in1k': _cfg(
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/legacy_se_resnext101_32x4d-37725eac.pth'),
- })
- @register_model
- def legacy_seresnet18(pretrained=False, **kwargs) -> SENet:
- model_args = dict(
- block=SEResNetBlock, layers=[2, 2, 2, 2], groups=1, reduction=16)
- return _create_senet('legacy_seresnet18', pretrained, **dict(model_args, **kwargs))
- @register_model
- def legacy_seresnet34(pretrained=False, **kwargs) -> SENet:
- model_args = dict(
- block=SEResNetBlock, layers=[3, 4, 6, 3], groups=1, reduction=16)
- return _create_senet('legacy_seresnet34', pretrained, **dict(model_args, **kwargs))
- @register_model
- def legacy_seresnet50(pretrained=False, **kwargs) -> SENet:
- model_args = dict(
- block=SEResNetBottleneck, layers=[3, 4, 6, 3], groups=1, reduction=16)
- return _create_senet('legacy_seresnet50', pretrained, **dict(model_args, **kwargs))
- @register_model
- def legacy_seresnet101(pretrained=False, **kwargs) -> SENet:
- model_args = dict(
- block=SEResNetBottleneck, layers=[3, 4, 23, 3], groups=1, reduction=16)
- return _create_senet('legacy_seresnet101', pretrained, **dict(model_args, **kwargs))
- @register_model
- def legacy_seresnet152(pretrained=False, **kwargs) -> SENet:
- model_args = dict(
- block=SEResNetBottleneck, layers=[3, 8, 36, 3], groups=1, reduction=16)
- return _create_senet('legacy_seresnet152', pretrained, **dict(model_args, **kwargs))
- @register_model
- def legacy_senet154(pretrained=False, **kwargs) -> SENet:
- model_args = dict(
- block=SEBottleneck, layers=[3, 8, 36, 3], groups=64, reduction=16,
- downsample_kernel_size=3, downsample_padding=1, inplanes=128, input_3x3=True)
- return _create_senet('legacy_senet154', pretrained, **dict(model_args, **kwargs))
- @register_model
- def legacy_seresnext26_32x4d(pretrained=False, **kwargs) -> SENet:
- model_args = dict(
- block=SEResNeXtBottleneck, layers=[2, 2, 2, 2], groups=32, reduction=16)
- return _create_senet('legacy_seresnext26_32x4d', pretrained, **dict(model_args, **kwargs))
- @register_model
- def legacy_seresnext50_32x4d(pretrained=False, **kwargs) -> SENet:
- model_args = dict(
- block=SEResNeXtBottleneck, layers=[3, 4, 6, 3], groups=32, reduction=16)
- return _create_senet('legacy_seresnext50_32x4d', pretrained, **dict(model_args, **kwargs))
- @register_model
- def legacy_seresnext101_32x4d(pretrained=False, **kwargs) -> SENet:
- model_args = dict(
- block=SEResNeXtBottleneck, layers=[3, 4, 23, 3], groups=32, reduction=16)
- return _create_senet('legacy_seresnext101_32x4d', pretrained, **dict(model_args, **kwargs))
|