inception_v3.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508
  1. """ Inception-V3
  2. Originally from torchvision Inception3 model
  3. Licensed BSD-Clause 3 https://github.com/pytorch/vision/blob/master/LICENSE
  4. """
  5. from functools import partial
  6. from typing import Optional, Type
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
  11. from timm.layers import trunc_normal_, create_classifier, Linear, ConvNormAct
  12. from ._builder import build_model_with_cfg
  13. from ._builder import resolve_pretrained_cfg
  14. from ._manipulate import flatten_modules
  15. from ._registry import register_model, generate_default_cfgs, register_model_deprecations
  16. __all__ = ['InceptionV3'] # model_registry will add each entrypoint fn to this
  17. class InceptionA(nn.Module):
  18. def __init__(
  19. self,
  20. in_channels: int,
  21. pool_features: int,
  22. conv_block: Optional[Type[nn.Module]] = None,
  23. device=None,
  24. dtype=None,
  25. ):
  26. dd = {'device': device, 'dtype': dtype}
  27. super().__init__()
  28. conv_block = conv_block or ConvNormAct
  29. self.branch1x1 = conv_block(in_channels, 64, kernel_size=1, **dd)
  30. self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1, **dd)
  31. self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2, **dd)
  32. self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1, **dd)
  33. self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1, **dd)
  34. self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1, **dd)
  35. self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1, **dd)
  36. def _forward(self, x):
  37. branch1x1 = self.branch1x1(x)
  38. branch5x5 = self.branch5x5_1(x)
  39. branch5x5 = self.branch5x5_2(branch5x5)
  40. branch3x3dbl = self.branch3x3dbl_1(x)
  41. branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
  42. branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
  43. branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
  44. branch_pool = self.branch_pool(branch_pool)
  45. outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
  46. return outputs
  47. def forward(self, x):
  48. outputs = self._forward(x)
  49. return torch.cat(outputs, 1)
  50. class InceptionB(nn.Module):
  51. def __init__(
  52. self,
  53. in_channels: int,
  54. conv_block: Optional[Type[nn.Module]] = None,
  55. device=None,
  56. dtype=None,
  57. ):
  58. dd = {'device': device, 'dtype': dtype}
  59. super().__init__()
  60. conv_block = conv_block or ConvNormAct
  61. self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2, **dd)
  62. self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1, **dd)
  63. self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1, **dd)
  64. self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2, **dd)
  65. def _forward(self, x):
  66. branch3x3 = self.branch3x3(x)
  67. branch3x3dbl = self.branch3x3dbl_1(x)
  68. branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
  69. branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
  70. branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
  71. outputs = [branch3x3, branch3x3dbl, branch_pool]
  72. return outputs
  73. def forward(self, x):
  74. outputs = self._forward(x)
  75. return torch.cat(outputs, 1)
  76. class InceptionC(nn.Module):
  77. def __init__(
  78. self,
  79. in_channels: int,
  80. channels_7x7: int,
  81. conv_block: Optional[Type[nn.Module]] = None,
  82. device=None,
  83. dtype=None,
  84. ):
  85. dd = {'device': device, 'dtype': dtype}
  86. super().__init__()
  87. conv_block = conv_block or ConvNormAct
  88. self.branch1x1 = conv_block(in_channels, 192, kernel_size=1, **dd)
  89. c7 = channels_7x7
  90. self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1, **dd)
  91. self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3), **dd)
  92. self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0), **dd)
  93. self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1, **dd)
  94. self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0), **dd)
  95. self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3), **dd)
  96. self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0), **dd)
  97. self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3), **dd)
  98. self.branch_pool = conv_block(in_channels, 192, kernel_size=1, **dd)
  99. def _forward(self, x):
  100. branch1x1 = self.branch1x1(x)
  101. branch7x7 = self.branch7x7_1(x)
  102. branch7x7 = self.branch7x7_2(branch7x7)
  103. branch7x7 = self.branch7x7_3(branch7x7)
  104. branch7x7dbl = self.branch7x7dbl_1(x)
  105. branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
  106. branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
  107. branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
  108. branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
  109. branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
  110. branch_pool = self.branch_pool(branch_pool)
  111. outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
  112. return outputs
  113. def forward(self, x):
  114. outputs = self._forward(x)
  115. return torch.cat(outputs, 1)
  116. class InceptionD(nn.Module):
  117. def __init__(
  118. self,
  119. in_channels: int,
  120. conv_block: Optional[Type[nn.Module]] = None,
  121. device=None,
  122. dtype=None,
  123. ):
  124. dd = {'device': device, 'dtype': dtype}
  125. super().__init__()
  126. conv_block = conv_block or ConvNormAct
  127. self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1, **dd)
  128. self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2, **dd)
  129. self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1, **dd)
  130. self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3), **dd)
  131. self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0), **dd)
  132. self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2, **dd)
  133. def _forward(self, x):
  134. branch3x3 = self.branch3x3_1(x)
  135. branch3x3 = self.branch3x3_2(branch3x3)
  136. branch7x7x3 = self.branch7x7x3_1(x)
  137. branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
  138. branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
  139. branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
  140. branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
  141. outputs = [branch3x3, branch7x7x3, branch_pool]
  142. return outputs
  143. def forward(self, x):
  144. outputs = self._forward(x)
  145. return torch.cat(outputs, 1)
  146. class InceptionE(nn.Module):
  147. def __init__(
  148. self,
  149. in_channels: int,
  150. conv_block: Optional[Type[nn.Module]] = None,
  151. device=None,
  152. dtype=None,
  153. ):
  154. dd = {'device': device, 'dtype': dtype}
  155. super().__init__()
  156. conv_block = conv_block or ConvNormAct
  157. self.branch1x1 = conv_block(in_channels, 320, kernel_size=1, **dd)
  158. self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1, **dd)
  159. self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1), **dd)
  160. self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0), **dd)
  161. self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1, **dd)
  162. self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1, **dd)
  163. self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1), **dd)
  164. self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0), **dd)
  165. self.branch_pool = conv_block(in_channels, 192, kernel_size=1, **dd)
  166. def _forward(self, x):
  167. branch1x1 = self.branch1x1(x)
  168. branch3x3 = self.branch3x3_1(x)
  169. branch3x3 = [
  170. self.branch3x3_2a(branch3x3),
  171. self.branch3x3_2b(branch3x3),
  172. ]
  173. branch3x3 = torch.cat(branch3x3, 1)
  174. branch3x3dbl = self.branch3x3dbl_1(x)
  175. branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
  176. branch3x3dbl = [
  177. self.branch3x3dbl_3a(branch3x3dbl),
  178. self.branch3x3dbl_3b(branch3x3dbl),
  179. ]
  180. branch3x3dbl = torch.cat(branch3x3dbl, 1)
  181. branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
  182. branch_pool = self.branch_pool(branch_pool)
  183. outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
  184. return outputs
  185. def forward(self, x):
  186. outputs = self._forward(x)
  187. return torch.cat(outputs, 1)
  188. class InceptionAux(nn.Module):
  189. def __init__(
  190. self,
  191. in_channels: int,
  192. num_classes: int,
  193. conv_block: Optional[Type[nn.Module]] = None,
  194. device=None,
  195. dtype=None,
  196. ):
  197. dd = {'device': device, 'dtype': dtype}
  198. super().__init__()
  199. conv_block = conv_block or ConvNormAct
  200. self.conv0 = conv_block(in_channels, 128, kernel_size=1, **dd)
  201. self.conv1 = conv_block(128, 768, kernel_size=5, **dd)
  202. self.conv1.stddev = 0.01
  203. self.fc = Linear(768, num_classes, **dd)
  204. self.fc.stddev = 0.001
  205. def forward(self, x):
  206. # N x 768 x 17 x 17
  207. x = F.avg_pool2d(x, kernel_size=5, stride=3)
  208. # N x 768 x 5 x 5
  209. x = self.conv0(x)
  210. # N x 128 x 5 x 5
  211. x = self.conv1(x)
  212. # N x 768 x 1 x 1
  213. # Adaptive average pooling
  214. x = F.adaptive_avg_pool2d(x, (1, 1))
  215. # N x 768 x 1 x 1
  216. x = torch.flatten(x, 1)
  217. # N x 768
  218. x = self.fc(x)
  219. # N x 1000
  220. return x
  221. class InceptionV3(nn.Module):
  222. """Inception-V3
  223. """
  224. aux_logits: torch.jit.Final[bool]
  225. def __init__(
  226. self,
  227. num_classes: int = 1000,
  228. in_chans: int = 3,
  229. drop_rate: float = 0.,
  230. global_pool: str = 'avg',
  231. aux_logits: bool = False,
  232. norm_layer: str = 'batchnorm2d',
  233. norm_eps: float = 1e-3,
  234. act_layer: str = 'relu',
  235. device=None,
  236. dtype=None,
  237. ):
  238. super().__init__()
  239. dd = {'device': device, 'dtype': dtype}
  240. self.num_classes = num_classes
  241. self.in_chans = in_chans
  242. self.aux_logits = aux_logits
  243. conv_block = partial(
  244. ConvNormAct,
  245. padding=0,
  246. norm_layer=norm_layer,
  247. act_layer=act_layer,
  248. norm_kwargs=dict(eps=norm_eps),
  249. act_kwargs=dict(inplace=True),
  250. )
  251. self.Conv2d_1a_3x3 = conv_block(in_chans, 32, kernel_size=3, stride=2, **dd)
  252. self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3, **dd)
  253. self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1, **dd)
  254. self.Pool1 = nn.MaxPool2d(kernel_size=3, stride=2)
  255. self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1, **dd)
  256. self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3, **dd)
  257. self.Pool2 = nn.MaxPool2d(kernel_size=3, stride=2)
  258. self.Mixed_5b = InceptionA(192, pool_features=32, conv_block=conv_block, **dd)
  259. self.Mixed_5c = InceptionA(256, pool_features=64, conv_block=conv_block, **dd)
  260. self.Mixed_5d = InceptionA(288, pool_features=64, conv_block=conv_block, **dd)
  261. self.Mixed_6a = InceptionB(288, conv_block=conv_block, **dd)
  262. self.Mixed_6b = InceptionC(768, channels_7x7=128, conv_block=conv_block, **dd)
  263. self.Mixed_6c = InceptionC(768, channels_7x7=160, conv_block=conv_block, **dd)
  264. self.Mixed_6d = InceptionC(768, channels_7x7=160, conv_block=conv_block, **dd)
  265. self.Mixed_6e = InceptionC(768, channels_7x7=192, conv_block=conv_block, **dd)
  266. if aux_logits:
  267. self.AuxLogits = InceptionAux(768, num_classes, conv_block=conv_block, **dd)
  268. else:
  269. self.AuxLogits = None
  270. self.Mixed_7a = InceptionD(768, conv_block=conv_block, **dd)
  271. self.Mixed_7b = InceptionE(1280, conv_block=conv_block, **dd)
  272. self.Mixed_7c = InceptionE(2048, conv_block=conv_block, **dd)
  273. self.feature_info = [
  274. dict(num_chs=64, reduction=2, module='Conv2d_2b_3x3'),
  275. dict(num_chs=192, reduction=4, module='Conv2d_4a_3x3'),
  276. dict(num_chs=288, reduction=8, module='Mixed_5d'),
  277. dict(num_chs=768, reduction=16, module='Mixed_6e'),
  278. dict(num_chs=2048, reduction=32, module='Mixed_7c'),
  279. ]
  280. self.num_features = self.head_hidden_size = 2048
  281. self.global_pool, self.head_drop, self.fc = create_classifier(
  282. self.num_features,
  283. self.num_classes,
  284. pool_type=global_pool,
  285. drop_rate=drop_rate,
  286. **dd,
  287. )
  288. for m in self.modules():
  289. if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
  290. stddev = m.stddev if hasattr(m, 'stddev') else 0.1
  291. trunc_normal_(m.weight, std=stddev)
  292. elif isinstance(m, nn.BatchNorm2d):
  293. nn.init.constant_(m.weight, 1)
  294. nn.init.constant_(m.bias, 0)
  295. @torch.jit.ignore
  296. def group_matcher(self, coarse=False):
  297. module_map = {k: i for i, (k, _) in enumerate(flatten_modules(self.named_children(), prefix=()))}
  298. module_map.pop(('fc',))
  299. def _matcher(name):
  300. if any([name.startswith(n) for n in ('Conv2d_1', 'Conv2d_2')]):
  301. return 0
  302. elif any([name.startswith(n) for n in ('Conv2d_3', 'Conv2d_4')]):
  303. return 1
  304. else:
  305. for k in module_map.keys():
  306. if k == tuple(name.split('.')[:len(k)]):
  307. return module_map[k]
  308. return float('inf')
  309. return _matcher
  310. @torch.jit.ignore
  311. def set_grad_checkpointing(self, enable=True):
  312. assert not enable, 'gradient checkpointing not supported'
  313. @torch.jit.ignore
  314. def get_classifier(self) -> nn.Module:
  315. return self.fc
  316. def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
  317. self.num_classes = num_classes
  318. self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
  319. def forward_preaux(self, x):
  320. x = self.Conv2d_1a_3x3(x) # N x 32 x 149 x 149
  321. x = self.Conv2d_2a_3x3(x) # N x 32 x 147 x 147
  322. x = self.Conv2d_2b_3x3(x) # N x 64 x 147 x 147
  323. x = self.Pool1(x) # N x 64 x 73 x 73
  324. x = self.Conv2d_3b_1x1(x) # N x 80 x 73 x 73
  325. x = self.Conv2d_4a_3x3(x) # N x 192 x 71 x 71
  326. x = self.Pool2(x) # N x 192 x 35 x 35
  327. x = self.Mixed_5b(x) # N x 256 x 35 x 35
  328. x = self.Mixed_5c(x) # N x 288 x 35 x 35
  329. x = self.Mixed_5d(x) # N x 288 x 35 x 35
  330. x = self.Mixed_6a(x) # N x 768 x 17 x 17
  331. x = self.Mixed_6b(x) # N x 768 x 17 x 17
  332. x = self.Mixed_6c(x) # N x 768 x 17 x 17
  333. x = self.Mixed_6d(x) # N x 768 x 17 x 17
  334. x = self.Mixed_6e(x) # N x 768 x 17 x 17
  335. return x
  336. def forward_postaux(self, x):
  337. x = self.Mixed_7a(x) # N x 1280 x 8 x 8
  338. x = self.Mixed_7b(x) # N x 2048 x 8 x 8
  339. x = self.Mixed_7c(x) # N x 2048 x 8 x 8
  340. return x
  341. def forward_features(self, x):
  342. x = self.forward_preaux(x)
  343. if self.aux_logits:
  344. aux = self.AuxLogits(x)
  345. x = self.forward_postaux(x)
  346. return x, aux
  347. x = self.forward_postaux(x)
  348. return x
  349. def forward_head(self, x, pre_logits: bool = False):
  350. x = self.global_pool(x)
  351. x = self.head_drop(x)
  352. if pre_logits:
  353. return x
  354. x = self.fc(x)
  355. return x
  356. def forward(self, x):
  357. if self.aux_logits:
  358. x, aux = self.forward_features(x)
  359. x = self.forward_head(x)
  360. return x, aux
  361. x = self.forward_features(x)
  362. x = self.forward_head(x)
  363. return x
  364. def _create_inception_v3(variant, pretrained=False, **kwargs):
  365. pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None))
  366. aux_logits = kwargs.get('aux_logits', False)
  367. has_aux_logits = False
  368. if pretrained_cfg:
  369. # only torchvision pretrained weights have aux logits
  370. has_aux_logits = pretrained_cfg.tag == 'tv_in1k'
  371. if aux_logits:
  372. assert not kwargs.pop('features_only', False)
  373. load_strict = has_aux_logits
  374. else:
  375. load_strict = not has_aux_logits
  376. return build_model_with_cfg(
  377. InceptionV3,
  378. variant,
  379. pretrained,
  380. pretrained_cfg=pretrained_cfg,
  381. pretrained_strict=load_strict,
  382. **kwargs,
  383. )
  384. def _cfg(url='', **kwargs):
  385. return {
  386. 'url': url,
  387. 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
  388. 'crop_pct': 0.875, 'interpolation': 'bicubic',
  389. 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
  390. 'first_conv': 'Conv2d_1a_3x3.conv', 'classifier': 'fc', 'license': 'apache-2.0',
  391. **kwargs
  392. }
  393. default_cfgs = generate_default_cfgs({
  394. # original PyTorch weights, ported from Tensorflow but modified
  395. 'inception_v3.tv_in1k': _cfg(
  396. # NOTE checkpoint has aux logit layer weights
  397. hf_hub_id='timm/',
  398. url='https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth'),
  399. # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
  400. 'inception_v3.tf_in1k': _cfg(hf_hub_id='timm/'),
  401. # my port of Tensorflow adversarially trained Inception V3 from
  402. # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
  403. 'inception_v3.tf_adv_in1k': _cfg(hf_hub_id='timm/'),
  404. # from gluon pretrained models, best performing in terms of accuracy/loss metrics
  405. # https://gluon-cv.mxnet.io/model_zoo/classification.html
  406. 'inception_v3.gluon_in1k': _cfg(
  407. hf_hub_id='timm/',
  408. mean=IMAGENET_DEFAULT_MEAN, # also works well with inception defaults
  409. std=IMAGENET_DEFAULT_STD, # also works well with inception defaults
  410. )
  411. })
  412. @register_model
  413. def inception_v3(pretrained=False, **kwargs) -> InceptionV3:
  414. model = _create_inception_v3('inception_v3', pretrained=pretrained, **kwargs)
  415. return model
  416. register_model_deprecations(__name__, {
  417. 'tf_inception_v3': 'inception_v3.tf_in1k',
  418. 'adv_inception_v3': 'inception_v3.tf_adv_in1k',
  419. 'gluon_inception_v3': 'inception_v3.gluon_in1k',
  420. })