efficientformer.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686
  1. """ EfficientFormer
  2. @article{li2022efficientformer,
  3. title={EfficientFormer: Vision Transformers at MobileNet Speed},
  4. author={Li, Yanyu and Yuan, Geng and Wen, Yang and Hu, Eric and Evangelidis, Georgios and Tulyakov,
  5. Sergey and Wang, Yanzhi and Ren, Jian},
  6. journal={arXiv preprint arXiv:2206.01191},
  7. year={2022}
  8. }
  9. Based on Apache 2.0 licensed code at https://github.com/snap-research/EfficientFormer, Copyright (c) 2022 Snap Inc.
  10. Modifications and timm support by / Copyright 2022, Ross Wightman
  11. """
  12. from typing import Dict, List, Optional, Tuple, Type, Union
  13. import torch
  14. import torch.nn as nn
  15. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  16. from timm.layers import (
  17. DropPath,
  18. LayerScale,
  19. LayerScale2d,
  20. Mlp,
  21. calculate_drop_path_rates,
  22. trunc_normal_,
  23. to_2tuple,
  24. ndgrid,
  25. )
  26. from ._builder import build_model_with_cfg
  27. from ._features import feature_take_indices
  28. from ._manipulate import checkpoint_seq
  29. from ._registry import generate_default_cfgs, register_model
  30. __all__ = ['EfficientFormer'] # model_registry will add each entrypoint fn to this
  31. EfficientFormer_width = {
  32. 'l1': (48, 96, 224, 448),
  33. 'l3': (64, 128, 320, 512),
  34. 'l7': (96, 192, 384, 768),
  35. }
  36. EfficientFormer_depth = {
  37. 'l1': (3, 2, 6, 4),
  38. 'l3': (4, 4, 12, 6),
  39. 'l7': (6, 6, 18, 8),
  40. }
  41. class Attention(torch.nn.Module):
  42. attention_bias_cache: Dict[str, torch.Tensor]
  43. def __init__(
  44. self,
  45. dim: int = 384,
  46. key_dim: int = 32,
  47. num_heads: int = 8,
  48. attn_ratio: float = 4,
  49. resolution: int = 7,
  50. device=None,
  51. dtype=None,
  52. ):
  53. dd = {'device': device, 'dtype': dtype}
  54. super().__init__()
  55. self.num_heads = num_heads
  56. self.scale = key_dim ** -0.5
  57. self.key_dim = key_dim
  58. self.key_attn_dim = key_dim * num_heads
  59. self.val_dim = int(attn_ratio * key_dim)
  60. self.val_attn_dim = self.val_dim * num_heads
  61. self.attn_ratio = attn_ratio
  62. self.qkv = nn.Linear(dim, self.key_attn_dim * 2 + self.val_attn_dim, **dd)
  63. self.proj = nn.Linear(self.val_attn_dim, dim, **dd)
  64. resolution = to_2tuple(resolution)
  65. pos = torch.stack(ndgrid(
  66. torch.arange(resolution[0], device=device, dtype=torch.long),
  67. torch.arange(resolution[1], device=device, dtype=torch.long)
  68. )).flatten(1)
  69. rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
  70. rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
  71. self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1], **dd))
  72. self.register_buffer('attention_bias_idxs', rel_pos)
  73. self.attention_bias_cache = {} # per-device attention_biases cache (data-parallel compat)
  74. @torch.no_grad()
  75. def train(self, mode=True):
  76. super().train(mode)
  77. if mode and self.attention_bias_cache:
  78. self.attention_bias_cache = {} # clear ab cache
  79. def get_attention_biases(self, device: torch.device) -> torch.Tensor:
  80. if torch.jit.is_tracing() or self.training:
  81. return self.attention_biases[:, self.attention_bias_idxs]
  82. else:
  83. device_key = str(device)
  84. if device_key not in self.attention_bias_cache:
  85. self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
  86. return self.attention_bias_cache[device_key]
  87. def forward(self, x): # x (B,N,C)
  88. B, N, C = x.shape
  89. qkv = self.qkv(x)
  90. qkv = qkv.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
  91. q, k, v = qkv.split([self.key_dim, self.key_dim, self.val_dim], dim=3)
  92. attn = (q @ k.transpose(-2, -1)) * self.scale
  93. attn = attn + self.get_attention_biases(x.device)
  94. attn = attn.softmax(dim=-1)
  95. x = (attn @ v).transpose(1, 2).reshape(B, N, self.val_attn_dim)
  96. x = self.proj(x)
  97. return x
  98. class Stem4(nn.Sequential):
  99. def __init__(
  100. self,
  101. in_chs: int,
  102. out_chs: int,
  103. act_layer: Type[nn.Module] = nn.ReLU,
  104. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  105. device=None,
  106. dtype=None,
  107. ):
  108. dd = {'device': device, 'dtype': dtype}
  109. super().__init__()
  110. self.stride = 4
  111. self.add_module('conv1', nn.Conv2d(in_chs, out_chs // 2, kernel_size=3, stride=2, padding=1, **dd))
  112. self.add_module('norm1', norm_layer(out_chs // 2, **dd))
  113. self.add_module('act1', act_layer())
  114. self.add_module('conv2', nn.Conv2d(out_chs // 2, out_chs, kernel_size=3, stride=2, padding=1, **dd))
  115. self.add_module('norm2', norm_layer(out_chs, **dd))
  116. self.add_module('act2', act_layer())
  117. class Downsample(nn.Module):
  118. """
  119. Downsampling via strided conv w/ norm
  120. Input: tensor in shape [B, C, H, W]
  121. Output: tensor in shape [B, C, H/stride, W/stride]
  122. """
  123. def __init__(
  124. self,
  125. in_chs: int,
  126. out_chs: int,
  127. kernel_size: int = 3,
  128. stride: int = 2,
  129. padding: Optional[int] = None,
  130. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  131. device=None,
  132. dtype=None,
  133. ):
  134. dd = {'device': device, 'dtype': dtype}
  135. super().__init__()
  136. if padding is None:
  137. padding = kernel_size // 2
  138. self.conv = nn.Conv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride, padding=padding, **dd)
  139. self.norm = norm_layer(out_chs, **dd)
  140. def forward(self, x):
  141. x = self.conv(x)
  142. x = self.norm(x)
  143. return x
  144. class Flat(nn.Module):
  145. def __init__(self, ):
  146. super().__init__()
  147. def forward(self, x):
  148. x = x.flatten(2).transpose(1, 2)
  149. return x
  150. class Pooling(nn.Module):
  151. """
  152. Implementation of pooling for PoolFormer
  153. --pool_size: pooling size
  154. """
  155. def __init__(self, pool_size: int = 3):
  156. super().__init__()
  157. self.pool = nn.AvgPool2d(pool_size, stride=1, padding=pool_size // 2, count_include_pad=False)
  158. def forward(self, x):
  159. return self.pool(x) - x
  160. class ConvMlpWithNorm(nn.Module):
  161. """
  162. Implementation of MLP with 1*1 convolutions.
  163. Input: tensor with shape [B, C, H, W]
  164. """
  165. def __init__(
  166. self,
  167. in_features: int,
  168. hidden_features: Optional[int] = None,
  169. out_features: Optional[int] = None,
  170. act_layer: Type[nn.Module] = nn.GELU,
  171. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  172. drop: float = 0.,
  173. device=None,
  174. dtype=None,
  175. ):
  176. dd = {'device': device, 'dtype': dtype}
  177. super().__init__()
  178. out_features = out_features or in_features
  179. hidden_features = hidden_features or in_features
  180. self.fc1 = nn.Conv2d(in_features, hidden_features, 1, **dd)
  181. self.norm1 = norm_layer(hidden_features, **dd) if norm_layer is not None else nn.Identity()
  182. self.act = act_layer()
  183. self.fc2 = nn.Conv2d(hidden_features, out_features, 1, **dd)
  184. self.norm2 = norm_layer(out_features, **dd) if norm_layer is not None else nn.Identity()
  185. self.drop = nn.Dropout(drop)
  186. def forward(self, x):
  187. x = self.fc1(x)
  188. x = self.norm1(x)
  189. x = self.act(x)
  190. x = self.drop(x)
  191. x = self.fc2(x)
  192. x = self.norm2(x)
  193. x = self.drop(x)
  194. return x
  195. class MetaBlock1d(nn.Module):
  196. def __init__(
  197. self,
  198. dim: int,
  199. mlp_ratio: float = 4.,
  200. act_layer: Type[nn.Module] = nn.GELU,
  201. norm_layer: Type[nn.Module] = nn.LayerNorm,
  202. proj_drop: float = 0.,
  203. drop_path: float = 0.,
  204. layer_scale_init_value: float = 1e-5,
  205. device=None,
  206. dtype=None,
  207. ):
  208. dd = {'device': device, 'dtype': dtype}
  209. super().__init__()
  210. self.norm1 = norm_layer(dim, **dd)
  211. self.token_mixer = Attention(dim, **dd)
  212. self.ls1 = LayerScale(dim, layer_scale_init_value, **dd)
  213. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  214. self.norm2 = norm_layer(dim, **dd)
  215. self.mlp = Mlp(
  216. in_features=dim,
  217. hidden_features=int(dim * mlp_ratio),
  218. act_layer=act_layer,
  219. drop=proj_drop,
  220. **dd,
  221. )
  222. self.ls2 = LayerScale(dim, layer_scale_init_value, **dd)
  223. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  224. def forward(self, x):
  225. x = x + self.drop_path1(self.ls1(self.token_mixer(self.norm1(x))))
  226. x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
  227. return x
  228. class MetaBlock2d(nn.Module):
  229. def __init__(
  230. self,
  231. dim: int,
  232. pool_size: int = 3,
  233. mlp_ratio: float = 4.,
  234. act_layer: Type[nn.Module] = nn.GELU,
  235. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  236. proj_drop: float = 0.,
  237. drop_path: float = 0.,
  238. layer_scale_init_value: float = 1e-5,
  239. device=None,
  240. dtype=None,
  241. ):
  242. dd = {'device': device, 'dtype': dtype}
  243. super().__init__()
  244. self.token_mixer = Pooling(pool_size=pool_size)
  245. self.ls1 = LayerScale2d(dim, layer_scale_init_value, **dd)
  246. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  247. self.mlp = ConvMlpWithNorm(
  248. dim,
  249. hidden_features=int(dim * mlp_ratio),
  250. act_layer=act_layer,
  251. norm_layer=norm_layer,
  252. drop=proj_drop,
  253. **dd,
  254. )
  255. self.ls2 = LayerScale2d(dim, layer_scale_init_value, **dd)
  256. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  257. def forward(self, x):
  258. x = x + self.drop_path1(self.ls1(self.token_mixer(x)))
  259. x = x + self.drop_path2(self.ls2(self.mlp(x)))
  260. return x
  261. class EfficientFormerStage(nn.Module):
  262. def __init__(
  263. self,
  264. dim: int,
  265. dim_out: int,
  266. depth: int ,
  267. downsample: bool = True,
  268. num_vit: int = 1,
  269. pool_size: int = 3,
  270. mlp_ratio: float = 4.,
  271. act_layer: Type[nn.Module] = nn.GELU,
  272. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  273. norm_layer_cl: Type[nn.Module] = nn.LayerNorm,
  274. proj_drop: float = .0,
  275. drop_path: float = 0.,
  276. layer_scale_init_value: float = 1e-5,
  277. device=None,
  278. dtype=None,
  279. ):
  280. dd = {'device': device, 'dtype': dtype}
  281. super().__init__()
  282. self.grad_checkpointing = False
  283. if downsample:
  284. self.downsample = Downsample(in_chs=dim, out_chs=dim_out, norm_layer=norm_layer, **dd)
  285. dim = dim_out
  286. else:
  287. assert dim == dim_out
  288. self.downsample = nn.Identity()
  289. blocks = []
  290. if num_vit and num_vit >= depth:
  291. blocks.append(Flat())
  292. for block_idx in range(depth):
  293. remain_idx = depth - block_idx - 1
  294. if num_vit and num_vit > remain_idx:
  295. blocks.append(
  296. MetaBlock1d(
  297. dim,
  298. mlp_ratio=mlp_ratio,
  299. act_layer=act_layer,
  300. norm_layer=norm_layer_cl,
  301. proj_drop=proj_drop,
  302. drop_path=drop_path[block_idx],
  303. layer_scale_init_value=layer_scale_init_value,
  304. **dd,
  305. ))
  306. else:
  307. blocks.append(
  308. MetaBlock2d(
  309. dim,
  310. pool_size=pool_size,
  311. mlp_ratio=mlp_ratio,
  312. act_layer=act_layer,
  313. norm_layer=norm_layer,
  314. proj_drop=proj_drop,
  315. drop_path=drop_path[block_idx],
  316. layer_scale_init_value=layer_scale_init_value,
  317. **dd,
  318. ))
  319. if num_vit and num_vit == remain_idx:
  320. blocks.append(Flat())
  321. self.blocks = nn.Sequential(*blocks)
  322. def forward(self, x):
  323. x = self.downsample(x)
  324. if self.grad_checkpointing and not torch.jit.is_scripting():
  325. x = checkpoint_seq(self.blocks, x)
  326. else:
  327. x = self.blocks(x)
  328. return x
  329. class EfficientFormer(nn.Module):
  330. def __init__(
  331. self,
  332. depths: Tuple[int, ...] = (3, 2, 6, 4),
  333. embed_dims: Tuple[int, ...] = (48, 96, 224, 448),
  334. in_chans: int = 3,
  335. num_classes: int = 1000,
  336. global_pool: str = 'avg',
  337. downsamples: Optional[Tuple[bool, ...]] = None,
  338. num_vit: int = 0,
  339. mlp_ratios: float = 4,
  340. pool_size: int = 3,
  341. layer_scale_init_value: float = 1e-5,
  342. act_layer: Type[nn.Module] = nn.GELU,
  343. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  344. norm_layer_cl: Type[nn.Module] = nn.LayerNorm,
  345. drop_rate: float = 0.,
  346. proj_drop_rate: float = 0.,
  347. drop_path_rate: float = 0.,
  348. device=None,
  349. dtype=None,
  350. **kwargs
  351. ):
  352. super().__init__()
  353. dd = {'device': device, 'dtype': dtype}
  354. self.num_classes = num_classes
  355. self.in_chans = in_chans
  356. self.global_pool = global_pool
  357. self.stem = Stem4(in_chans, embed_dims[0], norm_layer=norm_layer, **dd)
  358. prev_dim = embed_dims[0]
  359. # stochastic depth decay rule
  360. self.num_stages = len(depths)
  361. last_stage = self.num_stages - 1
  362. dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
  363. downsamples = downsamples or (False,) + (True,) * (self.num_stages - 1)
  364. stages = []
  365. self.feature_info = []
  366. for i in range(self.num_stages):
  367. stage = EfficientFormerStage(
  368. prev_dim,
  369. embed_dims[i],
  370. depths[i],
  371. downsample=downsamples[i],
  372. num_vit=num_vit if i == last_stage else 0,
  373. pool_size=pool_size,
  374. mlp_ratio=mlp_ratios,
  375. act_layer=act_layer,
  376. norm_layer_cl=norm_layer_cl,
  377. norm_layer=norm_layer,
  378. proj_drop=proj_drop_rate,
  379. drop_path=dpr[i],
  380. layer_scale_init_value=layer_scale_init_value,
  381. **dd,
  382. )
  383. prev_dim = embed_dims[i]
  384. stages.append(stage)
  385. self.feature_info += [dict(num_chs=embed_dims[i], reduction=2**(i+2), module=f'stages.{i}')]
  386. self.stages = nn.Sequential(*stages)
  387. # Classifier head
  388. self.num_features = self.head_hidden_size = embed_dims[-1]
  389. self.norm = norm_layer_cl(self.num_features, **dd)
  390. self.head_drop = nn.Dropout(drop_rate)
  391. self.head = nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity()
  392. # assuming model is always distilled (valid for current checkpoints, will split def if that changes)
  393. self.head_dist = nn.Linear(embed_dims[-1], num_classes, **dd) if num_classes > 0 else nn.Identity()
  394. self.distilled_training = False # must set this True to train w/ distillation token
  395. self.apply(self._init_weights)
  396. # init for classification
  397. def _init_weights(self, m):
  398. if isinstance(m, nn.Linear):
  399. trunc_normal_(m.weight, std=.02)
  400. if isinstance(m, nn.Linear) and m.bias is not None:
  401. nn.init.constant_(m.bias, 0)
  402. @torch.jit.ignore
  403. def no_weight_decay(self):
  404. return {k for k, _ in self.named_parameters() if 'attention_biases' in k}
  405. @torch.jit.ignore
  406. def group_matcher(self, coarse=False):
  407. matcher = dict(
  408. stem=r'^stem', # stem and embed
  409. blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
  410. )
  411. return matcher
  412. @torch.jit.ignore
  413. def set_grad_checkpointing(self, enable=True):
  414. for s in self.stages:
  415. s.grad_checkpointing = enable
  416. @torch.jit.ignore
  417. def get_classifier(self) -> nn.Module:
  418. return self.head, self.head_dist
  419. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  420. self.num_classes = num_classes
  421. if global_pool is not None:
  422. self.global_pool = global_pool
  423. self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  424. self.head_dist = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  425. @torch.jit.ignore
  426. def set_distilled_training(self, enable=True):
  427. self.distilled_training = enable
  428. def forward_intermediates(
  429. self,
  430. x: torch.Tensor,
  431. indices: Optional[Union[int, List[int]]] = None,
  432. norm: bool = False,
  433. stop_early: bool = False,
  434. output_fmt: str = 'NCHW',
  435. intermediates_only: bool = False,
  436. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  437. """ Forward features that returns intermediates.
  438. Args:
  439. x: Input image tensor
  440. indices: Take last n blocks if int, all if None, select matching indices if sequence
  441. norm: Apply norm layer to compatible intermediates
  442. stop_early: Stop iterating over blocks when last desired intermediate hit
  443. output_fmt: Shape of intermediate feature outputs
  444. intermediates_only: Only return intermediate features
  445. Returns:
  446. """
  447. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  448. intermediates = []
  449. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  450. # forward pass
  451. x = self.stem(x)
  452. B, C, H, W = x.shape
  453. last_idx = self.num_stages - 1
  454. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  455. stages = self.stages
  456. else:
  457. stages = self.stages[:max_index + 1]
  458. feat_idx = 0
  459. for feat_idx, stage in enumerate(stages):
  460. x = stage(x)
  461. if feat_idx < last_idx:
  462. B, C, H, W = x.shape
  463. if feat_idx in take_indices:
  464. if feat_idx == last_idx:
  465. x_inter = self.norm(x) if norm else x
  466. intermediates.append(x_inter.reshape(B, H // 2, W // 2, -1).permute(0, 3, 1, 2))
  467. else:
  468. intermediates.append(x)
  469. if intermediates_only:
  470. return intermediates
  471. if feat_idx == last_idx:
  472. x = self.norm(x)
  473. return x, intermediates
  474. def prune_intermediate_layers(
  475. self,
  476. indices: Union[int, List[int]] = 1,
  477. prune_norm: bool = False,
  478. prune_head: bool = True,
  479. ):
  480. """ Prune layers not required for specified intermediates.
  481. """
  482. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  483. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  484. if prune_norm:
  485. self.norm = nn.Identity()
  486. if prune_head:
  487. self.reset_classifier(0, '')
  488. return take_indices
  489. def forward_features(self, x):
  490. x = self.stem(x)
  491. x = self.stages(x)
  492. x = self.norm(x)
  493. return x
  494. def forward_head(self, x, pre_logits: bool = False):
  495. if self.global_pool == 'avg':
  496. x = x.mean(dim=1)
  497. x = self.head_drop(x)
  498. if pre_logits:
  499. return x
  500. x, x_dist = self.head(x), self.head_dist(x)
  501. if self.distilled_training and self.training and not torch.jit.is_scripting():
  502. # only return separate classification predictions when training in distilled mode
  503. return x, x_dist
  504. else:
  505. # during standard train/finetune, inference average the classifier predictions
  506. return (x + x_dist) / 2
  507. def forward(self, x):
  508. x = self.forward_features(x)
  509. x = self.forward_head(x)
  510. return x
  511. def checkpoint_filter_fn(state_dict, model):
  512. """ Remap original checkpoints -> timm """
  513. if 'stem.0.weight' in state_dict:
  514. return state_dict # non-original checkpoint, no remapping needed
  515. out_dict = {}
  516. import re
  517. stage_idx = 0
  518. for k, v in state_dict.items():
  519. if k.startswith('patch_embed'):
  520. k = k.replace('patch_embed.0', 'stem.conv1')
  521. k = k.replace('patch_embed.1', 'stem.norm1')
  522. k = k.replace('patch_embed.3', 'stem.conv2')
  523. k = k.replace('patch_embed.4', 'stem.norm2')
  524. if re.match(r'network\.(\d+)\.proj\.weight', k):
  525. stage_idx += 1
  526. k = re.sub(r'network.(\d+).(\d+)', f'stages.{stage_idx}.blocks.\\2', k)
  527. k = re.sub(r'network.(\d+).proj', f'stages.{stage_idx}.downsample.conv', k)
  528. k = re.sub(r'network.(\d+).norm', f'stages.{stage_idx}.downsample.norm', k)
  529. k = re.sub(r'layer_scale_([0-9])', r'ls\1.gamma', k)
  530. k = k.replace('dist_head', 'head_dist')
  531. out_dict[k] = v
  532. return out_dict
  533. def _cfg(url='', **kwargs):
  534. return {
  535. 'url': url,
  536. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'fixed_input_size': True,
  537. 'crop_pct': .95, 'interpolation': 'bicubic',
  538. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  539. 'first_conv': 'stem.conv1', 'classifier': ('head', 'head_dist'),
  540. 'license': 'apache-2.0',
  541. **kwargs
  542. }
  543. default_cfgs = generate_default_cfgs({
  544. 'efficientformer_l1.snap_dist_in1k': _cfg(
  545. hf_hub_id='timm/',
  546. ),
  547. 'efficientformer_l3.snap_dist_in1k': _cfg(
  548. hf_hub_id='timm/',
  549. ),
  550. 'efficientformer_l7.snap_dist_in1k': _cfg(
  551. hf_hub_id='timm/',
  552. ),
  553. })
  554. def _create_efficientformer(variant, pretrained=False, **kwargs):
  555. out_indices = kwargs.pop('out_indices', 4)
  556. model = build_model_with_cfg(
  557. EfficientFormer, variant, pretrained,
  558. pretrained_filter_fn=checkpoint_filter_fn,
  559. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  560. **kwargs,
  561. )
  562. return model
  563. @register_model
  564. def efficientformer_l1(pretrained=False, **kwargs) -> EfficientFormer:
  565. model_args = dict(
  566. depths=EfficientFormer_depth['l1'],
  567. embed_dims=EfficientFormer_width['l1'],
  568. num_vit=1,
  569. )
  570. return _create_efficientformer('efficientformer_l1', pretrained=pretrained, **dict(model_args, **kwargs))
  571. @register_model
  572. def efficientformer_l3(pretrained=False, **kwargs) -> EfficientFormer:
  573. model_args = dict(
  574. depths=EfficientFormer_depth['l3'],
  575. embed_dims=EfficientFormer_width['l3'],
  576. num_vit=4,
  577. )
  578. return _create_efficientformer('efficientformer_l3', pretrained=pretrained, **dict(model_args, **kwargs))
  579. @register_model
  580. def efficientformer_l7(pretrained=False, **kwargs) -> EfficientFormer:
  581. model_args = dict(
  582. depths=EfficientFormer_depth['l7'],
  583. embed_dims=EfficientFormer_width['l7'],
  584. num_vit=8,
  585. )
  586. return _create_efficientformer('efficientformer_l7', pretrained=pretrained, **dict(model_args, **kwargs))