inception_v4.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. """ Pytorch Inception-V4 implementation
  2. Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is
  3. based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License)
  4. """
  5. from functools import partial
  6. from typing import List, Optional, Tuple, Union, Type
  7. import torch
  8. import torch.nn as nn
  9. from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
  10. from timm.layers import create_classifier, ConvNormAct
  11. from ._builder import build_model_with_cfg
  12. from ._features import feature_take_indices
  13. from ._registry import register_model, generate_default_cfgs
  14. __all__ = ['InceptionV4']
  15. class Mixed3a(nn.Module):
  16. def __init__(
  17. self,
  18. conv_block: Type[nn.Module] = ConvNormAct,
  19. device=None,
  20. dtype=None,
  21. ):
  22. dd = {'device': device, 'dtype': dtype}
  23. super().__init__()
  24. self.maxpool = nn.MaxPool2d(3, stride=2)
  25. self.conv = conv_block(64, 96, kernel_size=3, stride=2, **dd)
  26. def forward(self, x):
  27. x0 = self.maxpool(x)
  28. x1 = self.conv(x)
  29. out = torch.cat((x0, x1), 1)
  30. return out
  31. class Mixed4a(nn.Module):
  32. def __init__(
  33. self,
  34. conv_block: Type[nn.Module] = ConvNormAct,
  35. device=None,
  36. dtype=None,
  37. ):
  38. dd = {'device': device, 'dtype': dtype}
  39. super().__init__()
  40. self.branch0 = nn.Sequential(
  41. conv_block(160, 64, kernel_size=1, stride=1, **dd),
  42. conv_block(64, 96, kernel_size=3, stride=1, **dd)
  43. )
  44. self.branch1 = nn.Sequential(
  45. conv_block(160, 64, kernel_size=1, stride=1, **dd),
  46. conv_block(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3), **dd),
  47. conv_block(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0), **dd),
  48. conv_block(64, 96, kernel_size=(3, 3), stride=1, **dd)
  49. )
  50. def forward(self, x):
  51. x0 = self.branch0(x)
  52. x1 = self.branch1(x)
  53. out = torch.cat((x0, x1), 1)
  54. return out
  55. class Mixed5a(nn.Module):
  56. def __init__(
  57. self,
  58. conv_block: Type[nn.Module] = ConvNormAct,
  59. device=None,
  60. dtype=None,
  61. ):
  62. dd = {'device': device, 'dtype': dtype}
  63. super().__init__()
  64. self.conv = conv_block(192, 192, kernel_size=3, stride=2, **dd)
  65. self.maxpool = nn.MaxPool2d(3, stride=2)
  66. def forward(self, x):
  67. x0 = self.conv(x)
  68. x1 = self.maxpool(x)
  69. out = torch.cat((x0, x1), 1)
  70. return out
  71. class InceptionA(nn.Module):
  72. def __init__(
  73. self,
  74. conv_block: Type[nn.Module] = ConvNormAct,
  75. device=None,
  76. dtype=None,
  77. ):
  78. dd = {'device': device, 'dtype': dtype}
  79. super().__init__()
  80. self.branch0 = conv_block(384, 96, kernel_size=1, stride=1, **dd)
  81. self.branch1 = nn.Sequential(
  82. conv_block(384, 64, kernel_size=1, stride=1, **dd),
  83. conv_block(64, 96, kernel_size=3, stride=1, padding=1, **dd)
  84. )
  85. self.branch2 = nn.Sequential(
  86. conv_block(384, 64, kernel_size=1, stride=1, **dd),
  87. conv_block(64, 96, kernel_size=3, stride=1, padding=1, **dd),
  88. conv_block(96, 96, kernel_size=3, stride=1, padding=1, **dd)
  89. )
  90. self.branch3 = nn.Sequential(
  91. nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
  92. conv_block(384, 96, kernel_size=1, stride=1, **dd)
  93. )
  94. def forward(self, x):
  95. x0 = self.branch0(x)
  96. x1 = self.branch1(x)
  97. x2 = self.branch2(x)
  98. x3 = self.branch3(x)
  99. out = torch.cat((x0, x1, x2, x3), 1)
  100. return out
  101. class ReductionA(nn.Module):
  102. def __init__(
  103. self,
  104. conv_block: Type[nn.Module] = ConvNormAct,
  105. device=None,
  106. dtype=None,
  107. ):
  108. dd = {'device': device, 'dtype': dtype}
  109. super().__init__()
  110. self.branch0 = conv_block(384, 384, kernel_size=3, stride=2, **dd)
  111. self.branch1 = nn.Sequential(
  112. conv_block(384, 192, kernel_size=1, stride=1, **dd),
  113. conv_block(192, 224, kernel_size=3, stride=1, padding=1, **dd),
  114. conv_block(224, 256, kernel_size=3, stride=2, **dd)
  115. )
  116. self.branch2 = nn.MaxPool2d(3, stride=2)
  117. def forward(self, x):
  118. x0 = self.branch0(x)
  119. x1 = self.branch1(x)
  120. x2 = self.branch2(x)
  121. out = torch.cat((x0, x1, x2), 1)
  122. return out
  123. class InceptionB(nn.Module):
  124. def __init__(
  125. self,
  126. conv_block: Type[nn.Module] = ConvNormAct,
  127. device=None,
  128. dtype=None,
  129. ):
  130. dd = {'device': device, 'dtype': dtype}
  131. super().__init__()
  132. self.branch0 = conv_block(1024, 384, kernel_size=1, stride=1, **dd)
  133. self.branch1 = nn.Sequential(
  134. conv_block(1024, 192, kernel_size=1, stride=1, **dd),
  135. conv_block(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3), **dd),
  136. conv_block(224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0), **dd)
  137. )
  138. self.branch2 = nn.Sequential(
  139. conv_block(1024, 192, kernel_size=1, stride=1, **dd),
  140. conv_block(192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0), **dd),
  141. conv_block(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3), **dd),
  142. conv_block(224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0), **dd),
  143. conv_block(224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3), **dd)
  144. )
  145. self.branch3 = nn.Sequential(
  146. nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
  147. conv_block(1024, 128, kernel_size=1, stride=1, **dd)
  148. )
  149. def forward(self, x):
  150. x0 = self.branch0(x)
  151. x1 = self.branch1(x)
  152. x2 = self.branch2(x)
  153. x3 = self.branch3(x)
  154. out = torch.cat((x0, x1, x2, x3), 1)
  155. return out
  156. class ReductionB(nn.Module):
  157. def __init__(
  158. self,
  159. conv_block: Type[nn.Module] = ConvNormAct,
  160. device=None,
  161. dtype=None,
  162. ):
  163. dd = {'device': device, 'dtype': dtype}
  164. super().__init__()
  165. self.branch0 = nn.Sequential(
  166. conv_block(1024, 192, kernel_size=1, stride=1, **dd),
  167. conv_block(192, 192, kernel_size=3, stride=2, **dd)
  168. )
  169. self.branch1 = nn.Sequential(
  170. conv_block(1024, 256, kernel_size=1, stride=1, **dd),
  171. conv_block(256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3), **dd),
  172. conv_block(256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0), **dd),
  173. conv_block(320, 320, kernel_size=3, stride=2, **dd)
  174. )
  175. self.branch2 = nn.MaxPool2d(3, stride=2)
  176. def forward(self, x):
  177. x0 = self.branch0(x)
  178. x1 = self.branch1(x)
  179. x2 = self.branch2(x)
  180. out = torch.cat((x0, x1, x2), 1)
  181. return out
  182. class InceptionC(nn.Module):
  183. def __init__(
  184. self,
  185. conv_block: Type[nn.Module] = ConvNormAct,
  186. device=None,
  187. dtype=None,
  188. ):
  189. dd = {'device': device, 'dtype': dtype}
  190. super().__init__()
  191. self.branch0 = conv_block(1536, 256, kernel_size=1, stride=1, **dd)
  192. self.branch1_0 = conv_block(1536, 384, kernel_size=1, stride=1, **dd)
  193. self.branch1_1a = conv_block(384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1), **dd)
  194. self.branch1_1b = conv_block(384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0), **dd)
  195. self.branch2_0 = conv_block(1536, 384, kernel_size=1, stride=1, **dd)
  196. self.branch2_1 = conv_block(384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0), **dd)
  197. self.branch2_2 = conv_block(448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1), **dd)
  198. self.branch2_3a = conv_block(512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1), **dd)
  199. self.branch2_3b = conv_block(512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0), **dd)
  200. self.branch3 = nn.Sequential(
  201. nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
  202. conv_block(1536, 256, kernel_size=1, stride=1, **dd)
  203. )
  204. def forward(self, x):
  205. x0 = self.branch0(x)
  206. x1_0 = self.branch1_0(x)
  207. x1_1a = self.branch1_1a(x1_0)
  208. x1_1b = self.branch1_1b(x1_0)
  209. x1 = torch.cat((x1_1a, x1_1b), 1)
  210. x2_0 = self.branch2_0(x)
  211. x2_1 = self.branch2_1(x2_0)
  212. x2_2 = self.branch2_2(x2_1)
  213. x2_3a = self.branch2_3a(x2_2)
  214. x2_3b = self.branch2_3b(x2_2)
  215. x2 = torch.cat((x2_3a, x2_3b), 1)
  216. x3 = self.branch3(x)
  217. out = torch.cat((x0, x1, x2, x3), 1)
  218. return out
  219. class InceptionV4(nn.Module):
  220. def __init__(
  221. self,
  222. num_classes: int = 1000,
  223. in_chans: int = 3,
  224. output_stride: int = 32,
  225. drop_rate: float = 0.,
  226. global_pool: str = 'avg',
  227. norm_layer: str = 'batchnorm2d',
  228. norm_eps: float = 1e-3,
  229. act_layer: str = 'relu',
  230. device=None,
  231. dtype=None,
  232. ) -> None:
  233. dd = {'device': device, 'dtype': dtype}
  234. super().__init__()
  235. assert output_stride == 32
  236. self.num_classes = num_classes
  237. self.in_chans = in_chans
  238. self.num_features = self.head_hidden_size = 1536
  239. conv_block = partial(
  240. ConvNormAct,
  241. padding=0,
  242. norm_layer=norm_layer,
  243. act_layer=act_layer,
  244. norm_kwargs=dict(eps=norm_eps),
  245. act_kwargs=dict(inplace=True),
  246. )
  247. features = [
  248. conv_block(in_chans, 32, kernel_size=3, stride=2, **dd),
  249. conv_block(32, 32, kernel_size=3, stride=1, **dd),
  250. conv_block(32, 64, kernel_size=3, stride=1, padding=1, **dd),
  251. Mixed3a(conv_block, **dd),
  252. Mixed4a(conv_block, **dd),
  253. Mixed5a(conv_block, **dd),
  254. ]
  255. features += [InceptionA(conv_block, **dd) for _ in range(4)]
  256. features += [ReductionA(conv_block, **dd)] # Mixed6a
  257. features += [InceptionB(conv_block, **dd) for _ in range(7)]
  258. features += [ReductionB(conv_block, **dd)] # Mixed7a
  259. features += [InceptionC(conv_block, **dd) for _ in range(3)]
  260. self.features = nn.Sequential(*features)
  261. self.feature_info = [
  262. dict(num_chs=64, reduction=2, module='features.2'),
  263. dict(num_chs=160, reduction=4, module='features.3'),
  264. dict(num_chs=384, reduction=8, module='features.9'),
  265. dict(num_chs=1024, reduction=16, module='features.17'),
  266. dict(num_chs=1536, reduction=32, module='features.21'),
  267. ]
  268. self.global_pool, self.head_drop, self.last_linear = create_classifier(
  269. self.num_features,
  270. self.num_classes,
  271. pool_type=global_pool,
  272. drop_rate=drop_rate,
  273. **dd,
  274. )
  275. @torch.jit.ignore
  276. def group_matcher(self, coarse=False):
  277. return dict(
  278. stem=r'^features\.[012]\.',
  279. blocks=r'^features\.(\d+)'
  280. )
  281. @torch.jit.ignore
  282. def set_grad_checkpointing(self, enable=True):
  283. assert not enable, 'gradient checkpointing not supported'
  284. @torch.jit.ignore
  285. def get_classifier(self) -> nn.Module:
  286. return self.last_linear
  287. def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
  288. self.num_classes = num_classes
  289. self.global_pool, self.last_linear = create_classifier(
  290. self.num_features, self.num_classes, pool_type=global_pool)
  291. def forward_intermediates(
  292. self,
  293. x: torch.Tensor,
  294. indices: Optional[Union[int, List[int]]] = None,
  295. norm: bool = False,
  296. stop_early: bool = False,
  297. output_fmt: str = 'NCHW',
  298. intermediates_only: bool = False,
  299. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  300. """ Forward features that returns intermediates.
  301. Args:
  302. x: Input image tensor
  303. indices: Take last n blocks if int, all if None, select matching indices if sequence
  304. norm: Apply norm layer to compatible intermediates
  305. stop_early: Stop iterating over blocks when last desired intermediate hit
  306. output_fmt: Shape of intermediate feature outputs
  307. intermediates_only: Only return intermediate features
  308. Returns:
  309. """
  310. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  311. intermediates = []
  312. stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
  313. take_indices, max_index = feature_take_indices(len(stage_ends), indices)
  314. take_indices = [stage_ends[i] for i in take_indices]
  315. max_index = stage_ends[max_index]
  316. # forward pass
  317. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  318. stages = self.features
  319. else:
  320. stages = self.features[:max_index + 1]
  321. for feat_idx, stage in enumerate(stages):
  322. x = stage(x)
  323. if feat_idx in take_indices:
  324. intermediates.append(x)
  325. if intermediates_only:
  326. return intermediates
  327. return x, intermediates
  328. def prune_intermediate_layers(
  329. self,
  330. indices: Union[int, List[int]] = 1,
  331. prune_norm: bool = False,
  332. prune_head: bool = True,
  333. ):
  334. """ Prune layers not required for specified intermediates.
  335. """
  336. stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
  337. take_indices, max_index = feature_take_indices(len(stage_ends), indices)
  338. max_index = stage_ends[max_index]
  339. self.features = self.features[:max_index + 1] # truncate blocks w/ stem as idx 0
  340. if prune_head:
  341. self.reset_classifier(0, '')
  342. return take_indices
  343. def forward_features(self, x):
  344. return self.features(x)
  345. def forward_head(self, x, pre_logits: bool = False):
  346. x = self.global_pool(x)
  347. x = self.head_drop(x)
  348. return x if pre_logits else self.last_linear(x)
  349. def forward(self, x):
  350. x = self.forward_features(x)
  351. x = self.forward_head(x)
  352. return x
  353. def _create_inception_v4(variant, pretrained=False, **kwargs) -> InceptionV4:
  354. return build_model_with_cfg(
  355. InceptionV4,
  356. variant,
  357. pretrained,
  358. feature_cfg=dict(flatten_sequential=True),
  359. **kwargs,
  360. )
  361. default_cfgs = generate_default_cfgs({
  362. 'inception_v4.tf_in1k': {
  363. 'hf_hub_id': 'timm/',
  364. 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
  365. 'crop_pct': 0.875, 'interpolation': 'bicubic',
  366. 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
  367. 'first_conv': 'features.0.conv', 'classifier': 'last_linear',
  368. 'license': 'apache-2.0',
  369. }
  370. })
  371. @register_model
  372. def inception_v4(pretrained=False, **kwargs):
  373. return _create_inception_v4('inception_v4', pretrained, **kwargs)