efficientformer_v2.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946
  1. """ EfficientFormer-V2
  2. @article{
  3. li2022rethinking,
  4. title={Rethinking Vision Transformers for MobileNet Size and Speed},
  5. author={Li, Yanyu and Hu, Ju and Wen, Yang and Evangelidis, Georgios and Salahi, Kamyar and Wang, Yanzhi and Tulyakov, Sergey and Ren, Jian},
  6. journal={arXiv preprint arXiv:2212.08059},
  7. year={2022}
  8. }
  9. Significantly refactored and cleaned up for timm from original at: https://github.com/snap-research/EfficientFormer
  10. Original code licensed Apache 2.0, Copyright (c) 2022 Snap Inc.
  11. Modifications and timm support by / Copyright 2023, Ross Wightman
  12. """
  13. import math
  14. from functools import partial
  15. from typing import Dict, List, Optional, Tuple, Type, Union
  16. import torch
  17. import torch.nn as nn
  18. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  19. from timm.layers import (
  20. create_conv2d,
  21. create_norm_layer,
  22. get_act_layer,
  23. get_norm_layer,
  24. ConvNormAct,
  25. LayerScale2d,
  26. DropPath,
  27. calculate_drop_path_rates,
  28. trunc_normal_,
  29. to_2tuple,
  30. to_ntuple,
  31. ndgrid,
  32. )
  33. from ._builder import build_model_with_cfg
  34. from ._features import feature_take_indices
  35. from ._manipulate import checkpoint_seq
  36. from ._registry import generate_default_cfgs, register_model
  37. __all__ = ['EfficientFormerV2']
  38. EfficientFormer_width = {
  39. 'L': (40, 80, 192, 384), # 26m 83.3% 6attn
  40. 'S2': (32, 64, 144, 288), # 12m 81.6% 4attn dp0.02
  41. 'S1': (32, 48, 120, 224), # 6.1m 79.0
  42. 'S0': (32, 48, 96, 176), # 75.0 75.7
  43. }
  44. EfficientFormer_depth = {
  45. 'L': (5, 5, 15, 10), # 26m 83.3%
  46. 'S2': (4, 4, 12, 8), # 12m
  47. 'S1': (3, 3, 9, 6), # 79.0
  48. 'S0': (2, 2, 6, 4), # 75.7
  49. }
  50. EfficientFormer_expansion_ratios = {
  51. 'L': (4, 4, (4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4), (4, 4, 4, 3, 3, 3, 3, 4, 4, 4)),
  52. 'S2': (4, 4, (4, 4, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4), (4, 4, 3, 3, 3, 3, 4, 4)),
  53. 'S1': (4, 4, (4, 4, 3, 3, 3, 3, 4, 4, 4), (4, 4, 3, 3, 4, 4)),
  54. 'S0': (4, 4, (4, 3, 3, 3, 4, 4), (4, 3, 3, 4)),
  55. }
  56. class ConvNorm(nn.Module):
  57. def __init__(
  58. self,
  59. in_channels: int,
  60. out_channels: int,
  61. kernel_size: int = 1,
  62. stride: int = 1,
  63. padding: Union[int, str] = '',
  64. dilation: int = 1,
  65. groups: int = 1,
  66. bias: bool = True,
  67. norm_layer: str = 'batchnorm2d',
  68. norm_kwargs: Optional[Dict] = None,
  69. device=None,
  70. dtype=None,
  71. ):
  72. dd = {'device': device, 'dtype': dtype}
  73. norm_kwargs = norm_kwargs or {}
  74. super().__init__()
  75. self.conv = create_conv2d(
  76. in_channels,
  77. out_channels,
  78. kernel_size,
  79. stride=stride,
  80. padding=padding,
  81. dilation=dilation,
  82. groups=groups,
  83. bias=bias,
  84. **dd,
  85. )
  86. self.bn = create_norm_layer(norm_layer, out_channels, **norm_kwargs, **dd)
  87. def forward(self, x):
  88. x = self.conv(x)
  89. x = self.bn(x)
  90. return x
  91. class Attention2d(torch.nn.Module):
  92. attention_bias_cache: Dict[str, torch.Tensor]
  93. def __init__(
  94. self,
  95. dim: int = 384,
  96. key_dim: int = 32,
  97. num_heads: int = 8,
  98. attn_ratio: int = 4,
  99. resolution: Union[int, Tuple[int, int]] = 7,
  100. act_layer: Type[nn.Module] = nn.GELU,
  101. stride: Optional[int] = None,
  102. device=None,
  103. dtype=None,
  104. ):
  105. dd = {'device': device, 'dtype': dtype}
  106. super().__init__()
  107. self.num_heads = num_heads
  108. self.scale = key_dim ** -0.5
  109. self.key_dim = key_dim
  110. resolution = to_2tuple(resolution)
  111. if stride is not None:
  112. resolution = tuple([math.ceil(r / stride) for r in resolution])
  113. self.stride_conv = ConvNorm(dim, dim, kernel_size=3, stride=stride, groups=dim, **dd)
  114. self.upsample = nn.Upsample(scale_factor=stride, mode='bilinear')
  115. else:
  116. self.stride_conv = None
  117. self.upsample = None
  118. self.resolution = resolution
  119. self.N = self.resolution[0] * self.resolution[1]
  120. self.d = int(attn_ratio * key_dim)
  121. self.dh = int(attn_ratio * key_dim) * num_heads
  122. self.attn_ratio = attn_ratio
  123. kh = self.key_dim * self.num_heads
  124. self.q = ConvNorm(dim, kh, **dd)
  125. self.k = ConvNorm(dim, kh, **dd)
  126. self.v = ConvNorm(dim, self.dh, **dd)
  127. self.v_local = ConvNorm(self.dh, self.dh, kernel_size=3, groups=self.dh, **dd)
  128. self.talking_head1 = nn.Conv2d(self.num_heads, self.num_heads, kernel_size=1, **dd)
  129. self.talking_head2 = nn.Conv2d(self.num_heads, self.num_heads, kernel_size=1, **dd)
  130. self.act = act_layer()
  131. self.proj = ConvNorm(self.dh, dim, 1, **dd)
  132. self.attention_biases = torch.nn.Parameter(torch.empty(num_heads, self.N, **dd))
  133. self.register_buffer(
  134. 'attention_bias_idxs',
  135. torch.empty((self.N, self.N), device=device, dtype=torch.long),
  136. persistent=False,
  137. )
  138. self.attention_bias_cache = {}
  139. # TODO: skip init when on meta device when safe to do so
  140. self.reset_parameters()
  141. @torch.no_grad()
  142. def train(self, mode=True):
  143. super().train(mode)
  144. if mode and self.attention_bias_cache:
  145. self.attention_bias_cache = {} # clear ab cache
  146. def reset_parameters(self) -> None:
  147. """Initialize parameters and buffers."""
  148. nn.init.zeros_(self.attention_biases)
  149. self._init_buffers()
  150. def _compute_attention_bias_idxs(self, device=None):
  151. """Compute relative position indices for attention bias."""
  152. pos = torch.stack(ndgrid(
  153. torch.arange(self.resolution[0], device=device, dtype=torch.long),
  154. torch.arange(self.resolution[1], device=device, dtype=torch.long),
  155. )).flatten(1)
  156. rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
  157. rel_pos = (rel_pos[0] * self.resolution[1]) + rel_pos[1]
  158. return rel_pos
  159. def _init_buffers(self) -> None:
  160. """Compute and fill non-persistent buffer values."""
  161. self.attention_bias_idxs.copy_(
  162. self._compute_attention_bias_idxs(device=self.attention_bias_idxs.device)
  163. )
  164. self.attention_bias_cache = {}
  165. def init_non_persistent_buffers(self) -> None:
  166. """Initialize non-persistent buffers."""
  167. self._init_buffers()
  168. def get_attention_biases(self, device: torch.device) -> torch.Tensor:
  169. if torch.jit.is_tracing() or self.training:
  170. return self.attention_biases[:, self.attention_bias_idxs]
  171. else:
  172. device_key = str(device)
  173. if device_key not in self.attention_bias_cache:
  174. self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
  175. return self.attention_bias_cache[device_key]
  176. def forward(self, x):
  177. B, C, H, W = x.shape
  178. if self.stride_conv is not None:
  179. x = self.stride_conv(x)
  180. q = self.q(x).reshape(B, self.num_heads, -1, self.N).permute(0, 1, 3, 2)
  181. k = self.k(x).reshape(B, self.num_heads, -1, self.N).permute(0, 1, 2, 3)
  182. v = self.v(x)
  183. v_local = self.v_local(v)
  184. v = v.reshape(B, self.num_heads, -1, self.N).permute(0, 1, 3, 2)
  185. attn = (q @ k) * self.scale
  186. attn = attn + self.get_attention_biases(x.device)
  187. attn = self.talking_head1(attn)
  188. attn = attn.softmax(dim=-1)
  189. attn = self.talking_head2(attn)
  190. x = (attn @ v).transpose(2, 3)
  191. x = x.reshape(B, self.dh, self.resolution[0], self.resolution[1]) + v_local
  192. if self.upsample is not None:
  193. x = self.upsample(x)
  194. x = self.act(x)
  195. x = self.proj(x)
  196. return x
  197. class LocalGlobalQuery(torch.nn.Module):
  198. def __init__(
  199. self,
  200. in_dim: int,
  201. out_dim: int,
  202. device=None,
  203. dtype=None,
  204. ):
  205. dd = {'device': device, 'dtype': dtype}
  206. super().__init__()
  207. self.pool = nn.AvgPool2d(1, 2, 0)
  208. self.local = nn.Conv2d(in_dim, in_dim, kernel_size=3, stride=2, padding=1, groups=in_dim, **dd)
  209. self.proj = ConvNorm(in_dim, out_dim, 1, **dd)
  210. def forward(self, x):
  211. local_q = self.local(x)
  212. pool_q = self.pool(x)
  213. q = local_q + pool_q
  214. q = self.proj(q)
  215. return q
  216. class Attention2dDownsample(torch.nn.Module):
  217. attention_bias_cache: Dict[str, torch.Tensor]
  218. def __init__(
  219. self,
  220. dim: int = 384,
  221. key_dim: int = 16,
  222. num_heads: int = 8,
  223. attn_ratio: int = 4,
  224. resolution: Union[int, Tuple[int, int]] = 7,
  225. out_dim: Optional[int] = None,
  226. act_layer: Type[nn.Module] = nn.GELU,
  227. device=None,
  228. dtype=None,
  229. ):
  230. dd = {'device': device, 'dtype': dtype}
  231. super().__init__()
  232. self.num_heads = num_heads
  233. self.scale = key_dim ** -0.5
  234. self.key_dim = key_dim
  235. self.resolution = to_2tuple(resolution)
  236. self.resolution2 = tuple([math.ceil(r / 2) for r in self.resolution])
  237. self.N = self.resolution[0] * self.resolution[1]
  238. self.N2 = self.resolution2[0] * self.resolution2[1]
  239. self.d = int(attn_ratio * key_dim)
  240. self.dh = int(attn_ratio * key_dim) * num_heads
  241. self.attn_ratio = attn_ratio
  242. self.out_dim = out_dim or dim
  243. kh = self.key_dim * self.num_heads
  244. self.q = LocalGlobalQuery(dim, kh, **dd)
  245. self.k = ConvNorm(dim, kh, 1, **dd)
  246. self.v = ConvNorm(dim, self.dh, 1, **dd)
  247. self.v_local = ConvNorm(self.dh, self.dh, kernel_size=3, stride=2, groups=self.dh, **dd)
  248. self.act = act_layer()
  249. self.proj = ConvNorm(self.dh, self.out_dim, 1, **dd)
  250. self.attention_biases = nn.Parameter(torch.empty(num_heads, self.N, **dd))
  251. self.register_buffer(
  252. 'attention_bias_idxs',
  253. torch.empty((self.N2, self.N), device=device, dtype=torch.long),
  254. persistent=False,
  255. )
  256. self.attention_bias_cache = {}
  257. # TODO: skip init when on meta device when safe to do so
  258. self.reset_parameters()
  259. @torch.no_grad()
  260. def train(self, mode=True):
  261. super().train(mode)
  262. if mode and self.attention_bias_cache:
  263. self.attention_bias_cache = {} # clear ab cache
  264. def reset_parameters(self) -> None:
  265. """Initialize parameters and buffers."""
  266. nn.init.zeros_(self.attention_biases)
  267. self._init_buffers()
  268. def _compute_attention_bias_idxs(self, device=None):
  269. """Compute relative position indices for attention bias."""
  270. k_pos = torch.stack(ndgrid(
  271. torch.arange(self.resolution[0], device=device, dtype=torch.long),
  272. torch.arange(self.resolution[1], device=device, dtype=torch.long),
  273. )).flatten(1)
  274. q_pos = torch.stack(ndgrid(
  275. torch.arange(0, self.resolution[0], step=2, device=device, dtype=torch.long),
  276. torch.arange(0, self.resolution[1], step=2, device=device, dtype=torch.long),
  277. )).flatten(1)
  278. rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
  279. rel_pos = (rel_pos[0] * self.resolution[1]) + rel_pos[1]
  280. return rel_pos
  281. def _init_buffers(self) -> None:
  282. """Compute and fill non-persistent buffer values."""
  283. self.attention_bias_idxs.copy_(
  284. self._compute_attention_bias_idxs(device=self.attention_bias_idxs.device)
  285. )
  286. self.attention_bias_cache = {}
  287. def init_non_persistent_buffers(self) -> None:
  288. """Initialize non-persistent buffers."""
  289. self._init_buffers()
  290. def get_attention_biases(self, device: torch.device) -> torch.Tensor:
  291. if torch.jit.is_tracing() or self.training:
  292. return self.attention_biases[:, self.attention_bias_idxs]
  293. else:
  294. device_key = str(device)
  295. if device_key not in self.attention_bias_cache:
  296. self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
  297. return self.attention_bias_cache[device_key]
  298. def forward(self, x):
  299. B, C, H, W = x.shape
  300. q = self.q(x).reshape(B, self.num_heads, -1, self.N2).permute(0, 1, 3, 2)
  301. k = self.k(x).reshape(B, self.num_heads, -1, self.N).permute(0, 1, 2, 3)
  302. v = self.v(x)
  303. v_local = self.v_local(v)
  304. v = v.reshape(B, self.num_heads, -1, self.N).permute(0, 1, 3, 2)
  305. attn = (q @ k) * self.scale
  306. attn = attn + self.get_attention_biases(x.device)
  307. attn = attn.softmax(dim=-1)
  308. x = (attn @ v).transpose(2, 3)
  309. x = x.reshape(B, self.dh, self.resolution2[0], self.resolution2[1]) + v_local
  310. x = self.act(x)
  311. x = self.proj(x)
  312. return x
  313. class Downsample(nn.Module):
  314. def __init__(
  315. self,
  316. in_chs: int,
  317. out_chs: int,
  318. kernel_size: Union[int, Tuple[int, int]] = 3,
  319. stride: Union[int, Tuple[int, int]] = 2,
  320. padding: Union[int, Tuple[int, int]] = 1,
  321. resolution: Union[int, Tuple[int, int]] = 7,
  322. use_attn: bool = False,
  323. act_layer: Type[nn.Module] = nn.GELU,
  324. norm_layer: Optional[Type[nn.Module]] = nn.BatchNorm2d,
  325. device=None,
  326. dtype=None,
  327. ):
  328. dd = {'device': device, 'dtype': dtype}
  329. super().__init__()
  330. kernel_size = to_2tuple(kernel_size)
  331. stride = to_2tuple(stride)
  332. padding = to_2tuple(padding)
  333. norm_layer = norm_layer or nn.Identity()
  334. self.conv = ConvNorm(
  335. in_chs,
  336. out_chs,
  337. kernel_size=kernel_size,
  338. stride=stride,
  339. padding=padding,
  340. norm_layer=norm_layer,
  341. **dd,
  342. )
  343. if use_attn:
  344. self.attn = Attention2dDownsample(
  345. dim=in_chs,
  346. out_dim=out_chs,
  347. resolution=resolution,
  348. act_layer=act_layer,
  349. **dd,
  350. )
  351. else:
  352. self.attn = None
  353. def forward(self, x):
  354. out = self.conv(x)
  355. if self.attn is not None:
  356. return self.attn(x) + out
  357. return out
  358. class ConvMlpWithNorm(nn.Module):
  359. """
  360. Implementation of MLP with 1*1 convolutions.
  361. Input: tensor with shape [B, C, H, W]
  362. """
  363. def __init__(
  364. self,
  365. in_features: int,
  366. hidden_features: Optional[int] = None,
  367. out_features: Optional[int] = None,
  368. act_layer: Type[nn.Module] = nn.GELU,
  369. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  370. drop: float = 0.,
  371. mid_conv: bool = False,
  372. device=None,
  373. dtype=None,
  374. ):
  375. dd = {'device': device, 'dtype': dtype}
  376. super().__init__()
  377. out_features = out_features or in_features
  378. hidden_features = hidden_features or in_features
  379. self.fc1 = ConvNormAct(
  380. in_features,
  381. hidden_features,
  382. 1,
  383. bias=True,
  384. norm_layer=norm_layer,
  385. act_layer=act_layer,
  386. **dd,
  387. )
  388. if mid_conv:
  389. self.mid = ConvNormAct(
  390. hidden_features,
  391. hidden_features,
  392. 3,
  393. groups=hidden_features,
  394. bias=True,
  395. norm_layer=norm_layer,
  396. act_layer=act_layer,
  397. **dd,
  398. )
  399. else:
  400. self.mid = nn.Identity()
  401. self.drop1 = nn.Dropout(drop)
  402. self.fc2 = ConvNorm(hidden_features, out_features, 1, norm_layer=norm_layer, **dd)
  403. self.drop2 = nn.Dropout(drop)
  404. def forward(self, x):
  405. x = self.fc1(x)
  406. x = self.mid(x)
  407. x = self.drop1(x)
  408. x = self.fc2(x)
  409. x = self.drop2(x)
  410. return x
  411. class EfficientFormerV2Block(nn.Module):
  412. def __init__(
  413. self,
  414. dim: int,
  415. mlp_ratio: float = 4.,
  416. act_layer: Type[nn.Module] = nn.GELU,
  417. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  418. proj_drop: float = 0.,
  419. drop_path: float = 0.,
  420. layer_scale_init_value: Optional[float] = 1e-5,
  421. resolution: Union[int, Tuple[int, int]] = 7,
  422. stride: Optional[int] = None,
  423. use_attn: bool = True,
  424. device=None,
  425. dtype=None,
  426. ):
  427. dd = {'device': device, 'dtype': dtype}
  428. super().__init__()
  429. if use_attn:
  430. self.token_mixer = Attention2d(
  431. dim,
  432. resolution=resolution,
  433. act_layer=act_layer,
  434. stride=stride,
  435. **dd,
  436. )
  437. self.ls1 = LayerScale2d(
  438. dim, layer_scale_init_value, **dd) if layer_scale_init_value is not None else nn.Identity()
  439. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  440. else:
  441. self.token_mixer = None
  442. self.ls1 = None
  443. self.drop_path1 = None
  444. self.mlp = ConvMlpWithNorm(
  445. in_features=dim,
  446. hidden_features=int(dim * mlp_ratio),
  447. act_layer=act_layer,
  448. norm_layer=norm_layer,
  449. drop=proj_drop,
  450. mid_conv=True,
  451. **dd,
  452. )
  453. self.ls2 = LayerScale2d(
  454. dim, layer_scale_init_value, **dd) if layer_scale_init_value is not None else nn.Identity()
  455. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  456. def forward(self, x):
  457. if self.token_mixer is not None:
  458. x = x + self.drop_path1(self.ls1(self.token_mixer(x)))
  459. x = x + self.drop_path2(self.ls2(self.mlp(x)))
  460. return x
  461. class Stem4(nn.Sequential):
  462. def __init__(
  463. self,
  464. in_chs: int,
  465. out_chs: int,
  466. act_layer: Type[nn.Module] = nn.GELU,
  467. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  468. device=None,
  469. dtype=None,
  470. ):
  471. dd = {'device': device, 'dtype': dtype}
  472. super().__init__()
  473. self.stride = 4
  474. self.conv1 = ConvNormAct(
  475. in_chs,
  476. out_chs // 2,
  477. kernel_size=3,
  478. stride=2, padding=1,
  479. bias=True,
  480. norm_layer=norm_layer,
  481. act_layer=act_layer,
  482. **dd,
  483. )
  484. self.conv2 = ConvNormAct(
  485. out_chs // 2,
  486. out_chs,
  487. kernel_size=3,
  488. stride=2,
  489. padding=1,
  490. bias=True,
  491. norm_layer=norm_layer,
  492. act_layer=act_layer,
  493. **dd,
  494. )
  495. class EfficientFormerV2Stage(nn.Module):
  496. def __init__(
  497. self,
  498. dim: int,
  499. dim_out: int,
  500. depth: int,
  501. resolution: Union[int, Tuple[int, int]] = 7,
  502. downsample: bool = True,
  503. block_stride: Optional[int] = None,
  504. downsample_use_attn: bool = False,
  505. block_use_attn: bool = False,
  506. num_vit: int = 1,
  507. mlp_ratio: Union[float, Tuple[float, ...]] = 4.,
  508. proj_drop: float = .0,
  509. drop_path: Union[float, List[float]] = 0.,
  510. layer_scale_init_value: Optional[float] = 1e-5,
  511. act_layer: Type[nn.Module] = nn.GELU,
  512. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  513. device=None,
  514. dtype=None,
  515. ):
  516. dd = {'device': device, 'dtype': dtype}
  517. super().__init__()
  518. self.grad_checkpointing = False
  519. mlp_ratio = to_ntuple(depth)(mlp_ratio)
  520. resolution = to_2tuple(resolution)
  521. if downsample:
  522. self.downsample = Downsample(
  523. dim,
  524. dim_out,
  525. use_attn=downsample_use_attn,
  526. resolution=resolution,
  527. norm_layer=norm_layer,
  528. act_layer=act_layer,
  529. **dd,
  530. )
  531. dim = dim_out
  532. resolution = tuple([math.ceil(r / 2) for r in resolution])
  533. else:
  534. assert dim == dim_out
  535. self.downsample = nn.Identity()
  536. blocks = []
  537. for block_idx in range(depth):
  538. remain_idx = depth - num_vit - 1
  539. b = EfficientFormerV2Block(
  540. dim,
  541. resolution=resolution,
  542. stride=block_stride,
  543. mlp_ratio=mlp_ratio[block_idx],
  544. use_attn=block_use_attn and block_idx > remain_idx,
  545. proj_drop=proj_drop,
  546. drop_path=drop_path[block_idx],
  547. layer_scale_init_value=layer_scale_init_value,
  548. act_layer=act_layer,
  549. norm_layer=norm_layer,
  550. **dd,
  551. )
  552. blocks += [b]
  553. self.blocks = nn.Sequential(*blocks)
  554. def forward(self, x):
  555. x = self.downsample(x)
  556. if self.grad_checkpointing and not torch.jit.is_scripting():
  557. x = checkpoint_seq(self.blocks, x)
  558. else:
  559. x = self.blocks(x)
  560. return x
  561. class EfficientFormerV2(nn.Module):
  562. def __init__(
  563. self,
  564. depths: Tuple[int, ...],
  565. in_chans: int = 3,
  566. img_size: Union[int, Tuple[int, int]] = 224,
  567. global_pool: str = 'avg',
  568. embed_dims: Optional[Tuple[int, ...]] = None,
  569. downsamples: Optional[Tuple[bool, ...]] = None,
  570. mlp_ratios: Union[float, Tuple[float, ...], Tuple[Tuple[float, ...], ...]] = 4,
  571. norm_layer: str = 'batchnorm2d',
  572. norm_eps: float = 1e-5,
  573. act_layer: str = 'gelu',
  574. num_classes: int = 1000,
  575. drop_rate: float = 0.,
  576. proj_drop_rate: float = 0.,
  577. drop_path_rate: float = 0.,
  578. layer_scale_init_value: Optional[float] = 1e-5,
  579. num_vit: int = 0,
  580. distillation: bool = True,
  581. device=None,
  582. dtype=None,
  583. ):
  584. super().__init__()
  585. dd = {'device': device, 'dtype': dtype}
  586. assert global_pool in ('avg', '')
  587. self.num_classes = num_classes
  588. self.in_chans = in_chans
  589. self.global_pool = global_pool
  590. self.feature_info = []
  591. img_size = to_2tuple(img_size)
  592. norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
  593. act_layer = get_act_layer(act_layer)
  594. self.stem = Stem4(in_chans, embed_dims[0], act_layer=act_layer, norm_layer=norm_layer, **dd)
  595. prev_dim = embed_dims[0]
  596. stride = 4
  597. num_stages = len(depths)
  598. dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
  599. downsamples = downsamples or (False,) + (True,) * (len(depths) - 1)
  600. mlp_ratios = to_ntuple(num_stages)(mlp_ratios)
  601. stages = []
  602. for i in range(num_stages):
  603. curr_resolution = tuple([math.ceil(s / stride) for s in img_size])
  604. stage = EfficientFormerV2Stage(
  605. prev_dim,
  606. embed_dims[i],
  607. depth=depths[i],
  608. resolution=curr_resolution,
  609. downsample=downsamples[i],
  610. block_stride=2 if i == 2 else None,
  611. downsample_use_attn=i >= 3,
  612. block_use_attn=i >= 2,
  613. num_vit=num_vit,
  614. mlp_ratio=mlp_ratios[i],
  615. proj_drop=proj_drop_rate,
  616. drop_path=dpr[i],
  617. layer_scale_init_value=layer_scale_init_value,
  618. act_layer=act_layer,
  619. norm_layer=norm_layer,
  620. **dd,
  621. )
  622. if downsamples[i]:
  623. stride *= 2
  624. prev_dim = embed_dims[i]
  625. self.feature_info += [dict(num_chs=prev_dim, reduction=stride, module=f'stages.{i}')]
  626. stages.append(stage)
  627. self.stages = nn.Sequential(*stages)
  628. # Classifier head
  629. self.num_features = self.head_hidden_size = embed_dims[-1]
  630. self.norm = norm_layer(embed_dims[-1], **dd)
  631. self.head_drop = nn.Dropout(drop_rate)
  632. self.head = nn.Linear(embed_dims[-1], num_classes, **dd) if num_classes > 0 else nn.Identity()
  633. self.dist = distillation
  634. if self.dist:
  635. self.head_dist = nn.Linear(embed_dims[-1], num_classes, **dd) if num_classes > 0 else nn.Identity()
  636. else:
  637. self.head_dist = None
  638. # TODO: skip init when on meta device when safe to do so
  639. self.init_weights(needs_reset=False)
  640. self.distilled_training = False
  641. def _init_weights(self, m, needs_reset: bool = True):
  642. if isinstance(m, nn.Linear):
  643. trunc_normal_(m.weight, std=.02)
  644. if m.bias is not None:
  645. nn.init.constant_(m.bias, 0)
  646. elif needs_reset and hasattr(m, 'reset_parameters'):
  647. m.reset_parameters()
  648. def init_weights(self, needs_reset: bool = True):
  649. self.apply(partial(self._init_weights, needs_reset=needs_reset))
  650. @torch.jit.ignore
  651. def no_weight_decay(self):
  652. return {k for k, _ in self.named_parameters() if 'attention_biases' in k}
  653. @torch.jit.ignore
  654. def group_matcher(self, coarse=False):
  655. matcher = dict(
  656. stem=r'^stem', # stem and embed
  657. blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
  658. )
  659. return matcher
  660. @torch.jit.ignore
  661. def set_grad_checkpointing(self, enable=True):
  662. for s in self.stages:
  663. s.grad_checkpointing = enable
  664. @torch.jit.ignore
  665. def get_classifier(self) -> nn.Module:
  666. return self.head, self.head_dist
  667. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  668. self.num_classes = num_classes
  669. if global_pool is not None:
  670. self.global_pool = global_pool
  671. self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  672. self.head_dist = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  673. @torch.jit.ignore
  674. def set_distilled_training(self, enable=True):
  675. self.distilled_training = enable
  676. def forward_intermediates(
  677. self,
  678. x: torch.Tensor,
  679. indices: Optional[Union[int, List[int]]] = None,
  680. norm: bool = False,
  681. stop_early: bool = False,
  682. output_fmt: str = 'NCHW',
  683. intermediates_only: bool = False,
  684. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  685. """ Forward features that returns intermediates.
  686. Args:
  687. x: Input image tensor
  688. indices: Take last n blocks if int, all if None, select matching indices if sequence
  689. norm: Apply norm layer to compatible intermediates
  690. stop_early: Stop iterating over blocks when last desired intermediate hit
  691. output_fmt: Shape of intermediate feature outputs
  692. intermediates_only: Only return intermediate features
  693. Returns:
  694. """
  695. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  696. intermediates = []
  697. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  698. # forward pass
  699. x = self.stem(x)
  700. last_idx = len(self.stages) - 1
  701. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  702. stages = self.stages
  703. else:
  704. stages = self.stages[:max_index + 1]
  705. for feat_idx, stage in enumerate(stages):
  706. x = stage(x)
  707. if feat_idx in take_indices:
  708. if feat_idx == last_idx:
  709. x_inter = self.norm(x) if norm else x
  710. intermediates.append(x_inter)
  711. else:
  712. intermediates.append(x)
  713. if intermediates_only:
  714. return intermediates
  715. if feat_idx == last_idx:
  716. x = self.norm(x)
  717. return x, intermediates
  718. def prune_intermediate_layers(
  719. self,
  720. indices: Union[int, List[int]] = 1,
  721. prune_norm: bool = False,
  722. prune_head: bool = True,
  723. ):
  724. """ Prune layers not required for specified intermediates.
  725. """
  726. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  727. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  728. if prune_norm:
  729. self.norm = nn.Identity()
  730. if prune_head:
  731. self.reset_classifier(0, '')
  732. return take_indices
  733. def forward_features(self, x):
  734. x = self.stem(x)
  735. x = self.stages(x)
  736. x = self.norm(x)
  737. return x
  738. def forward_head(self, x, pre_logits: bool = False):
  739. if self.global_pool == 'avg':
  740. x = x.mean(dim=(2, 3))
  741. x = self.head_drop(x)
  742. if pre_logits:
  743. return x
  744. x, x_dist = self.head(x), self.head_dist(x)
  745. if self.distilled_training and self.training and not torch.jit.is_scripting():
  746. # only return separate classification predictions when training in distilled mode
  747. return x, x_dist
  748. else:
  749. # during standard train/finetune, inference average the classifier predictions
  750. return (x + x_dist) / 2
  751. def forward(self, x):
  752. x = self.forward_features(x)
  753. x = self.forward_head(x)
  754. return x
  755. def _cfg(url='', **kwargs):
  756. return {
  757. 'url': url,
  758. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'fixed_input_size': True,
  759. 'crop_pct': .95, 'interpolation': 'bicubic',
  760. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  761. 'classifier': ('head', 'head_dist'), 'first_conv': 'stem.conv1.conv',
  762. 'license': 'apache-2.0',
  763. **kwargs
  764. }
  765. default_cfgs = generate_default_cfgs({
  766. 'efficientformerv2_s0.snap_dist_in1k': _cfg(
  767. hf_hub_id='timm/',
  768. ),
  769. 'efficientformerv2_s1.snap_dist_in1k': _cfg(
  770. hf_hub_id='timm/',
  771. ),
  772. 'efficientformerv2_s2.snap_dist_in1k': _cfg(
  773. hf_hub_id='timm/',
  774. ),
  775. 'efficientformerv2_l.snap_dist_in1k': _cfg(
  776. hf_hub_id='timm/',
  777. ),
  778. })
  779. def _create_efficientformerv2(variant, pretrained=False, **kwargs):
  780. out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
  781. model = build_model_with_cfg(
  782. EfficientFormerV2, variant, pretrained,
  783. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  784. **kwargs)
  785. return model
  786. @register_model
  787. def efficientformerv2_s0(pretrained=False, **kwargs) -> EfficientFormerV2:
  788. model_args = dict(
  789. depths=EfficientFormer_depth['S0'],
  790. embed_dims=EfficientFormer_width['S0'],
  791. num_vit=2,
  792. drop_path_rate=0.0,
  793. mlp_ratios=EfficientFormer_expansion_ratios['S0'],
  794. )
  795. return _create_efficientformerv2('efficientformerv2_s0', pretrained=pretrained, **dict(model_args, **kwargs))
  796. @register_model
  797. def efficientformerv2_s1(pretrained=False, **kwargs) -> EfficientFormerV2:
  798. model_args = dict(
  799. depths=EfficientFormer_depth['S1'],
  800. embed_dims=EfficientFormer_width['S1'],
  801. num_vit=2,
  802. drop_path_rate=0.0,
  803. mlp_ratios=EfficientFormer_expansion_ratios['S1'],
  804. )
  805. return _create_efficientformerv2('efficientformerv2_s1', pretrained=pretrained, **dict(model_args, **kwargs))
  806. @register_model
  807. def efficientformerv2_s2(pretrained=False, **kwargs) -> EfficientFormerV2:
  808. model_args = dict(
  809. depths=EfficientFormer_depth['S2'],
  810. embed_dims=EfficientFormer_width['S2'],
  811. num_vit=4,
  812. drop_path_rate=0.02,
  813. mlp_ratios=EfficientFormer_expansion_ratios['S2'],
  814. )
  815. return _create_efficientformerv2('efficientformerv2_s2', pretrained=pretrained, **dict(model_args, **kwargs))
  816. @register_model
  817. def efficientformerv2_l(pretrained=False, **kwargs) -> EfficientFormerV2:
  818. model_args = dict(
  819. depths=EfficientFormer_depth['L'],
  820. embed_dims=EfficientFormer_width['L'],
  821. num_vit=6,
  822. drop_path_rate=0.1,
  823. mlp_ratios=EfficientFormer_expansion_ratios['L'],
  824. )
  825. return _create_efficientformerv2('efficientformerv2_l', pretrained=pretrained, **dict(model_args, **kwargs))