inception_resnet_v2.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. """ Pytorch Inception-Resnet-V2 implementation
  2. Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is
  3. based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License)
  4. """
  5. from functools import partial
  6. from typing import Type, Optional
  7. import torch
  8. import torch.nn as nn
  9. from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
  10. from timm.layers import create_classifier, ConvNormAct
  11. from ._builder import build_model_with_cfg
  12. from ._manipulate import flatten_modules
  13. from ._registry import register_model, generate_default_cfgs, register_model_deprecations
  14. __all__ = ['InceptionResnetV2']
  15. class Mixed_5b(nn.Module):
  16. def __init__(
  17. self,
  18. conv_block: Optional[Type[nn.Module]] = None,
  19. device=None,
  20. dtype=None,
  21. ):
  22. dd = {'device': device, 'dtype': dtype}
  23. super().__init__()
  24. conv_block = conv_block or ConvNormAct
  25. self.branch0 = conv_block(192, 96, kernel_size=1, stride=1, **dd)
  26. self.branch1 = nn.Sequential(
  27. conv_block(192, 48, kernel_size=1, stride=1, **dd),
  28. conv_block(48, 64, kernel_size=5, stride=1, padding=2, **dd)
  29. )
  30. self.branch2 = nn.Sequential(
  31. conv_block(192, 64, kernel_size=1, stride=1, **dd),
  32. conv_block(64, 96, kernel_size=3, stride=1, padding=1, **dd),
  33. conv_block(96, 96, kernel_size=3, stride=1, padding=1, **dd)
  34. )
  35. self.branch3 = nn.Sequential(
  36. nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
  37. conv_block(192, 64, kernel_size=1, stride=1, **dd)
  38. )
  39. def forward(self, x):
  40. x0 = self.branch0(x)
  41. x1 = self.branch1(x)
  42. x2 = self.branch2(x)
  43. x3 = self.branch3(x)
  44. out = torch.cat((x0, x1, x2, x3), 1)
  45. return out
  46. class Block35(nn.Module):
  47. def __init__(
  48. self,
  49. scale: float = 1.0,
  50. conv_block: Optional[Type[nn.Module]] = None,
  51. device=None,
  52. dtype=None,
  53. ):
  54. dd = {'device': device, 'dtype': dtype}
  55. super().__init__()
  56. self.scale = scale
  57. conv_block = conv_block or ConvNormAct
  58. self.branch0 = conv_block(320, 32, kernel_size=1, stride=1, **dd)
  59. self.branch1 = nn.Sequential(
  60. conv_block(320, 32, kernel_size=1, stride=1, **dd),
  61. conv_block(32, 32, kernel_size=3, stride=1, padding=1, **dd)
  62. )
  63. self.branch2 = nn.Sequential(
  64. conv_block(320, 32, kernel_size=1, stride=1, **dd),
  65. conv_block(32, 48, kernel_size=3, stride=1, padding=1, **dd),
  66. conv_block(48, 64, kernel_size=3, stride=1, padding=1, **dd)
  67. )
  68. self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1, **dd)
  69. self.act = nn.ReLU()
  70. def forward(self, x):
  71. x0 = self.branch0(x)
  72. x1 = self.branch1(x)
  73. x2 = self.branch2(x)
  74. out = torch.cat((x0, x1, x2), 1)
  75. out = self.conv2d(out)
  76. out = out * self.scale + x
  77. out = self.act(out)
  78. return out
  79. class Mixed_6a(nn.Module):
  80. def __init__(
  81. self,
  82. conv_block: Optional[Type[nn.Module]] = None,
  83. device=None,
  84. dtype=None,
  85. ):
  86. dd = {'device': device, 'dtype': dtype}
  87. super().__init__()
  88. conv_block = conv_block or ConvNormAct
  89. self.branch0 = conv_block(320, 384, kernel_size=3, stride=2, **dd)
  90. self.branch1 = nn.Sequential(
  91. conv_block(320, 256, kernel_size=1, stride=1, **dd),
  92. conv_block(256, 256, kernel_size=3, stride=1, padding=1, **dd),
  93. conv_block(256, 384, kernel_size=3, stride=2, **dd)
  94. )
  95. self.branch2 = nn.MaxPool2d(3, stride=2)
  96. def forward(self, x):
  97. x0 = self.branch0(x)
  98. x1 = self.branch1(x)
  99. x2 = self.branch2(x)
  100. out = torch.cat((x0, x1, x2), 1)
  101. return out
  102. class Block17(nn.Module):
  103. def __init__(
  104. self,
  105. scale: float = 1.0,
  106. conv_block: Optional[Type[nn.Module]] = None,
  107. device=None,
  108. dtype=None,
  109. ):
  110. dd = {'device': device, 'dtype': dtype}
  111. super().__init__()
  112. self.scale = scale
  113. conv_block = conv_block or ConvNormAct
  114. self.branch0 = conv_block(1088, 192, kernel_size=1, stride=1, **dd)
  115. self.branch1 = nn.Sequential(
  116. conv_block(1088, 128, kernel_size=1, stride=1, **dd),
  117. conv_block(128, 160, kernel_size=(1, 7), stride=1, padding=(0, 3), **dd),
  118. conv_block(160, 192, kernel_size=(7, 1), stride=1, padding=(3, 0), **dd)
  119. )
  120. self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1, **dd)
  121. self.act = nn.ReLU()
  122. def forward(self, x):
  123. x0 = self.branch0(x)
  124. x1 = self.branch1(x)
  125. out = torch.cat((x0, x1), 1)
  126. out = self.conv2d(out)
  127. out = out * self.scale + x
  128. out = self.act(out)
  129. return out
  130. class Mixed_7a(nn.Module):
  131. def __init__(
  132. self,
  133. conv_block: Optional[Type[nn.Module]] = None,
  134. device=None,
  135. dtype=None,
  136. ):
  137. dd = {'device': device, 'dtype': dtype}
  138. super().__init__()
  139. conv_block = conv_block or ConvNormAct
  140. self.branch0 = nn.Sequential(
  141. conv_block(1088, 256, kernel_size=1, stride=1, **dd),
  142. conv_block(256, 384, kernel_size=3, stride=2, **dd)
  143. )
  144. self.branch1 = nn.Sequential(
  145. conv_block(1088, 256, kernel_size=1, stride=1, **dd),
  146. conv_block(256, 288, kernel_size=3, stride=2, **dd)
  147. )
  148. self.branch2 = nn.Sequential(
  149. conv_block(1088, 256, kernel_size=1, stride=1, **dd),
  150. conv_block(256, 288, kernel_size=3, stride=1, padding=1, **dd),
  151. conv_block(288, 320, kernel_size=3, stride=2, **dd)
  152. )
  153. self.branch3 = nn.MaxPool2d(3, stride=2)
  154. def forward(self, x):
  155. x0 = self.branch0(x)
  156. x1 = self.branch1(x)
  157. x2 = self.branch2(x)
  158. x3 = self.branch3(x)
  159. out = torch.cat((x0, x1, x2, x3), 1)
  160. return out
  161. class Block8(nn.Module):
  162. def __init__(
  163. self,
  164. scale: float = 1.0,
  165. no_relu: bool = False,
  166. conv_block: Optional[Type[nn.Module]] = None,
  167. device=None,
  168. dtype=None,
  169. ):
  170. dd = {'device': device, 'dtype': dtype}
  171. super().__init__()
  172. self.scale = scale
  173. conv_block = conv_block or ConvNormAct
  174. self.branch0 = conv_block(2080, 192, kernel_size=1, stride=1, **dd)
  175. self.branch1 = nn.Sequential(
  176. conv_block(2080, 192, kernel_size=1, stride=1, **dd),
  177. conv_block(192, 224, kernel_size=(1, 3), stride=1, padding=(0, 1), **dd),
  178. conv_block(224, 256, kernel_size=(3, 1), stride=1, padding=(1, 0), **dd)
  179. )
  180. self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1, **dd)
  181. self.relu = None if no_relu else nn.ReLU()
  182. def forward(self, x):
  183. x0 = self.branch0(x)
  184. x1 = self.branch1(x)
  185. out = torch.cat((x0, x1), 1)
  186. out = self.conv2d(out)
  187. out = out * self.scale + x
  188. if self.relu is not None:
  189. out = self.relu(out)
  190. return out
  191. class InceptionResnetV2(nn.Module):
  192. def __init__(
  193. self,
  194. num_classes: int = 1000,
  195. in_chans: int = 3,
  196. drop_rate: float = 0.,
  197. output_stride: int = 32,
  198. global_pool: str = 'avg',
  199. norm_layer: str = 'batchnorm2d',
  200. norm_eps: float = 1e-3,
  201. act_layer: str = 'relu',
  202. device=None,
  203. dtype=None,
  204. ) -> None:
  205. super().__init__()
  206. dd = {'device': device, 'dtype': dtype}
  207. self.num_classes = num_classes
  208. self.in_chans = in_chans
  209. self.num_features = self.head_hidden_size = 1536
  210. assert output_stride == 32
  211. conv_block = partial(
  212. ConvNormAct,
  213. padding=0,
  214. norm_layer=norm_layer,
  215. act_layer=act_layer,
  216. norm_kwargs=dict(eps=norm_eps),
  217. act_kwargs=dict(inplace=True),
  218. )
  219. self.conv2d_1a = conv_block(in_chans, 32, kernel_size=3, stride=2, **dd)
  220. self.conv2d_2a = conv_block(32, 32, kernel_size=3, stride=1, **dd)
  221. self.conv2d_2b = conv_block(32, 64, kernel_size=3, stride=1, padding=1, **dd)
  222. self.feature_info = [dict(num_chs=64, reduction=2, module='conv2d_2b')]
  223. self.maxpool_3a = nn.MaxPool2d(3, stride=2)
  224. self.conv2d_3b = conv_block(64, 80, kernel_size=1, stride=1, **dd)
  225. self.conv2d_4a = conv_block(80, 192, kernel_size=3, stride=1, **dd)
  226. self.feature_info += [dict(num_chs=192, reduction=4, module='conv2d_4a')]
  227. self.maxpool_5a = nn.MaxPool2d(3, stride=2)
  228. self.mixed_5b = Mixed_5b(conv_block=conv_block, **dd)
  229. self.repeat = nn.Sequential(*[Block35(scale=0.17, conv_block=conv_block, **dd) for _ in range(10)])
  230. self.feature_info += [dict(num_chs=320, reduction=8, module='repeat')]
  231. self.mixed_6a = Mixed_6a(conv_block=conv_block, **dd)
  232. self.repeat_1 = nn.Sequential(*[Block17(scale=0.10, conv_block=conv_block, **dd) for _ in range(20)])
  233. self.feature_info += [dict(num_chs=1088, reduction=16, module='repeat_1')]
  234. self.mixed_7a = Mixed_7a(conv_block=conv_block, **dd)
  235. self.repeat_2 = nn.Sequential(*[Block8(scale=0.20, conv_block=conv_block, **dd) for _ in range(9)])
  236. self.block8 = Block8(no_relu=True, conv_block=conv_block, **dd)
  237. self.conv2d_7b = conv_block(2080, self.num_features, kernel_size=1, stride=1, **dd)
  238. self.feature_info += [dict(num_chs=self.num_features, reduction=32, module='conv2d_7b')]
  239. self.global_pool, self.head_drop, self.classif = create_classifier(
  240. self.num_features,
  241. self.num_classes,
  242. pool_type=global_pool,
  243. drop_rate=drop_rate,
  244. **dd,
  245. )
  246. @torch.jit.ignore
  247. def group_matcher(self, coarse=False):
  248. module_map = {k: i for i, (k, _) in enumerate(flatten_modules(self.named_children(), prefix=()))}
  249. module_map.pop(('classif',))
  250. def _matcher(name):
  251. if any([name.startswith(n) for n in ('conv2d_1', 'conv2d_2')]):
  252. return 0
  253. elif any([name.startswith(n) for n in ('conv2d_3', 'conv2d_4')]):
  254. return 1
  255. elif any([name.startswith(n) for n in ('block8', 'conv2d_7')]):
  256. return len(module_map) + 1
  257. else:
  258. for k in module_map.keys():
  259. if k == tuple(name.split('.')[:len(k)]):
  260. return module_map[k]
  261. return float('inf')
  262. return _matcher
  263. @torch.jit.ignore
  264. def set_grad_checkpointing(self, enable=True):
  265. assert not enable, "checkpointing not supported"
  266. @torch.jit.ignore
  267. def get_classifier(self) -> nn.Module:
  268. return self.classif
  269. def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
  270. self.num_classes = num_classes
  271. self.global_pool, self.classif = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
  272. def forward_features(self, x):
  273. x = self.conv2d_1a(x)
  274. x = self.conv2d_2a(x)
  275. x = self.conv2d_2b(x)
  276. x = self.maxpool_3a(x)
  277. x = self.conv2d_3b(x)
  278. x = self.conv2d_4a(x)
  279. x = self.maxpool_5a(x)
  280. x = self.mixed_5b(x)
  281. x = self.repeat(x)
  282. x = self.mixed_6a(x)
  283. x = self.repeat_1(x)
  284. x = self.mixed_7a(x)
  285. x = self.repeat_2(x)
  286. x = self.block8(x)
  287. x = self.conv2d_7b(x)
  288. return x
  289. def forward_head(self, x, pre_logits: bool = False):
  290. x = self.global_pool(x)
  291. x = self.head_drop(x)
  292. return x if pre_logits else self.classif(x)
  293. def forward(self, x):
  294. x = self.forward_features(x)
  295. x = self.forward_head(x)
  296. return x
  297. def _create_inception_resnet_v2(variant, pretrained=False, **kwargs):
  298. return build_model_with_cfg(InceptionResnetV2, variant, pretrained, **kwargs)
  299. default_cfgs = generate_default_cfgs({
  300. # ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
  301. 'inception_resnet_v2.tf_in1k': {
  302. 'hf_hub_id': 'timm/',
  303. 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
  304. 'crop_pct': 0.8975, 'interpolation': 'bicubic',
  305. 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
  306. 'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
  307. 'license': 'apache-2.0',
  308. },
  309. # As per https://arxiv.org/abs/1705.07204 and
  310. # ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz
  311. 'inception_resnet_v2.tf_ens_adv_in1k': {
  312. 'hf_hub_id': 'timm/',
  313. 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
  314. 'crop_pct': 0.8975, 'interpolation': 'bicubic',
  315. 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
  316. 'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
  317. 'license': 'apache-2.0',
  318. }
  319. })
  320. @register_model
  321. def inception_resnet_v2(pretrained=False, **kwargs) -> InceptionResnetV2:
  322. return _create_inception_resnet_v2('inception_resnet_v2', pretrained=pretrained, **kwargs)
  323. register_model_deprecations(__name__, {
  324. 'ens_adv_inception_resnet_v2': 'inception_resnet_v2.tf_ens_adv_in1k',
  325. })