_efficientnet_blocks.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761
  1. """ EfficientNet, MobileNetV3, etc Blocks
  2. Hacked together by / Copyright 2019, Ross Wightman
  3. """
  4. from typing import Callable, Dict, Optional, Type, Union
  5. import torch
  6. import torch.nn as nn
  7. from torch.nn import functional as F
  8. from timm.layers import (
  9. create_conv2d,
  10. DropPath,
  11. make_divisible,
  12. create_act_layer,
  13. create_aa,
  14. to_2tuple,
  15. LayerType,
  16. ConvNormAct,
  17. get_norm_act_layer,
  18. MultiQueryAttention2d,
  19. Attention2d,
  20. LayerScale2d,
  21. )
  22. __all__ = [
  23. 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual',
  24. 'UniversalInvertedResidual', 'MobileAttention'
  25. ]
  26. ModuleType = Type[nn.Module]
  27. def num_groups(group_size: Optional[int], channels: int):
  28. if not group_size: # 0 or None
  29. return 1 # normal conv with 1 group
  30. else:
  31. # NOTE group_size == 1 -> depthwise conv
  32. assert channels % group_size == 0
  33. return channels // group_size
  34. class SqueezeExcite(nn.Module):
  35. """ Squeeze-and-Excitation w/ specific features for EfficientNet/MobileNet family
  36. Args:
  37. in_chs (int): input channels to layer
  38. rd_ratio (float): ratio of squeeze reduction
  39. act_layer (nn.Module): activation layer of containing block
  40. gate_layer (Callable): attention gate function
  41. force_act_layer (nn.Module): override block's activation fn if this is set/bound
  42. rd_round_fn (Callable): specify a fn to calculate rounding of reduced chs
  43. """
  44. def __init__(
  45. self,
  46. in_chs: int,
  47. rd_ratio: float = 0.25,
  48. rd_channels: Optional[int] = None,
  49. act_layer: LayerType = nn.ReLU,
  50. gate_layer: LayerType = nn.Sigmoid,
  51. force_act_layer: Optional[LayerType] = None,
  52. rd_round_fn: Optional[Callable] = None,
  53. device=None,
  54. dtype=None,
  55. ):
  56. dd = {'device': device, 'dtype': dtype}
  57. super().__init__()
  58. if rd_channels is None:
  59. rd_round_fn = rd_round_fn or round
  60. rd_channels = rd_round_fn(in_chs * rd_ratio)
  61. act_layer = force_act_layer or act_layer
  62. self.conv_reduce = nn.Conv2d(in_chs, rd_channels, 1, bias=True, **dd)
  63. self.act1 = create_act_layer(act_layer, inplace=True)
  64. self.conv_expand = nn.Conv2d(rd_channels, in_chs, 1, bias=True, **dd)
  65. self.gate = create_act_layer(gate_layer)
  66. def forward(self, x):
  67. x_se = x.mean((2, 3), keepdim=True)
  68. x_se = self.conv_reduce(x_se)
  69. x_se = self.act1(x_se)
  70. x_se = self.conv_expand(x_se)
  71. return x * self.gate(x_se)
  72. class ConvBnAct(nn.Module):
  73. """ Conv + Norm Layer + Activation w/ optional skip connection
  74. """
  75. def __init__(
  76. self,
  77. in_chs: int,
  78. out_chs: int,
  79. kernel_size: int,
  80. stride: int = 1,
  81. dilation: int = 1,
  82. group_size: int = 0,
  83. pad_type: Union[int, str] = '',
  84. skip: bool = False,
  85. act_layer: Optional[LayerType] = nn.ReLU,
  86. norm_layer: LayerType = nn.BatchNorm2d,
  87. aa_layer: Optional[LayerType] = None,
  88. drop_path_rate: float = 0.,
  89. device=None,
  90. dtype=None,
  91. ):
  92. dd = {'device': device, 'dtype': dtype}
  93. super().__init__()
  94. norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
  95. groups = num_groups(group_size, in_chs)
  96. self.has_skip = skip and stride == 1 and in_chs == out_chs
  97. use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation
  98. self.conv = create_conv2d(
  99. in_chs,
  100. out_chs,
  101. kernel_size,
  102. stride=1 if use_aa else stride,
  103. dilation=dilation,
  104. groups=groups,
  105. padding=pad_type,
  106. **dd,
  107. )
  108. self.bn1 = norm_act_layer(out_chs, inplace=True, **dd)
  109. self.aa = create_aa(aa_layer, channels=out_chs, stride=stride, enable=use_aa, **dd)
  110. self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
  111. def feature_info(self, location):
  112. if location == 'expansion': # output of conv after act, same as block coutput
  113. return dict(module='bn1', hook_type='forward', num_chs=self.conv.out_channels)
  114. else: # location == 'bottleneck', block output
  115. return dict(module='', num_chs=self.conv.out_channels)
  116. def forward(self, x):
  117. shortcut = x
  118. x = self.conv(x)
  119. x = self.bn1(x)
  120. x = self.aa(x)
  121. if self.has_skip:
  122. x = self.drop_path(x) + shortcut
  123. return x
  124. class DepthwiseSeparableConv(nn.Module):
  125. """ Depthwise-separable block
  126. Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion
  127. (factor of 1.0). This is an alternative to having a IR with an optional first pw conv.
  128. """
  129. def __init__(
  130. self,
  131. in_chs: int,
  132. out_chs: int,
  133. dw_kernel_size: int = 3,
  134. stride: int = 1,
  135. dilation: int = 1,
  136. group_size: int = 1,
  137. pad_type: str = '',
  138. noskip: bool = False,
  139. pw_kernel_size: int = 1,
  140. pw_act: bool = False,
  141. s2d: int = 0,
  142. act_layer: LayerType = nn.ReLU,
  143. norm_layer: LayerType = nn.BatchNorm2d,
  144. aa_layer: Optional[LayerType] = None,
  145. se_layer: Optional[ModuleType] = None,
  146. drop_path_rate: float = 0.,
  147. device=None,
  148. dtype=None,
  149. ):
  150. dd = {'device': device, 'dtype': dtype}
  151. super().__init__()
  152. norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
  153. self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip
  154. self.has_pw_act = pw_act # activation after point-wise conv
  155. use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation
  156. # Space to depth
  157. if s2d == 1:
  158. sd_chs = int(in_chs * 4)
  159. self.conv_s2d = create_conv2d(in_chs, sd_chs, kernel_size=2, stride=2, padding='same', **dd)
  160. self.bn_s2d = norm_act_layer(sd_chs, **dd)
  161. dw_kernel_size = (dw_kernel_size + 1) // 2
  162. dw_pad_type = 'same' if dw_kernel_size == 2 else pad_type
  163. in_chs = sd_chs
  164. use_aa = False # disable AA
  165. else:
  166. self.conv_s2d = None
  167. self.bn_s2d = None
  168. dw_pad_type = pad_type
  169. groups = num_groups(group_size, in_chs)
  170. self.conv_dw = create_conv2d(
  171. in_chs,
  172. in_chs,
  173. dw_kernel_size,
  174. stride=1 if use_aa else stride,
  175. dilation=dilation,
  176. padding=dw_pad_type,
  177. groups=groups,
  178. **dd,
  179. )
  180. self.bn1 = norm_act_layer(in_chs, inplace=True, **dd)
  181. self.aa = create_aa(aa_layer, channels=out_chs, stride=stride, enable=use_aa, **dd)
  182. # Squeeze-and-excitation
  183. self.se = se_layer(in_chs, act_layer=act_layer, **dd) if se_layer else nn.Identity()
  184. self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type, **dd)
  185. self.bn2 = norm_act_layer(out_chs, inplace=True, apply_act=self.has_pw_act, **dd)
  186. self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
  187. def feature_info(self, location):
  188. if location == 'expansion': # after SE, input to PW
  189. return dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels)
  190. else: # location == 'bottleneck', block output
  191. return dict(module='', num_chs=self.conv_pw.out_channels)
  192. def forward(self, x):
  193. shortcut = x
  194. if self.conv_s2d is not None:
  195. x = self.conv_s2d(x)
  196. x = self.bn_s2d(x)
  197. x = self.conv_dw(x)
  198. x = self.bn1(x)
  199. x = self.aa(x)
  200. x = self.se(x)
  201. x = self.conv_pw(x)
  202. x = self.bn2(x)
  203. if self.has_skip:
  204. x = self.drop_path(x) + shortcut
  205. return x
  206. class InvertedResidual(nn.Module):
  207. """ Inverted residual block w/ optional SE
  208. Originally used in MobileNet-V2 - https://arxiv.org/abs/1801.04381v4, this layer is often
  209. referred to as 'MBConv' for (Mobile inverted bottleneck conv) and is also used in
  210. * MNasNet - https://arxiv.org/abs/1807.11626
  211. * EfficientNet - https://arxiv.org/abs/1905.11946
  212. * MobileNet-V3 - https://arxiv.org/abs/1905.02244
  213. """
  214. def __init__(
  215. self,
  216. in_chs: int,
  217. out_chs: int,
  218. dw_kernel_size: int = 3,
  219. stride: int = 1,
  220. dilation: int = 1,
  221. group_size: int = 1,
  222. pad_type: str = '',
  223. noskip: bool = False,
  224. exp_ratio: float = 1.0,
  225. exp_kernel_size: int = 1,
  226. pw_kernel_size: int = 1,
  227. s2d: int = 0,
  228. act_layer: LayerType = nn.ReLU,
  229. norm_layer: LayerType = nn.BatchNorm2d,
  230. aa_layer: Optional[LayerType] = None,
  231. se_layer: Optional[ModuleType] = None,
  232. conv_kwargs: Optional[Dict] = None,
  233. drop_path_rate: float = 0.,
  234. device=None,
  235. dtype=None,
  236. ):
  237. dd = {'device': device, 'dtype': dtype}
  238. super().__init__()
  239. norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
  240. conv_kwargs = conv_kwargs or {}
  241. self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
  242. use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation
  243. # Space to depth
  244. if s2d == 1:
  245. sd_chs = int(in_chs * 4)
  246. self.conv_s2d = create_conv2d(in_chs, sd_chs, kernel_size=2, stride=2, padding='same', **dd)
  247. self.bn_s2d = norm_act_layer(sd_chs, **dd)
  248. dw_kernel_size = (dw_kernel_size + 1) // 2
  249. dw_pad_type = 'same' if dw_kernel_size == 2 else pad_type
  250. in_chs = sd_chs
  251. use_aa = False # disable AA
  252. else:
  253. self.conv_s2d = None
  254. self.bn_s2d = None
  255. dw_pad_type = pad_type
  256. mid_chs = make_divisible(in_chs * exp_ratio)
  257. groups = num_groups(group_size, mid_chs)
  258. # Point-wise expansion
  259. self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs, **dd)
  260. self.bn1 = norm_act_layer(mid_chs, inplace=True, **dd)
  261. # Depth-wise convolution
  262. self.conv_dw = create_conv2d(
  263. mid_chs,
  264. mid_chs,
  265. dw_kernel_size,
  266. stride=1 if use_aa else stride,
  267. dilation=dilation,
  268. groups=groups,
  269. padding=dw_pad_type,
  270. **conv_kwargs,
  271. **dd,
  272. )
  273. self.bn2 = norm_act_layer(mid_chs, inplace=True, **dd)
  274. self.aa = create_aa(aa_layer, channels=mid_chs, stride=stride, enable=use_aa, **dd)
  275. # Squeeze-and-excitation
  276. self.se = se_layer(mid_chs, act_layer=act_layer, **dd) if se_layer else nn.Identity()
  277. # Point-wise linear projection
  278. self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs, **dd)
  279. self.bn3 = norm_act_layer(out_chs, apply_act=False, **dd)
  280. self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
  281. def feature_info(self, location):
  282. if location == 'expansion': # after SE, input to PWL
  283. return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
  284. else: # location == 'bottleneck', block output
  285. return dict(module='', num_chs=self.conv_pwl.out_channels)
  286. def forward(self, x):
  287. shortcut = x
  288. if self.conv_s2d is not None:
  289. x = self.conv_s2d(x)
  290. x = self.bn_s2d(x)
  291. x = self.conv_pw(x)
  292. x = self.bn1(x)
  293. x = self.conv_dw(x)
  294. x = self.bn2(x)
  295. x = self.aa(x)
  296. x = self.se(x)
  297. x = self.conv_pwl(x)
  298. x = self.bn3(x)
  299. if self.has_skip:
  300. x = self.drop_path(x) + shortcut
  301. return x
  302. class UniversalInvertedResidual(nn.Module):
  303. """ Universal Inverted Residual Block (aka Universal Inverted Bottleneck, UIB)
  304. For MobileNetV4 - https://arxiv.org/abs/, referenced from
  305. https://github.com/tensorflow/models/blob/d93c7e932de27522b2fa3b115f58d06d6f640537/official/vision/modeling/layers/nn_blocks.py#L778
  306. """
  307. def __init__(
  308. self,
  309. in_chs: int,
  310. out_chs: int,
  311. dw_kernel_size_start: int = 0,
  312. dw_kernel_size_mid: int = 3,
  313. dw_kernel_size_end: int = 0,
  314. stride: int = 1,
  315. dilation: int = 1,
  316. group_size: int = 1,
  317. pad_type: str = '',
  318. noskip: bool = False,
  319. exp_ratio: float = 1.0,
  320. act_layer: LayerType = nn.ReLU,
  321. norm_layer: LayerType = nn.BatchNorm2d,
  322. aa_layer: Optional[LayerType] = None,
  323. se_layer: Optional[ModuleType] = None,
  324. conv_kwargs: Optional[Dict] = None,
  325. drop_path_rate: float = 0.,
  326. layer_scale_init_value: Optional[float] = 1e-5,
  327. device=None,
  328. dtype=None,
  329. ):
  330. dd = {'device': device, 'dtype': dtype}
  331. super().__init__()
  332. conv_kwargs = conv_kwargs or {}
  333. self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
  334. if stride > 1:
  335. assert dw_kernel_size_start or dw_kernel_size_mid or dw_kernel_size_end
  336. # FIXME dilation isn't right w/ extra ks > 1 convs
  337. if dw_kernel_size_start:
  338. dw_start_stride = stride if not dw_kernel_size_mid else 1
  339. dw_start_groups = num_groups(group_size, in_chs)
  340. self.dw_start = ConvNormAct(
  341. in_chs, in_chs, dw_kernel_size_start,
  342. stride=dw_start_stride,
  343. dilation=dilation, # FIXME
  344. groups=dw_start_groups,
  345. padding=pad_type,
  346. apply_act=False,
  347. act_layer=act_layer,
  348. norm_layer=norm_layer,
  349. aa_layer=aa_layer,
  350. **conv_kwargs,
  351. **dd,
  352. )
  353. else:
  354. self.dw_start = nn.Identity()
  355. # Point-wise expansion
  356. mid_chs = make_divisible(in_chs * exp_ratio)
  357. self.pw_exp = ConvNormAct(
  358. in_chs, mid_chs, 1,
  359. padding=pad_type,
  360. act_layer=act_layer,
  361. norm_layer=norm_layer,
  362. **conv_kwargs,
  363. **dd,
  364. )
  365. # Middle depth-wise convolution
  366. if dw_kernel_size_mid:
  367. groups = num_groups(group_size, mid_chs)
  368. self.dw_mid = ConvNormAct(
  369. mid_chs, mid_chs, dw_kernel_size_mid,
  370. stride=stride,
  371. dilation=dilation, # FIXME
  372. groups=groups,
  373. padding=pad_type,
  374. act_layer=act_layer,
  375. norm_layer=norm_layer,
  376. aa_layer=aa_layer,
  377. **conv_kwargs,
  378. **dd,
  379. )
  380. else:
  381. # keeping mid as identity so it can be hooked more easily for features
  382. self.dw_mid = nn.Identity()
  383. # Squeeze-and-excitation
  384. self.se = se_layer(mid_chs, act_layer=act_layer, **dd) if se_layer else nn.Identity()
  385. # Point-wise linear projection
  386. self.pw_proj = ConvNormAct(
  387. mid_chs, out_chs, 1,
  388. padding=pad_type,
  389. apply_act=False,
  390. act_layer=act_layer,
  391. norm_layer=norm_layer,
  392. **conv_kwargs,
  393. **dd,
  394. )
  395. if dw_kernel_size_end:
  396. dw_end_stride = stride if not dw_kernel_size_start and not dw_kernel_size_mid else 1
  397. dw_end_groups = num_groups(group_size, out_chs)
  398. if dw_end_stride > 1:
  399. assert not aa_layer
  400. self.dw_end = ConvNormAct(
  401. out_chs, out_chs, dw_kernel_size_end,
  402. stride=dw_end_stride,
  403. dilation=dilation,
  404. groups=dw_end_groups,
  405. padding=pad_type,
  406. apply_act=False,
  407. act_layer=act_layer,
  408. norm_layer=norm_layer,
  409. **conv_kwargs,
  410. **dd,
  411. )
  412. else:
  413. self.dw_end = nn.Identity()
  414. if layer_scale_init_value is not None:
  415. self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value, **dd)
  416. else:
  417. self.layer_scale = nn.Identity()
  418. self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
  419. def feature_info(self, location):
  420. if location == 'expansion': # after SE, input to PWL
  421. return dict(module='pw_proj.conv', hook_type='forward_pre', num_chs=self.pw_proj.conv.in_channels)
  422. else: # location == 'bottleneck', block output
  423. return dict(module='', num_chs=self.pw_proj.conv.out_channels)
  424. def forward(self, x):
  425. shortcut = x
  426. x = self.dw_start(x)
  427. x = self.pw_exp(x)
  428. x = self.dw_mid(x)
  429. x = self.se(x)
  430. x = self.pw_proj(x)
  431. x = self.dw_end(x)
  432. x = self.layer_scale(x)
  433. if self.has_skip:
  434. x = self.drop_path(x) + shortcut
  435. return x
  436. class MobileAttention(nn.Module):
  437. """ Mobile Attention Block
  438. For MobileNetV4 - https://arxiv.org/abs/, referenced from
  439. https://github.com/tensorflow/models/blob/d93c7e932de27522b2fa3b115f58d06d6f640537/official/vision/modeling/layers/nn_blocks.py#L1504
  440. """
  441. def __init__(
  442. self,
  443. in_chs: int,
  444. out_chs: int,
  445. stride: int = 1,
  446. dw_kernel_size: int = 3,
  447. dilation: int = 1,
  448. group_size: int = 1,
  449. pad_type: str = '',
  450. num_heads: int = 8,
  451. key_dim: int = 64,
  452. value_dim: int = 64,
  453. use_multi_query: bool = False,
  454. query_strides: int = (1, 1),
  455. kv_stride: int = 1,
  456. cpe_dw_kernel_size: int = 3,
  457. noskip: bool = False,
  458. act_layer: LayerType = nn.ReLU,
  459. norm_layer: LayerType = nn.BatchNorm2d,
  460. aa_layer: Optional[LayerType] = None,
  461. drop_path_rate: float = 0.,
  462. attn_drop: float = 0.0,
  463. proj_drop: float = 0.0,
  464. layer_scale_init_value: Optional[float] = 1e-5,
  465. use_bias: bool = False,
  466. use_cpe: bool = False,
  467. device=None,
  468. dtype=None,
  469. ):
  470. dd = {'device': device, 'dtype': dtype}
  471. super().__init__()
  472. norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
  473. self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip
  474. self.query_strides = to_2tuple(query_strides)
  475. self.kv_stride = kv_stride
  476. self.has_query_stride = any([s > 1 for s in self.query_strides])
  477. # This CPE is different than the one suggested in the original paper.
  478. # https://arxiv.org/abs/2102.10882
  479. # 1. Rather than adding one CPE before the attention blocks, we add a CPE
  480. # into every attention block.
  481. # 2. We replace the expensive Conv2D by a Separable DW Conv.
  482. if use_cpe:
  483. self.conv_cpe_dw = create_conv2d(
  484. in_chs, in_chs,
  485. kernel_size=cpe_dw_kernel_size,
  486. dilation=dilation,
  487. depthwise=True,
  488. bias=True,
  489. **dd,
  490. )
  491. else:
  492. self.conv_cpe_dw = None
  493. self.norm = norm_act_layer(in_chs, apply_act=False, **dd)
  494. if num_heads is None:
  495. assert in_chs % key_dim == 0
  496. num_heads = in_chs // key_dim
  497. if use_multi_query:
  498. self.attn = MultiQueryAttention2d(
  499. in_chs,
  500. dim_out=out_chs,
  501. num_heads=num_heads,
  502. key_dim=key_dim,
  503. value_dim=value_dim,
  504. query_strides=query_strides,
  505. kv_stride=kv_stride,
  506. dw_kernel_size=dw_kernel_size,
  507. dilation=dilation,
  508. padding=pad_type,
  509. attn_drop=attn_drop,
  510. proj_drop=proj_drop,
  511. norm_layer=norm_layer,
  512. # use_bias=use_bias, # why not here if used w/ mhsa?
  513. **dd,
  514. )
  515. else:
  516. self.attn = Attention2d(
  517. in_chs,
  518. dim_out=out_chs,
  519. num_heads=num_heads,
  520. attn_drop=attn_drop,
  521. proj_drop=proj_drop,
  522. bias=use_bias,
  523. **dd,
  524. )
  525. if layer_scale_init_value is not None:
  526. self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value, **dd)
  527. else:
  528. self.layer_scale = nn.Identity()
  529. self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
  530. def feature_info(self, location):
  531. if location == 'expansion': # after SE, input to PW
  532. return dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels)
  533. else: # location == 'bottleneck', block output
  534. return dict(module='', num_chs=self.conv_pw.out_channels)
  535. def forward(self, x):
  536. if self.conv_cpe_dw is not None:
  537. x_cpe = self.conv_cpe_dw(x)
  538. x = x + x_cpe
  539. shortcut = x
  540. x = self.norm(x)
  541. x = self.attn(x)
  542. x = self.layer_scale(x)
  543. if self.has_skip:
  544. x = self.drop_path(x) + shortcut
  545. return x
  546. class CondConvResidual(InvertedResidual):
  547. """ Inverted residual block w/ CondConv routing"""
  548. def __init__(
  549. self,
  550. in_chs: int,
  551. out_chs: int,
  552. dw_kernel_size: int = 3,
  553. stride: int = 1,
  554. dilation: int = 1,
  555. group_size: int = 1,
  556. pad_type: str = '',
  557. noskip: bool = False,
  558. exp_ratio: float = 1.0,
  559. exp_kernel_size: int = 1,
  560. pw_kernel_size: int = 1,
  561. act_layer: LayerType = nn.ReLU,
  562. norm_layer: LayerType = nn.BatchNorm2d,
  563. aa_layer: Optional[LayerType] = None,
  564. se_layer: Optional[ModuleType] = None,
  565. num_experts: int = 0,
  566. drop_path_rate: float = 0.,
  567. device=None,
  568. dtype=None,
  569. ):
  570. dd = {'device': device, 'dtype': dtype}
  571. self.num_experts = num_experts
  572. conv_kwargs = dict(num_experts=self.num_experts)
  573. super().__init__(
  574. in_chs,
  575. out_chs,
  576. dw_kernel_size=dw_kernel_size,
  577. stride=stride,
  578. dilation=dilation,
  579. group_size=group_size,
  580. pad_type=pad_type,
  581. noskip=noskip,
  582. exp_ratio=exp_ratio,
  583. exp_kernel_size=exp_kernel_size,
  584. pw_kernel_size=pw_kernel_size,
  585. act_layer=act_layer,
  586. norm_layer=norm_layer,
  587. aa_layer=aa_layer,
  588. se_layer=se_layer,
  589. conv_kwargs=conv_kwargs,
  590. drop_path_rate=drop_path_rate,
  591. **dd,
  592. )
  593. self.routing_fn = nn.Linear(in_chs, self.num_experts, **dd)
  594. def forward(self, x):
  595. shortcut = x
  596. pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) # CondConv routing
  597. routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs))
  598. x = self.conv_pw(x, routing_weights)
  599. x = self.bn1(x)
  600. x = self.conv_dw(x, routing_weights)
  601. x = self.bn2(x)
  602. x = self.se(x)
  603. x = self.conv_pwl(x, routing_weights)
  604. x = self.bn3(x)
  605. if self.has_skip:
  606. x = self.drop_path(x) + shortcut
  607. return x
  608. class EdgeResidual(nn.Module):
  609. """ Residual block with expansion convolution followed by pointwise-linear w/ stride
  610. Originally introduced in `EfficientNet-EdgeTPU: Creating Accelerator-Optimized Neural Networks with AutoML`
  611. - https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html
  612. This layer is also called FusedMBConv in the MobileDet, EfficientNet-X, and EfficientNet-V2 papers
  613. * MobileDet - https://arxiv.org/abs/2004.14525
  614. * EfficientNet-X - https://arxiv.org/abs/2102.05610
  615. * EfficientNet-V2 - https://arxiv.org/abs/2104.00298
  616. """
  617. def __init__(
  618. self,
  619. in_chs: int,
  620. out_chs: int,
  621. exp_kernel_size: int = 3,
  622. stride: int = 1,
  623. dilation: int = 1,
  624. group_size: int = 0,
  625. pad_type: str = '',
  626. force_in_chs: int = 0,
  627. noskip: bool = False,
  628. exp_ratio: float = 1.0,
  629. pw_kernel_size: int = 1,
  630. act_layer: LayerType = nn.ReLU,
  631. norm_layer: LayerType = nn.BatchNorm2d,
  632. aa_layer: Optional[LayerType] = None,
  633. se_layer: Optional[ModuleType] = None,
  634. drop_path_rate: float = 0.,
  635. device=None,
  636. dtype=None,
  637. ):
  638. dd = {'device': device, 'dtype': dtype}
  639. super().__init__()
  640. norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
  641. if force_in_chs > 0:
  642. mid_chs = make_divisible(force_in_chs * exp_ratio)
  643. else:
  644. mid_chs = make_divisible(in_chs * exp_ratio)
  645. groups = num_groups(group_size, mid_chs) # NOTE: Using out_chs of conv_exp for groups calc
  646. self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
  647. use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation
  648. # Expansion convolution
  649. self.conv_exp = create_conv2d(
  650. in_chs,
  651. mid_chs,
  652. exp_kernel_size,
  653. stride=1 if use_aa else stride,
  654. dilation=dilation,
  655. groups=groups,
  656. padding=pad_type,
  657. **dd,
  658. )
  659. self.bn1 = norm_act_layer(mid_chs, inplace=True, **dd)
  660. self.aa = create_aa(aa_layer, channels=mid_chs, stride=stride, enable=use_aa, **dd)
  661. # Squeeze-and-excitation
  662. self.se = se_layer(mid_chs, act_layer=act_layer, **dd) if se_layer else nn.Identity()
  663. # Point-wise linear projection
  664. self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **dd)
  665. self.bn2 = norm_act_layer(out_chs, apply_act=False, **dd)
  666. self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
  667. def feature_info(self, location):
  668. if location == 'expansion': # after SE, before PWL
  669. return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
  670. else: # location == 'bottleneck', block output
  671. return dict(module='', num_chs=self.conv_pwl.out_channels)
  672. def forward(self, x):
  673. shortcut = x
  674. x = self.conv_exp(x)
  675. x = self.bn1(x)
  676. x = self.aa(x)
  677. x = self.se(x)
  678. x = self.conv_pwl(x)
  679. x = self.bn2(x)
  680. if self.has_skip:
  681. x = self.drop_path(x) + shortcut
  682. return x