nasnet.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719
  1. """ NasNet-A (Large)
  2. nasnetalarge implementation grabbed from Cadene's pretrained models
  3. https://github.com/Cadene/pretrained-models.pytorch
  4. """
  5. from functools import partial
  6. from typing import Optional, Type
  7. import torch
  8. import torch.nn as nn
  9. from timm.layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier
  10. from ._builder import build_model_with_cfg
  11. from ._registry import register_model, generate_default_cfgs
  12. __all__ = ['NASNetALarge']
  13. class ActConvBn(nn.Module):
  14. def __init__(
  15. self,
  16. in_channels: int,
  17. out_channels: int,
  18. kernel_size: int,
  19. stride: int = 1,
  20. padding: str = '',
  21. device=None,
  22. dtype=None,
  23. ):
  24. dd = {'device': device, 'dtype': dtype}
  25. super().__init__()
  26. self.act = nn.ReLU()
  27. self.conv = create_conv2d(
  28. in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, **dd)
  29. self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, **dd)
  30. def forward(self, x):
  31. x = self.act(x)
  32. x = self.conv(x)
  33. x = self.bn(x)
  34. return x
  35. class SeparableConv2d(nn.Module):
  36. def __init__(
  37. self,
  38. in_channels: int,
  39. out_channels: int,
  40. kernel_size: int,
  41. stride: int,
  42. padding: str = '',
  43. device=None,
  44. dtype=None,
  45. ):
  46. dd = {'device': device, 'dtype': dtype}
  47. super().__init__()
  48. self.depthwise_conv2d = create_conv2d(
  49. in_channels,
  50. in_channels,
  51. kernel_size=kernel_size,
  52. stride=stride,
  53. padding=padding,
  54. groups=in_channels,
  55. **dd,
  56. )
  57. self.pointwise_conv2d = create_conv2d(
  58. in_channels,
  59. out_channels,
  60. kernel_size=1,
  61. padding=0,
  62. **dd,
  63. )
  64. def forward(self, x):
  65. x = self.depthwise_conv2d(x)
  66. x = self.pointwise_conv2d(x)
  67. return x
  68. class BranchSeparables(nn.Module):
  69. def __init__(
  70. self,
  71. in_channels: int,
  72. out_channels: int,
  73. kernel_size: int,
  74. stride: int = 1,
  75. pad_type: str = '',
  76. stem_cell: bool = False,
  77. device=None,
  78. dtype=None,
  79. ):
  80. dd = {'device': device, 'dtype': dtype}
  81. super().__init__()
  82. middle_channels = out_channels if stem_cell else in_channels
  83. self.act_1 = nn.ReLU()
  84. self.separable_1 = SeparableConv2d(
  85. in_channels, middle_channels, kernel_size, stride=stride, padding=pad_type, **dd)
  86. self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001, momentum=0.1, **dd)
  87. self.act_2 = nn.ReLU(inplace=True)
  88. self.separable_2 = SeparableConv2d(
  89. middle_channels, out_channels, kernel_size, stride=1, padding=pad_type, **dd)
  90. self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, **dd)
  91. def forward(self, x):
  92. x = self.act_1(x)
  93. x = self.separable_1(x)
  94. x = self.bn_sep_1(x)
  95. x = self.act_2(x)
  96. x = self.separable_2(x)
  97. x = self.bn_sep_2(x)
  98. return x
  99. class CellStem0(nn.Module):
  100. def __init__(
  101. self,
  102. stem_size: int,
  103. num_channels: int = 42,
  104. pad_type: str = '',
  105. device=None,
  106. dtype=None,
  107. ):
  108. dd = {'device': device, 'dtype': dtype}
  109. super().__init__()
  110. self.num_channels = num_channels
  111. self.stem_size = stem_size
  112. self.conv_1x1 = ActConvBn(self.stem_size, self.num_channels, 1, stride=1, **dd)
  113. self.comb_iter_0_left = BranchSeparables(
  114. self.num_channels, self.num_channels, 5, 2, pad_type, **dd)
  115. self.comb_iter_0_right = BranchSeparables(
  116. self.stem_size, self.num_channels, 7, 2, pad_type, stem_cell=True, **dd)
  117. self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
  118. self.comb_iter_1_right = BranchSeparables(
  119. self.stem_size, self.num_channels, 7, 2, pad_type, stem_cell=True, **dd)
  120. self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
  121. self.comb_iter_2_right = BranchSeparables(
  122. self.stem_size, self.num_channels, 5, 2, pad_type, stem_cell=True, **dd)
  123. self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
  124. self.comb_iter_4_left = BranchSeparables(
  125. self.num_channels, self.num_channels, 3, 1, pad_type, **dd)
  126. self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
  127. def forward(self, x):
  128. x1 = self.conv_1x1(x)
  129. x_comb_iter_0_left = self.comb_iter_0_left(x1)
  130. x_comb_iter_0_right = self.comb_iter_0_right(x)
  131. x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
  132. x_comb_iter_1_left = self.comb_iter_1_left(x1)
  133. x_comb_iter_1_right = self.comb_iter_1_right(x)
  134. x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
  135. x_comb_iter_2_left = self.comb_iter_2_left(x1)
  136. x_comb_iter_2_right = self.comb_iter_2_right(x)
  137. x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
  138. x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
  139. x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
  140. x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
  141. x_comb_iter_4_right = self.comb_iter_4_right(x1)
  142. x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
  143. x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
  144. return x_out
  145. class CellStem1(nn.Module):
  146. def __init__(
  147. self,
  148. stem_size: int,
  149. num_channels: int,
  150. pad_type: str = '',
  151. device=None,
  152. dtype=None,
  153. ):
  154. dd = {'device': device, 'dtype': dtype}
  155. super().__init__()
  156. self.num_channels = num_channels
  157. self.stem_size = stem_size
  158. self.conv_1x1 = ActConvBn(2 * self.num_channels, self.num_channels, 1, stride=1, **dd)
  159. self.act = nn.ReLU()
  160. self.path_1 = nn.Sequential()
  161. self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
  162. self.path_1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False, **dd))
  163. self.path_2 = nn.Sequential()
  164. self.path_2.add_module('pad', nn.ZeroPad2d((-1, 1, -1, 1)))
  165. self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
  166. self.path_2.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False, **dd))
  167. self.final_path_bn = nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1, **dd)
  168. self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type, **dd)
  169. self.comb_iter_0_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, pad_type, **dd)
  170. self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
  171. self.comb_iter_1_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, pad_type, **dd)
  172. self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
  173. self.comb_iter_2_right = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type, **dd)
  174. self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
  175. self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, pad_type, **dd)
  176. self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
  177. def forward(self, x_conv0, x_stem_0):
  178. x_left = self.conv_1x1(x_stem_0)
  179. x_relu = self.act(x_conv0)
  180. # path 1
  181. x_path1 = self.path_1(x_relu)
  182. # path 2
  183. x_path2 = self.path_2(x_relu)
  184. # final path
  185. x_right = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
  186. x_comb_iter_0_left = self.comb_iter_0_left(x_left)
  187. x_comb_iter_0_right = self.comb_iter_0_right(x_right)
  188. x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
  189. x_comb_iter_1_left = self.comb_iter_1_left(x_left)
  190. x_comb_iter_1_right = self.comb_iter_1_right(x_right)
  191. x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
  192. x_comb_iter_2_left = self.comb_iter_2_left(x_left)
  193. x_comb_iter_2_right = self.comb_iter_2_right(x_right)
  194. x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
  195. x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
  196. x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
  197. x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
  198. x_comb_iter_4_right = self.comb_iter_4_right(x_left)
  199. x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
  200. x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
  201. return x_out
  202. class FirstCell(nn.Module):
  203. def __init__(
  204. self,
  205. in_chs_left: int,
  206. out_chs_left: int,
  207. in_chs_right: int,
  208. out_chs_right: int,
  209. pad_type: str = '',
  210. device=None,
  211. dtype=None,
  212. ):
  213. dd = {'device': device, 'dtype': dtype}
  214. super().__init__()
  215. self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, **dd)
  216. self.act = nn.ReLU()
  217. self.path_1 = nn.Sequential()
  218. self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
  219. self.path_1.add_module('conv', nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False, **dd))
  220. self.path_2 = nn.Sequential()
  221. self.path_2.add_module('pad', nn.ZeroPad2d((-1, 1, -1, 1)))
  222. self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
  223. self.path_2.add_module('conv', nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False, **dd))
  224. self.final_path_bn = nn.BatchNorm2d(out_chs_left * 2, eps=0.001, momentum=0.1, **dd)
  225. self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type, **dd)
  226. self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type, **dd)
  227. self.comb_iter_1_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type, **dd)
  228. self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type, **dd)
  229. self.comb_iter_2_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
  230. self.comb_iter_3_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
  231. self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
  232. self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type, **dd)
  233. def forward(self, x, x_prev):
  234. x_relu = self.act(x_prev)
  235. x_path1 = self.path_1(x_relu)
  236. x_path2 = self.path_2(x_relu)
  237. x_left = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
  238. x_right = self.conv_1x1(x)
  239. x_comb_iter_0_left = self.comb_iter_0_left(x_right)
  240. x_comb_iter_0_right = self.comb_iter_0_right(x_left)
  241. x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
  242. x_comb_iter_1_left = self.comb_iter_1_left(x_left)
  243. x_comb_iter_1_right = self.comb_iter_1_right(x_left)
  244. x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
  245. x_comb_iter_2_left = self.comb_iter_2_left(x_right)
  246. x_comb_iter_2 = x_comb_iter_2_left + x_left
  247. x_comb_iter_3_left = self.comb_iter_3_left(x_left)
  248. x_comb_iter_3_right = self.comb_iter_3_right(x_left)
  249. x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right
  250. x_comb_iter_4_left = self.comb_iter_4_left(x_right)
  251. x_comb_iter_4 = x_comb_iter_4_left + x_right
  252. x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
  253. return x_out
  254. class NormalCell(nn.Module):
  255. def __init__(
  256. self,
  257. in_chs_left: int,
  258. out_chs_left: int,
  259. in_chs_right: int,
  260. out_chs_right: int,
  261. pad_type: str = '',
  262. device=None,
  263. dtype=None,
  264. ):
  265. dd = {'device': device, 'dtype': dtype}
  266. super().__init__()
  267. self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type, **dd)
  268. self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type, **dd)
  269. self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type, **dd)
  270. self.comb_iter_0_right = BranchSeparables(out_chs_left, out_chs_left, 3, 1, pad_type, **dd)
  271. self.comb_iter_1_left = BranchSeparables(out_chs_left, out_chs_left, 5, 1, pad_type, **dd)
  272. self.comb_iter_1_right = BranchSeparables(out_chs_left, out_chs_left, 3, 1, pad_type, **dd)
  273. self.comb_iter_2_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
  274. self.comb_iter_3_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
  275. self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
  276. self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type, **dd)
  277. def forward(self, x, x_prev):
  278. x_left = self.conv_prev_1x1(x_prev)
  279. x_right = self.conv_1x1(x)
  280. x_comb_iter_0_left = self.comb_iter_0_left(x_right)
  281. x_comb_iter_0_right = self.comb_iter_0_right(x_left)
  282. x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
  283. x_comb_iter_1_left = self.comb_iter_1_left(x_left)
  284. x_comb_iter_1_right = self.comb_iter_1_right(x_left)
  285. x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
  286. x_comb_iter_2_left = self.comb_iter_2_left(x_right)
  287. x_comb_iter_2 = x_comb_iter_2_left + x_left
  288. x_comb_iter_3_left = self.comb_iter_3_left(x_left)
  289. x_comb_iter_3_right = self.comb_iter_3_right(x_left)
  290. x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right
  291. x_comb_iter_4_left = self.comb_iter_4_left(x_right)
  292. x_comb_iter_4 = x_comb_iter_4_left + x_right
  293. x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
  294. return x_out
  295. class ReductionCell0(nn.Module):
  296. def __init__(
  297. self,
  298. in_chs_left: int,
  299. out_chs_left: int,
  300. in_chs_right: int,
  301. out_chs_right: int,
  302. pad_type: str = '',
  303. device=None,
  304. dtype=None,
  305. ):
  306. dd = {'device': device, 'dtype': dtype}
  307. super().__init__()
  308. self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type, **dd)
  309. self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type, **dd)
  310. self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type, **dd)
  311. self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type, **dd)
  312. self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
  313. self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type, **dd)
  314. self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
  315. self.comb_iter_2_right = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type, **dd)
  316. self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
  317. self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type, **dd)
  318. self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
  319. def forward(self, x, x_prev):
  320. x_left = self.conv_prev_1x1(x_prev)
  321. x_right = self.conv_1x1(x)
  322. x_comb_iter_0_left = self.comb_iter_0_left(x_right)
  323. x_comb_iter_0_right = self.comb_iter_0_right(x_left)
  324. x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
  325. x_comb_iter_1_left = self.comb_iter_1_left(x_right)
  326. x_comb_iter_1_right = self.comb_iter_1_right(x_left)
  327. x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
  328. x_comb_iter_2_left = self.comb_iter_2_left(x_right)
  329. x_comb_iter_2_right = self.comb_iter_2_right(x_left)
  330. x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
  331. x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
  332. x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
  333. x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
  334. x_comb_iter_4_right = self.comb_iter_4_right(x_right)
  335. x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
  336. x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
  337. return x_out
  338. class ReductionCell1(nn.Module):
  339. def __init__(
  340. self,
  341. in_chs_left: int,
  342. out_chs_left: int,
  343. in_chs_right: int,
  344. out_chs_right: int,
  345. pad_type: str = '',
  346. device=None,
  347. dtype=None,
  348. ):
  349. dd = {'device': device, 'dtype': dtype}
  350. super().__init__()
  351. self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type, **dd)
  352. self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type, **dd)
  353. self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type, **dd)
  354. self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type, **dd)
  355. self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
  356. self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type, **dd)
  357. self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
  358. self.comb_iter_2_right = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type, **dd)
  359. self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
  360. self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type, **dd)
  361. self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
  362. def forward(self, x, x_prev):
  363. x_left = self.conv_prev_1x1(x_prev)
  364. x_right = self.conv_1x1(x)
  365. x_comb_iter_0_left = self.comb_iter_0_left(x_right)
  366. x_comb_iter_0_right = self.comb_iter_0_right(x_left)
  367. x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
  368. x_comb_iter_1_left = self.comb_iter_1_left(x_right)
  369. x_comb_iter_1_right = self.comb_iter_1_right(x_left)
  370. x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
  371. x_comb_iter_2_left = self.comb_iter_2_left(x_right)
  372. x_comb_iter_2_right = self.comb_iter_2_right(x_left)
  373. x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
  374. x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
  375. x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
  376. x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
  377. x_comb_iter_4_right = self.comb_iter_4_right(x_right)
  378. x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
  379. x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
  380. return x_out
  381. class NASNetALarge(nn.Module):
  382. """NASNetALarge (6 @ 4032) """
  383. def __init__(
  384. self,
  385. num_classes: int = 1000,
  386. in_chans: int = 3,
  387. stem_size: int = 96,
  388. channel_multiplier: int = 2,
  389. num_features: int = 4032,
  390. output_stride: int = 32,
  391. drop_rate: float = 0.,
  392. global_pool: str = 'avg',
  393. pad_type: str = 'same',
  394. device=None,
  395. dtype=None,
  396. ):
  397. super().__init__()
  398. dd = {'device': device, 'dtype': dtype}
  399. self.num_classes = num_classes
  400. self.in_chans = in_chans
  401. self.stem_size = stem_size
  402. self.num_features = self.head_hidden_size = num_features
  403. self.channel_multiplier = channel_multiplier
  404. assert output_stride == 32
  405. channels = self.num_features // 24
  406. # 24 is default value for the architecture
  407. self.conv0 = ConvNormAct(
  408. in_channels=in_chans,
  409. out_channels=self.stem_size,
  410. kernel_size=3,
  411. padding=0,
  412. stride=2,
  413. norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1),
  414. apply_act=False,
  415. **dd,
  416. )
  417. self.cell_stem_0 = CellStem0(
  418. self.stem_size, num_channels=channels // (channel_multiplier ** 2), pad_type=pad_type, **dd)
  419. self.cell_stem_1 = CellStem1(
  420. self.stem_size, num_channels=channels // channel_multiplier, pad_type=pad_type, **dd)
  421. self.cell_0 = FirstCell(
  422. in_chs_left=channels, out_chs_left=channels // 2,
  423. in_chs_right=2 * channels, out_chs_right=channels, pad_type=pad_type, **dd)
  424. self.cell_1 = NormalCell(
  425. in_chs_left=2 * channels, out_chs_left=channels,
  426. in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type, **dd)
  427. self.cell_2 = NormalCell(
  428. in_chs_left=6 * channels, out_chs_left=channels,
  429. in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type, **dd)
  430. self.cell_3 = NormalCell(
  431. in_chs_left=6 * channels, out_chs_left=channels,
  432. in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type, **dd)
  433. self.cell_4 = NormalCell(
  434. in_chs_left=6 * channels, out_chs_left=channels,
  435. in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type, **dd)
  436. self.cell_5 = NormalCell(
  437. in_chs_left=6 * channels, out_chs_left=channels,
  438. in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type, **dd)
  439. self.reduction_cell_0 = ReductionCell0(
  440. in_chs_left=6 * channels, out_chs_left=2 * channels,
  441. in_chs_right=6 * channels, out_chs_right=2 * channels, pad_type=pad_type, **dd)
  442. self.cell_6 = FirstCell(
  443. in_chs_left=6 * channels, out_chs_left=channels,
  444. in_chs_right=8 * channels, out_chs_right=2 * channels, pad_type=pad_type, **dd)
  445. self.cell_7 = NormalCell(
  446. in_chs_left=8 * channels, out_chs_left=2 * channels,
  447. in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type, **dd)
  448. self.cell_8 = NormalCell(
  449. in_chs_left=12 * channels, out_chs_left=2 * channels,
  450. in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type, **dd)
  451. self.cell_9 = NormalCell(
  452. in_chs_left=12 * channels, out_chs_left=2 * channels,
  453. in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type, **dd)
  454. self.cell_10 = NormalCell(
  455. in_chs_left=12 * channels, out_chs_left=2 * channels,
  456. in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type, **dd)
  457. self.cell_11 = NormalCell(
  458. in_chs_left=12 * channels, out_chs_left=2 * channels,
  459. in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type, **dd)
  460. self.reduction_cell_1 = ReductionCell1(
  461. in_chs_left=12 * channels, out_chs_left=4 * channels,
  462. in_chs_right=12 * channels, out_chs_right=4 * channels, pad_type=pad_type, **dd)
  463. self.cell_12 = FirstCell(
  464. in_chs_left=12 * channels, out_chs_left=2 * channels,
  465. in_chs_right=16 * channels, out_chs_right=4 * channels, pad_type=pad_type, **dd)
  466. self.cell_13 = NormalCell(
  467. in_chs_left=16 * channels, out_chs_left=4 * channels,
  468. in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type, **dd)
  469. self.cell_14 = NormalCell(
  470. in_chs_left=24 * channels, out_chs_left=4 * channels,
  471. in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type, **dd)
  472. self.cell_15 = NormalCell(
  473. in_chs_left=24 * channels, out_chs_left=4 * channels,
  474. in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type, **dd)
  475. self.cell_16 = NormalCell(
  476. in_chs_left=24 * channels, out_chs_left=4 * channels,
  477. in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type, **dd)
  478. self.cell_17 = NormalCell(
  479. in_chs_left=24 * channels, out_chs_left=4 * channels,
  480. in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type, **dd)
  481. self.act = nn.ReLU(inplace=True)
  482. self.feature_info = [
  483. dict(num_chs=96, reduction=2, module='conv0'),
  484. dict(num_chs=168, reduction=4, module='cell_stem_1.conv_1x1.act'),
  485. dict(num_chs=1008, reduction=8, module='reduction_cell_0.conv_1x1.act'),
  486. dict(num_chs=2016, reduction=16, module='reduction_cell_1.conv_1x1.act'),
  487. dict(num_chs=4032, reduction=32, module='act'),
  488. ]
  489. self.global_pool, self.head_drop, self.last_linear = create_classifier(
  490. self.num_features,
  491. self.num_classes,
  492. pool_type=global_pool,
  493. drop_rate=drop_rate,
  494. **dd,
  495. )
  496. @torch.jit.ignore
  497. def group_matcher(self, coarse=False):
  498. matcher = dict(
  499. stem=r'^conv0|cell_stem_[01]',
  500. blocks=[
  501. (r'^cell_(\d+)', None),
  502. (r'^reduction_cell_0', (6,)),
  503. (r'^reduction_cell_1', (12,)),
  504. ]
  505. )
  506. return matcher
  507. @torch.jit.ignore
  508. def set_grad_checkpointing(self, enable=True):
  509. assert not enable, 'gradient checkpointing not supported'
  510. @torch.jit.ignore
  511. def get_classifier(self) -> nn.Module:
  512. return self.last_linear
  513. def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
  514. self.num_classes = num_classes
  515. self.global_pool, self.last_linear = create_classifier(
  516. self.num_features, self.num_classes, pool_type=global_pool)
  517. def forward_features(self, x):
  518. x_conv0 = self.conv0(x)
  519. x_stem_0 = self.cell_stem_0(x_conv0)
  520. x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0)
  521. x_cell_0 = self.cell_0(x_stem_1, x_stem_0)
  522. x_cell_1 = self.cell_1(x_cell_0, x_stem_1)
  523. x_cell_2 = self.cell_2(x_cell_1, x_cell_0)
  524. x_cell_3 = self.cell_3(x_cell_2, x_cell_1)
  525. x_cell_4 = self.cell_4(x_cell_3, x_cell_2)
  526. x_cell_5 = self.cell_5(x_cell_4, x_cell_3)
  527. x_reduction_cell_0 = self.reduction_cell_0(x_cell_5, x_cell_4)
  528. x_cell_6 = self.cell_6(x_reduction_cell_0, x_cell_4)
  529. x_cell_7 = self.cell_7(x_cell_6, x_reduction_cell_0)
  530. x_cell_8 = self.cell_8(x_cell_7, x_cell_6)
  531. x_cell_9 = self.cell_9(x_cell_8, x_cell_7)
  532. x_cell_10 = self.cell_10(x_cell_9, x_cell_8)
  533. x_cell_11 = self.cell_11(x_cell_10, x_cell_9)
  534. x_reduction_cell_1 = self.reduction_cell_1(x_cell_11, x_cell_10)
  535. x_cell_12 = self.cell_12(x_reduction_cell_1, x_cell_10)
  536. x_cell_13 = self.cell_13(x_cell_12, x_reduction_cell_1)
  537. x_cell_14 = self.cell_14(x_cell_13, x_cell_12)
  538. x_cell_15 = self.cell_15(x_cell_14, x_cell_13)
  539. x_cell_16 = self.cell_16(x_cell_15, x_cell_14)
  540. x_cell_17 = self.cell_17(x_cell_16, x_cell_15)
  541. x = self.act(x_cell_17)
  542. return x
  543. def forward_head(self, x, pre_logits: bool = False):
  544. x = self.global_pool(x)
  545. x = self.head_drop(x)
  546. return x if pre_logits else self.last_linear(x)
  547. def forward(self, x):
  548. x = self.forward_features(x)
  549. x = self.forward_head(x)
  550. return x
  551. def _create_nasnet(variant, pretrained=False, **kwargs):
  552. return build_model_with_cfg(
  553. NASNetALarge,
  554. variant,
  555. pretrained,
  556. feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model
  557. **kwargs,
  558. )
  559. default_cfgs = generate_default_cfgs({
  560. 'nasnetalarge.tf_in1k': {
  561. 'hf_hub_id': 'timm/',
  562. 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nasnetalarge-dc4a7b8b.pth',
  563. 'input_size': (3, 331, 331),
  564. 'pool_size': (11, 11),
  565. 'crop_pct': 0.911,
  566. 'interpolation': 'bicubic',
  567. 'mean': (0.5, 0.5, 0.5),
  568. 'std': (0.5, 0.5, 0.5),
  569. 'num_classes': 1000,
  570. 'first_conv': 'conv0.conv',
  571. 'classifier': 'last_linear',
  572. 'license': 'apache-2.0',
  573. },
  574. })
  575. @register_model
  576. def nasnetalarge(pretrained=False, **kwargs) -> NASNetALarge:
  577. """NASNet-A large model architecture.
  578. """
  579. model_kwargs = dict(pad_type='same', **kwargs)
  580. return _create_nasnet('nasnetalarge', pretrained, **model_kwargs)