pnasnet.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. """
  2. pnasnet5large implementation grabbed from Cadene's pretrained models
  3. Additional credit to https://github.com/creafz
  4. https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/pnasnet.py
  5. """
  6. from collections import OrderedDict
  7. from functools import partial
  8. from typing import Type
  9. import torch
  10. import torch.nn as nn
  11. from timm.layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier
  12. from ._builder import build_model_with_cfg
  13. from ._registry import register_model, generate_default_cfgs
  14. __all__ = ['PNASNet5Large']
  15. class SeparableConv2d(nn.Module):
  16. def __init__(
  17. self,
  18. in_channels: int,
  19. out_channels: int,
  20. kernel_size: int,
  21. stride: int,
  22. padding: str = '',
  23. device=None,
  24. dtype=None,
  25. ):
  26. dd = {'device': device, 'dtype': dtype}
  27. super().__init__()
  28. self.depthwise_conv2d = create_conv2d(
  29. in_channels,
  30. in_channels,
  31. kernel_size=kernel_size,
  32. stride=stride,
  33. padding=padding,
  34. groups=in_channels,
  35. **dd,
  36. )
  37. self.pointwise_conv2d = create_conv2d(
  38. in_channels,
  39. out_channels,
  40. kernel_size=1,
  41. padding=padding,
  42. **dd,
  43. )
  44. def forward(self, x):
  45. x = self.depthwise_conv2d(x)
  46. x = self.pointwise_conv2d(x)
  47. return x
  48. class BranchSeparables(nn.Module):
  49. def __init__(
  50. self,
  51. in_channels: int,
  52. out_channels: int,
  53. kernel_size: int,
  54. stride: int = 1,
  55. stem_cell: bool = False,
  56. padding: str = '',
  57. device=None,
  58. dtype=None,
  59. ):
  60. dd = {'device': device, 'dtype': dtype}
  61. super().__init__()
  62. middle_channels = out_channels if stem_cell else in_channels
  63. self.act_1 = nn.ReLU()
  64. self.separable_1 = SeparableConv2d(
  65. in_channels,
  66. middle_channels,
  67. kernel_size,
  68. stride=stride,
  69. padding=padding,
  70. **dd,
  71. )
  72. self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001, **dd)
  73. self.act_2 = nn.ReLU()
  74. self.separable_2 = SeparableConv2d(
  75. middle_channels,
  76. out_channels,
  77. kernel_size,
  78. stride=1,
  79. padding=padding,
  80. **dd,
  81. )
  82. self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, **dd)
  83. def forward(self, x):
  84. x = self.act_1(x)
  85. x = self.separable_1(x)
  86. x = self.bn_sep_1(x)
  87. x = self.act_2(x)
  88. x = self.separable_2(x)
  89. x = self.bn_sep_2(x)
  90. return x
  91. class ActConvBn(nn.Module):
  92. def __init__(
  93. self,
  94. in_channels: int,
  95. out_channels: int,
  96. kernel_size: int,
  97. stride: int = 1,
  98. padding: str = '',
  99. device=None,
  100. dtype=None,
  101. ):
  102. dd = {'device': device, 'dtype': dtype}
  103. super().__init__()
  104. self.act = nn.ReLU()
  105. self.conv = create_conv2d(
  106. in_channels,
  107. out_channels,
  108. kernel_size=kernel_size,
  109. stride=stride,
  110. padding=padding,
  111. **dd,
  112. )
  113. self.bn = nn.BatchNorm2d(out_channels, eps=0.001, **dd)
  114. def forward(self, x):
  115. x = self.act(x)
  116. x = self.conv(x)
  117. x = self.bn(x)
  118. return x
  119. class FactorizedReduction(nn.Module):
  120. def __init__(
  121. self,
  122. in_channels: int,
  123. out_channels: int,
  124. padding: str = '',
  125. device=None,
  126. dtype=None,
  127. ):
  128. dd = {'device': device, 'dtype': dtype}
  129. super().__init__()
  130. self.act = nn.ReLU()
  131. self.path_1 = nn.Sequential(OrderedDict([
  132. ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)),
  133. ('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding, **dd)),
  134. ]))
  135. self.path_2 = nn.Sequential(OrderedDict([
  136. ('pad', nn.ZeroPad2d((-1, 1, -1, 1))), # shift
  137. ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)),
  138. ('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding, **dd)),
  139. ]))
  140. self.final_path_bn = nn.BatchNorm2d(out_channels, eps=0.001, **dd)
  141. def forward(self, x):
  142. x = self.act(x)
  143. x_path1 = self.path_1(x)
  144. x_path2 = self.path_2(x)
  145. out = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
  146. return out
  147. class CellBase(nn.Module):
  148. def cell_forward(self, x_left, x_right):
  149. x_comb_iter_0_left = self.comb_iter_0_left(x_left)
  150. x_comb_iter_0_right = self.comb_iter_0_right(x_left)
  151. x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
  152. x_comb_iter_1_left = self.comb_iter_1_left(x_right)
  153. x_comb_iter_1_right = self.comb_iter_1_right(x_right)
  154. x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
  155. x_comb_iter_2_left = self.comb_iter_2_left(x_right)
  156. x_comb_iter_2_right = self.comb_iter_2_right(x_right)
  157. x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
  158. x_comb_iter_3_left = self.comb_iter_3_left(x_comb_iter_2)
  159. x_comb_iter_3_right = self.comb_iter_3_right(x_right)
  160. x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right
  161. x_comb_iter_4_left = self.comb_iter_4_left(x_left)
  162. if self.comb_iter_4_right is not None:
  163. x_comb_iter_4_right = self.comb_iter_4_right(x_right)
  164. else:
  165. x_comb_iter_4_right = x_right
  166. x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
  167. x_out = torch.cat([x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
  168. return x_out
  169. class CellStem0(CellBase):
  170. def __init__(
  171. self,
  172. in_chs_left: int,
  173. out_chs_left: int,
  174. in_chs_right: int,
  175. out_chs_right: int,
  176. pad_type: str = '',
  177. device=None,
  178. dtype=None,
  179. ):
  180. dd = {'device': device, 'dtype': dtype}
  181. super().__init__()
  182. self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=pad_type, **dd)
  183. self.comb_iter_0_left = BranchSeparables(
  184. in_chs_left, out_chs_left, kernel_size=5, stride=2, stem_cell=True, padding=pad_type, **dd)
  185. self.comb_iter_0_right = nn.Sequential(OrderedDict([
  186. ('max_pool', create_pool2d('max', 3, stride=2, padding=pad_type)),
  187. ('conv', create_conv2d(in_chs_left, out_chs_left, kernel_size=1, padding=pad_type, **dd)),
  188. ('bn', nn.BatchNorm2d(out_chs_left, eps=0.001, **dd)),
  189. ]))
  190. self.comb_iter_1_left = BranchSeparables(
  191. out_chs_right, out_chs_right, kernel_size=7, stride=2, padding=pad_type, **dd)
  192. self.comb_iter_1_right = create_pool2d('max', 3, stride=2, padding=pad_type)
  193. self.comb_iter_2_left = BranchSeparables(
  194. out_chs_right, out_chs_right, kernel_size=5, stride=2, padding=pad_type, **dd)
  195. self.comb_iter_2_right = BranchSeparables(
  196. out_chs_right, out_chs_right, kernel_size=3, stride=2, padding=pad_type, **dd)
  197. self.comb_iter_3_left = BranchSeparables(
  198. out_chs_right, out_chs_right, kernel_size=3, padding=pad_type, **dd)
  199. self.comb_iter_3_right = create_pool2d('max', 3, stride=2, padding=pad_type)
  200. self.comb_iter_4_left = BranchSeparables(
  201. in_chs_right, out_chs_right, kernel_size=3, stride=2, stem_cell=True, padding=pad_type, **dd)
  202. self.comb_iter_4_right = ActConvBn(
  203. out_chs_right, out_chs_right, kernel_size=1, stride=2, padding=pad_type, **dd)
  204. def forward(self, x_left):
  205. x_right = self.conv_1x1(x_left)
  206. x_out = self.cell_forward(x_left, x_right)
  207. return x_out
  208. class Cell(CellBase):
  209. def __init__(
  210. self,
  211. in_chs_left: int,
  212. out_chs_left: int,
  213. in_chs_right: int,
  214. out_chs_right: int,
  215. pad_type: str = '',
  216. is_reduction: bool = False,
  217. match_prev_layer_dims: bool = False,
  218. device=None,
  219. dtype=None,
  220. ):
  221. dd = {'device': device, 'dtype': dtype}
  222. super().__init__()
  223. # If `is_reduction` is set to `True` stride 2 is used for
  224. # convolution and pooling layers to reduce the spatial size of
  225. # the output of a cell approximately by a factor of 2.
  226. stride = 2 if is_reduction else 1
  227. # If `match_prev_layer_dimensions` is set to `True`
  228. # `FactorizedReduction` is used to reduce the spatial size
  229. # of the left input of a cell approximately by a factor of 2.
  230. self.match_prev_layer_dimensions = match_prev_layer_dims
  231. if match_prev_layer_dims:
  232. self.conv_prev_1x1 = FactorizedReduction(in_chs_left, out_chs_left, padding=pad_type, **dd)
  233. else:
  234. self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, kernel_size=1, padding=pad_type, **dd)
  235. self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=pad_type, **dd)
  236. self.comb_iter_0_left = BranchSeparables(
  237. out_chs_left, out_chs_left, kernel_size=5, stride=stride, padding=pad_type, **dd)
  238. self.comb_iter_0_right = create_pool2d('max', 3, stride=stride, padding=pad_type)
  239. self.comb_iter_1_left = BranchSeparables(
  240. out_chs_right, out_chs_right, kernel_size=7, stride=stride, padding=pad_type, **dd)
  241. self.comb_iter_1_right = create_pool2d('max', 3, stride=stride, padding=pad_type)
  242. self.comb_iter_2_left = BranchSeparables(
  243. out_chs_right, out_chs_right, kernel_size=5, stride=stride, padding=pad_type, **dd)
  244. self.comb_iter_2_right = BranchSeparables(
  245. out_chs_right, out_chs_right, kernel_size=3, stride=stride, padding=pad_type, **dd)
  246. self.comb_iter_3_left = BranchSeparables(out_chs_right, out_chs_right, kernel_size=3, **dd)
  247. self.comb_iter_3_right = create_pool2d('max', 3, stride=stride, padding=pad_type)
  248. self.comb_iter_4_left = BranchSeparables(
  249. out_chs_left, out_chs_left, kernel_size=3, stride=stride, padding=pad_type, **dd)
  250. if is_reduction:
  251. self.comb_iter_4_right = ActConvBn(
  252. out_chs_right, out_chs_right, kernel_size=1, stride=stride, padding=pad_type, **dd)
  253. else:
  254. self.comb_iter_4_right = None
  255. def forward(self, x_left, x_right):
  256. x_left = self.conv_prev_1x1(x_left)
  257. x_right = self.conv_1x1(x_right)
  258. x_out = self.cell_forward(x_left, x_right)
  259. return x_out
  260. class PNASNet5Large(nn.Module):
  261. def __init__(
  262. self,
  263. num_classes: int = 1000,
  264. in_chans: int = 3,
  265. output_stride: int = 32,
  266. drop_rate: float = 0.,
  267. global_pool: str = 'avg',
  268. pad_type: str = '',
  269. device=None,
  270. dtype=None,
  271. ):
  272. super().__init__()
  273. dd = {'device': device, 'dtype': dtype}
  274. self.num_classes = num_classes
  275. self.in_chans = in_chans
  276. self.num_features = self.head_hidden_size = 4320
  277. assert output_stride == 32
  278. self.conv_0 = ConvNormAct(
  279. in_chans, 96, kernel_size=3, stride=2, padding=0,
  280. norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False, **dd)
  281. self.cell_stem_0 = CellStem0(
  282. in_chs_left=96, out_chs_left=54, in_chs_right=96, out_chs_right=54, pad_type=pad_type, **dd)
  283. self.cell_stem_1 = Cell(
  284. in_chs_left=96, out_chs_left=108, in_chs_right=270, out_chs_right=108, pad_type=pad_type,
  285. match_prev_layer_dims=True, is_reduction=True, **dd)
  286. self.cell_0 = Cell(
  287. in_chs_left=270, out_chs_left=216, in_chs_right=540, out_chs_right=216, pad_type=pad_type,
  288. match_prev_layer_dims=True, **dd)
  289. self.cell_1 = Cell(
  290. in_chs_left=540, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type, **dd)
  291. self.cell_2 = Cell(
  292. in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type, **dd)
  293. self.cell_3 = Cell(
  294. in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type, **dd)
  295. self.cell_4 = Cell(
  296. in_chs_left=1080, out_chs_left=432, in_chs_right=1080, out_chs_right=432, pad_type=pad_type,
  297. is_reduction=True, **dd)
  298. self.cell_5 = Cell(
  299. in_chs_left=1080, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type,
  300. match_prev_layer_dims=True, **dd)
  301. self.cell_6 = Cell(
  302. in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type, **dd)
  303. self.cell_7 = Cell(
  304. in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type, **dd)
  305. self.cell_8 = Cell(
  306. in_chs_left=2160, out_chs_left=864, in_chs_right=2160, out_chs_right=864, pad_type=pad_type,
  307. is_reduction=True, **dd)
  308. self.cell_9 = Cell(
  309. in_chs_left=2160, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type,
  310. match_prev_layer_dims=True, **dd)
  311. self.cell_10 = Cell(
  312. in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type, **dd)
  313. self.cell_11 = Cell(
  314. in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type, **dd)
  315. self.act = nn.ReLU()
  316. self.feature_info = [
  317. dict(num_chs=96, reduction=2, module='conv_0'),
  318. dict(num_chs=270, reduction=4, module='cell_stem_1.conv_1x1.act'),
  319. dict(num_chs=1080, reduction=8, module='cell_4.conv_1x1.act'),
  320. dict(num_chs=2160, reduction=16, module='cell_8.conv_1x1.act'),
  321. dict(num_chs=4320, reduction=32, module='act'),
  322. ]
  323. self.global_pool, self.head_drop, self.last_linear = create_classifier(
  324. self.num_features, self.num_classes, pool_type=global_pool, drop_rate=drop_rate, **dd)
  325. @torch.jit.ignore
  326. def group_matcher(self, coarse=False):
  327. return dict(stem=r'^conv_0|cell_stem_[01]', blocks=r'^cell_(\d+)')
  328. @torch.jit.ignore
  329. def set_grad_checkpointing(self, enable=True):
  330. assert not enable, 'gradient checkpointing not supported'
  331. @torch.jit.ignore
  332. def get_classifier(self) -> nn.Module:
  333. return self.last_linear
  334. def reset_classifier(self, num_classes: int, global_pool: str = 'avg', device=None, dtype=None):
  335. dd = {'device': device, 'dtype': dtype}
  336. self.num_classes = num_classes
  337. self.global_pool, self.last_linear = create_classifier(
  338. self.num_features, self.num_classes, pool_type=global_pool, **dd)
  339. def forward_features(self, x):
  340. x_conv_0 = self.conv_0(x)
  341. x_stem_0 = self.cell_stem_0(x_conv_0)
  342. x_stem_1 = self.cell_stem_1(x_conv_0, x_stem_0)
  343. x_cell_0 = self.cell_0(x_stem_0, x_stem_1)
  344. x_cell_1 = self.cell_1(x_stem_1, x_cell_0)
  345. x_cell_2 = self.cell_2(x_cell_0, x_cell_1)
  346. x_cell_3 = self.cell_3(x_cell_1, x_cell_2)
  347. x_cell_4 = self.cell_4(x_cell_2, x_cell_3)
  348. x_cell_5 = self.cell_5(x_cell_3, x_cell_4)
  349. x_cell_6 = self.cell_6(x_cell_4, x_cell_5)
  350. x_cell_7 = self.cell_7(x_cell_5, x_cell_6)
  351. x_cell_8 = self.cell_8(x_cell_6, x_cell_7)
  352. x_cell_9 = self.cell_9(x_cell_7, x_cell_8)
  353. x_cell_10 = self.cell_10(x_cell_8, x_cell_9)
  354. x_cell_11 = self.cell_11(x_cell_9, x_cell_10)
  355. x = self.act(x_cell_11)
  356. return x
  357. def forward_head(self, x, pre_logits: bool = False):
  358. x = self.global_pool(x)
  359. x = self.head_drop(x)
  360. return x if pre_logits else self.last_linear(x)
  361. def forward(self, x):
  362. x = self.forward_features(x)
  363. x = self.forward_head(x)
  364. return x
  365. def _create_pnasnet(variant, pretrained=False, **kwargs):
  366. return build_model_with_cfg(
  367. PNASNet5Large,
  368. variant,
  369. pretrained,
  370. feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model
  371. **kwargs,
  372. )
  373. default_cfgs = generate_default_cfgs({
  374. 'pnasnet5large.tf_in1k': {
  375. 'hf_hub_id': 'timm/',
  376. 'input_size': (3, 331, 331),
  377. 'pool_size': (11, 11),
  378. 'crop_pct': 0.911,
  379. 'interpolation': 'bicubic',
  380. 'mean': (0.5, 0.5, 0.5),
  381. 'std': (0.5, 0.5, 0.5),
  382. 'num_classes': 1000,
  383. 'first_conv': 'conv_0.conv',
  384. 'classifier': 'last_linear',
  385. 'license': 'apache-2.0',
  386. },
  387. })
  388. @register_model
  389. def pnasnet5large(pretrained=False, **kwargs) -> PNASNet5Large:
  390. r"""PNASNet-5 model architecture from the
  391. `"Progressive Neural Architecture Search"
  392. <https://arxiv.org/abs/1712.00559>`_ paper.
  393. """
  394. model_kwargs = dict(pad_type='same', **kwargs)
  395. return _create_pnasnet('pnasnet5large', pretrained, **model_kwargs)