senet.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542
  1. """
  2. SEResNet implementation from Cadene's pretrained models
  3. https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/senet.py
  4. Additional credit to https://github.com/creafz
  5. Original model: https://github.com/hujie-frank/SENet
  6. ResNet code gently borrowed from
  7. https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
  8. FIXME I'm deprecating this model and moving them to ResNet as I don't want to maintain duplicate
  9. support for extras like dilation, switchable BN/activations, feature extraction, etc that don't exist here.
  10. """
  11. import math
  12. from collections import OrderedDict
  13. from typing import Type, Optional, Tuple
  14. import torch
  15. import torch.nn as nn
  16. import torch.nn.functional as F
  17. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  18. from timm.layers import create_classifier
  19. from ._builder import build_model_with_cfg
  20. from ._registry import register_model, generate_default_cfgs
  21. __all__ = ['SENet']
  22. def _weight_init(m):
  23. if isinstance(m, nn.Conv2d):
  24. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  25. elif isinstance(m, nn.BatchNorm2d):
  26. nn.init.constant_(m.weight, 1.)
  27. nn.init.constant_(m.bias, 0.)
  28. class SEModule(nn.Module):
  29. def __init__(self, channels: int, reduction: int, device=None, dtype=None):
  30. dd = {'device': device, 'dtype': dtype}
  31. super().__init__()
  32. self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, **dd)
  33. self.relu = nn.ReLU(inplace=True)
  34. self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, **dd)
  35. self.sigmoid = nn.Sigmoid()
  36. def forward(self, x):
  37. module_input = x
  38. x = x.mean((2, 3), keepdim=True)
  39. x = self.fc1(x)
  40. x = self.relu(x)
  41. x = self.fc2(x)
  42. x = self.sigmoid(x)
  43. return module_input * x
  44. class Bottleneck(nn.Module):
  45. """
  46. Base class for bottlenecks that implements `forward()` method.
  47. """
  48. def forward(self, x):
  49. shortcut = x
  50. out = self.conv1(x)
  51. out = self.bn1(out)
  52. out = self.relu(out)
  53. out = self.conv2(out)
  54. out = self.bn2(out)
  55. out = self.relu(out)
  56. out = self.conv3(out)
  57. out = self.bn3(out)
  58. if self.downsample is not None:
  59. shortcut = self.downsample(x)
  60. out = self.se_module(out) + shortcut
  61. out = self.relu(out)
  62. return out
  63. class SEBottleneck(Bottleneck):
  64. """
  65. Bottleneck for SENet154.
  66. """
  67. expansion = 4
  68. def __init__(
  69. self,
  70. inplanes: int,
  71. planes: int,
  72. groups: int,
  73. reduction: int,
  74. stride: int = 1,
  75. downsample: Optional[nn.Module] = None,
  76. device=None,
  77. dtype=None,
  78. ):
  79. dd = {'device': device, 'dtype': dtype}
  80. super().__init__()
  81. self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False, **dd)
  82. self.bn1 = nn.BatchNorm2d(planes * 2, **dd)
  83. self.conv2 = nn.Conv2d(
  84. planes * 2,
  85. planes * 4,
  86. kernel_size=3,
  87. stride=stride,
  88. padding=1,
  89. groups=groups,
  90. bias=False,
  91. **dd,
  92. )
  93. self.bn2 = nn.BatchNorm2d(planes * 4, **dd)
  94. self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1, bias=False, **dd)
  95. self.bn3 = nn.BatchNorm2d(planes * 4, **dd)
  96. self.relu = nn.ReLU(inplace=True)
  97. self.se_module = SEModule(planes * 4, reduction=reduction, **dd)
  98. self.downsample = downsample
  99. self.stride = stride
  100. class SEResNetBottleneck(Bottleneck):
  101. """
  102. ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe
  103. implementation and uses `stride=stride` in `conv1` and not in `conv2`
  104. (the latter is used in the torchvision implementation of ResNet).
  105. """
  106. expansion = 4
  107. def __init__(
  108. self,
  109. inplanes: int,
  110. planes: int,
  111. groups: int,
  112. reduction: int,
  113. stride: int = 1,
  114. downsample: Optional[nn.Module] = None,
  115. device=None,
  116. dtype=None,
  117. ):
  118. dd = {'device': device, 'dtype': dtype}
  119. super().__init__()
  120. self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False, stride=stride, **dd)
  121. self.bn1 = nn.BatchNorm2d(planes, **dd)
  122. self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, groups=groups, bias=False, **dd)
  123. self.bn2 = nn.BatchNorm2d(planes, **dd)
  124. self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False, **dd)
  125. self.bn3 = nn.BatchNorm2d(planes * 4, **dd)
  126. self.relu = nn.ReLU(inplace=True)
  127. self.se_module = SEModule(planes * 4, reduction=reduction, **dd)
  128. self.downsample = downsample
  129. self.stride = stride
  130. class SEResNeXtBottleneck(Bottleneck):
  131. """
  132. ResNeXt bottleneck type C with a Squeeze-and-Excitation module.
  133. """
  134. expansion = 4
  135. def __init__(
  136. self,
  137. inplanes: int,
  138. planes: int,
  139. groups: int,
  140. reduction: int,
  141. stride: int = 1,
  142. downsample: Optional[nn.Module] = None,
  143. base_width: int = 4,
  144. device=None,
  145. dtype=None,
  146. ):
  147. dd = {'device': device, 'dtype': dtype}
  148. super().__init__()
  149. width = math.floor(planes * (base_width / 64)) * groups
  150. self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, stride=1, **dd)
  151. self.bn1 = nn.BatchNorm2d(width, **dd)
  152. self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False, **dd)
  153. self.bn2 = nn.BatchNorm2d(width, **dd)
  154. self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False, **dd)
  155. self.bn3 = nn.BatchNorm2d(planes * 4, **dd)
  156. self.relu = nn.ReLU(inplace=True)
  157. self.se_module = SEModule(planes * 4, reduction=reduction, **dd)
  158. self.downsample = downsample
  159. self.stride = stride
  160. class SEResNetBlock(nn.Module):
  161. expansion = 1
  162. def __init__(
  163. self,
  164. inplanes: int,
  165. planes: int,
  166. groups: int,
  167. reduction: int,
  168. stride: int = 1,
  169. downsample: Optional[nn.Module] = None,
  170. device=None,
  171. dtype=None,
  172. ):
  173. dd = {'device': device, 'dtype': dtype}
  174. super().__init__()
  175. self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, padding=1, stride=stride, bias=False, **dd)
  176. self.bn1 = nn.BatchNorm2d(planes, **dd)
  177. self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, groups=groups, bias=False, **dd)
  178. self.bn2 = nn.BatchNorm2d(planes, **dd)
  179. self.relu = nn.ReLU(inplace=True)
  180. self.se_module = SEModule(planes, reduction=reduction, **dd)
  181. self.downsample = downsample
  182. self.stride = stride
  183. def forward(self, x):
  184. shortcut = x
  185. out = self.conv1(x)
  186. out = self.bn1(out)
  187. out = self.relu(out)
  188. out = self.conv2(out)
  189. out = self.bn2(out)
  190. out = self.relu(out)
  191. if self.downsample is not None:
  192. shortcut = self.downsample(x)
  193. out = self.se_module(out) + shortcut
  194. out = self.relu(out)
  195. return out
  196. class SENet(nn.Module):
  197. def __init__(
  198. self,
  199. block: Type[nn.Module],
  200. layers: Tuple[int, ...],
  201. groups: int,
  202. reduction: int,
  203. drop_rate: float = 0.2,
  204. in_chans: int = 3,
  205. inplanes: int = 64,
  206. input_3x3: bool = False,
  207. downsample_kernel_size: int = 1,
  208. downsample_padding: int = 0,
  209. num_classes: int = 1000,
  210. global_pool: str = 'avg',
  211. device=None,
  212. dtype=None,
  213. ):
  214. """
  215. Parameters
  216. ----------
  217. block (nn.Module): Bottleneck class.
  218. - For SENet154: SEBottleneck
  219. - For SE-ResNet models: SEResNetBottleneck
  220. - For SE-ResNeXt models: SEResNeXtBottleneck
  221. layers (list of ints): Number of residual blocks for 4 layers of the
  222. network (layer1...layer4).
  223. groups (int): Number of groups for the 3x3 convolution in each
  224. bottleneck block.
  225. - For SENet154: 64
  226. - For SE-ResNet models: 1
  227. - For SE-ResNeXt models: 32
  228. reduction (int): Reduction ratio for Squeeze-and-Excitation modules.
  229. - For all models: 16
  230. dropout_p (float or None): Drop probability for the Dropout layer.
  231. If `None` the Dropout layer is not used.
  232. - For SENet154: 0.2
  233. - For SE-ResNet models: None
  234. - For SE-ResNeXt models: None
  235. inplanes (int): Number of input channels for layer1.
  236. - For SENet154: 128
  237. - For SE-ResNet models: 64
  238. - For SE-ResNeXt models: 64
  239. input_3x3 (bool): If `True`, use three 3x3 convolutions instead of
  240. a single 7x7 convolution in layer0.
  241. - For SENet154: True
  242. - For SE-ResNet models: False
  243. - For SE-ResNeXt models: False
  244. downsample_kernel_size (int): Kernel size for downsampling convolutions
  245. in layer2, layer3 and layer4.
  246. - For SENet154: 3
  247. - For SE-ResNet models: 1
  248. - For SE-ResNeXt models: 1
  249. downsample_padding (int): Padding for downsampling convolutions in
  250. layer2, layer3 and layer4.
  251. - For SENet154: 1
  252. - For SE-ResNet models: 0
  253. - For SE-ResNeXt models: 0
  254. num_classes (int): Number of outputs in `last_linear` layer.
  255. - For all models: 1000
  256. """
  257. super().__init__()
  258. dd = {'device': device, 'dtype': dtype}
  259. self.inplanes = inplanes
  260. self.num_classes = num_classes
  261. self.in_chans = in_chans
  262. self.drop_rate = drop_rate
  263. if input_3x3:
  264. layer0_modules = [
  265. ('conv1', nn.Conv2d(in_chans, 64, 3, stride=2, padding=1, bias=False, **dd)),
  266. ('bn1', nn.BatchNorm2d(64, **dd)),
  267. ('relu1', nn.ReLU(inplace=True)),
  268. ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False, **dd)),
  269. ('bn2', nn.BatchNorm2d(64, **dd)),
  270. ('relu2', nn.ReLU(inplace=True)),
  271. ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1, bias=False, **dd)),
  272. ('bn3', nn.BatchNorm2d(inplanes, **dd)),
  273. ('relu3', nn.ReLU(inplace=True)),
  274. ]
  275. else:
  276. layer0_modules = [
  277. ('conv1', nn.Conv2d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False, **dd)),
  278. ('bn1', nn.BatchNorm2d(inplanes, **dd)),
  279. ('relu1', nn.ReLU(inplace=True)),
  280. ]
  281. self.layer0 = nn.Sequential(OrderedDict(layer0_modules))
  282. # To preserve compatibility with Caffe weights `ceil_mode=True` is used instead of `padding=1`.
  283. self.pool0 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
  284. self.feature_info = [dict(num_chs=inplanes, reduction=2, module='layer0')]
  285. self.layer1 = self._make_layer(
  286. block,
  287. planes=64,
  288. blocks=layers[0],
  289. groups=groups,
  290. reduction=reduction,
  291. downsample_kernel_size=1,
  292. downsample_padding=0,
  293. **dd,
  294. )
  295. self.feature_info += [dict(num_chs=64 * block.expansion, reduction=4, module='layer1')]
  296. self.layer2 = self._make_layer(
  297. block,
  298. planes=128,
  299. blocks=layers[1],
  300. stride=2,
  301. groups=groups,
  302. reduction=reduction,
  303. downsample_kernel_size=downsample_kernel_size,
  304. downsample_padding=downsample_padding,
  305. **dd,
  306. )
  307. self.feature_info += [dict(num_chs=128 * block.expansion, reduction=8, module='layer2')]
  308. self.layer3 = self._make_layer(
  309. block,
  310. planes=256,
  311. blocks=layers[2],
  312. stride=2,
  313. groups=groups,
  314. reduction=reduction,
  315. downsample_kernel_size=downsample_kernel_size,
  316. downsample_padding=downsample_padding,
  317. **dd,
  318. )
  319. self.feature_info += [dict(num_chs=256 * block.expansion, reduction=16, module='layer3')]
  320. self.layer4 = self._make_layer(
  321. block,
  322. planes=512,
  323. blocks=layers[3],
  324. stride=2,
  325. groups=groups,
  326. reduction=reduction,
  327. downsample_kernel_size=downsample_kernel_size,
  328. downsample_padding=downsample_padding,
  329. **dd,
  330. )
  331. self.feature_info += [dict(num_chs=512 * block.expansion, reduction=32, module='layer4')]
  332. self.num_features = self.head_hidden_size = 512 * block.expansion
  333. self.global_pool, self.last_linear = create_classifier(
  334. self.num_features,
  335. self.num_classes,
  336. pool_type=global_pool,
  337. **dd,
  338. )
  339. for m in self.modules():
  340. _weight_init(m)
  341. def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
  342. downsample_kernel_size=1, downsample_padding=0, device=None, dtype=None):
  343. dd = {'device': device, 'dtype': dtype}
  344. downsample = None
  345. if stride != 1 or self.inplanes != planes * block.expansion:
  346. downsample = nn.Sequential(
  347. nn.Conv2d(
  348. self.inplanes, planes * block.expansion, kernel_size=downsample_kernel_size,
  349. stride=stride, padding=downsample_padding, bias=False, **dd),
  350. nn.BatchNorm2d(planes * block.expansion, **dd),
  351. )
  352. layers = [block(self.inplanes, planes, groups, reduction, stride, downsample, **dd)]
  353. self.inplanes = planes * block.expansion
  354. for i in range(1, blocks):
  355. layers.append(block(self.inplanes, planes, groups, reduction, **dd))
  356. return nn.Sequential(*layers)
  357. @torch.jit.ignore
  358. def group_matcher(self, coarse=False):
  359. matcher = dict(stem=r'^layer0', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+)\.(\d+)')
  360. return matcher
  361. @torch.jit.ignore
  362. def set_grad_checkpointing(self, enable=True):
  363. assert not enable, 'gradient checkpointing not supported'
  364. @torch.jit.ignore
  365. def get_classifier(self) -> nn.Module:
  366. return self.last_linear
  367. def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
  368. self.num_classes = num_classes
  369. self.global_pool, self.last_linear = create_classifier(
  370. self.num_features, self.num_classes, pool_type=global_pool)
  371. def forward_features(self, x):
  372. x = self.layer0(x)
  373. x = self.pool0(x)
  374. x = self.layer1(x)
  375. x = self.layer2(x)
  376. x = self.layer3(x)
  377. x = self.layer4(x)
  378. return x
  379. def forward_head(self, x, pre_logits: bool = False):
  380. x = self.global_pool(x)
  381. if self.drop_rate > 0.:
  382. x = F.dropout(x, p=self.drop_rate, training=self.training)
  383. return x if pre_logits else self.last_linear(x)
  384. def forward(self, x):
  385. x = self.forward_features(x)
  386. x = self.forward_head(x)
  387. return x
  388. def _create_senet(variant, pretrained=False, **kwargs):
  389. return build_model_with_cfg(SENet, variant, pretrained, **kwargs)
  390. def _cfg(url='', **kwargs):
  391. return {
  392. 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  393. 'crop_pct': 0.875, 'interpolation': 'bilinear',
  394. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  395. 'first_conv': 'layer0.conv1', 'classifier': 'last_linear', 'license': 'apache-2.0',
  396. **kwargs
  397. }
  398. default_cfgs = generate_default_cfgs({
  399. 'legacy_senet154.in1k': _cfg(
  400. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/legacy_senet154-e9eb9fe6.pth'),
  401. 'legacy_seresnet18.in1k': _cfg(
  402. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet18-4bb0ce65.pth',
  403. interpolation='bicubic'),
  404. 'legacy_seresnet34.in1k': _cfg(
  405. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet34-a4004e63.pth'),
  406. 'legacy_seresnet50.in1k': _cfg(
  407. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet50-ce0d4300.pth'),
  408. 'legacy_seresnet101.in1k': _cfg(
  409. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet101-7e38fcc6.pth'),
  410. 'legacy_seresnet152.in1k': _cfg(
  411. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet152-d17c99b7.pth'),
  412. 'legacy_seresnext26_32x4d.in1k': _cfg(
  413. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26_32x4d-65ebdb501.pth',
  414. interpolation='bicubic'),
  415. 'legacy_seresnext50_32x4d.in1k': _cfg(
  416. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/legacy_se_resnext50_32x4d-f3651bad.pth'),
  417. 'legacy_seresnext101_32x4d.in1k': _cfg(
  418. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/legacy_se_resnext101_32x4d-37725eac.pth'),
  419. })
  420. @register_model
  421. def legacy_seresnet18(pretrained=False, **kwargs) -> SENet:
  422. model_args = dict(
  423. block=SEResNetBlock, layers=[2, 2, 2, 2], groups=1, reduction=16)
  424. return _create_senet('legacy_seresnet18', pretrained, **dict(model_args, **kwargs))
  425. @register_model
  426. def legacy_seresnet34(pretrained=False, **kwargs) -> SENet:
  427. model_args = dict(
  428. block=SEResNetBlock, layers=[3, 4, 6, 3], groups=1, reduction=16)
  429. return _create_senet('legacy_seresnet34', pretrained, **dict(model_args, **kwargs))
  430. @register_model
  431. def legacy_seresnet50(pretrained=False, **kwargs) -> SENet:
  432. model_args = dict(
  433. block=SEResNetBottleneck, layers=[3, 4, 6, 3], groups=1, reduction=16)
  434. return _create_senet('legacy_seresnet50', pretrained, **dict(model_args, **kwargs))
  435. @register_model
  436. def legacy_seresnet101(pretrained=False, **kwargs) -> SENet:
  437. model_args = dict(
  438. block=SEResNetBottleneck, layers=[3, 4, 23, 3], groups=1, reduction=16)
  439. return _create_senet('legacy_seresnet101', pretrained, **dict(model_args, **kwargs))
  440. @register_model
  441. def legacy_seresnet152(pretrained=False, **kwargs) -> SENet:
  442. model_args = dict(
  443. block=SEResNetBottleneck, layers=[3, 8, 36, 3], groups=1, reduction=16)
  444. return _create_senet('legacy_seresnet152', pretrained, **dict(model_args, **kwargs))
  445. @register_model
  446. def legacy_senet154(pretrained=False, **kwargs) -> SENet:
  447. model_args = dict(
  448. block=SEBottleneck, layers=[3, 8, 36, 3], groups=64, reduction=16,
  449. downsample_kernel_size=3, downsample_padding=1, inplanes=128, input_3x3=True)
  450. return _create_senet('legacy_senet154', pretrained, **dict(model_args, **kwargs))
  451. @register_model
  452. def legacy_seresnext26_32x4d(pretrained=False, **kwargs) -> SENet:
  453. model_args = dict(
  454. block=SEResNeXtBottleneck, layers=[2, 2, 2, 2], groups=32, reduction=16)
  455. return _create_senet('legacy_seresnext26_32x4d', pretrained, **dict(model_args, **kwargs))
  456. @register_model
  457. def legacy_seresnext50_32x4d(pretrained=False, **kwargs) -> SENet:
  458. model_args = dict(
  459. block=SEResNeXtBottleneck, layers=[3, 4, 6, 3], groups=32, reduction=16)
  460. return _create_senet('legacy_seresnext50_32x4d', pretrained, **dict(model_args, **kwargs))
  461. @register_model
  462. def legacy_seresnext101_32x4d(pretrained=False, **kwargs) -> SENet:
  463. model_args = dict(
  464. block=SEResNeXtBottleneck, layers=[3, 4, 23, 3], groups=32, reduction=16)
  465. return _create_senet('legacy_seresnext101_32x4d', pretrained, **dict(model_args, **kwargs))