edgenext.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706
  1. """ EdgeNeXt
  2. Paper: `EdgeNeXt: Efficiently Amalgamated CNN-Transformer Architecture for Mobile Vision Applications`
  3. - https://arxiv.org/abs/2206.10589
  4. Original code and weights from https://github.com/mmaaz60/EdgeNeXt
  5. Modifications and additions for timm by / Copyright 2022, Ross Wightman
  6. """
  7. import math
  8. from functools import partial
  9. from typing import List, Optional, Tuple, Type, Union
  10. import torch
  11. import torch.nn.functional as F
  12. from torch import nn
  13. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  14. from timm.layers import (
  15. DropPath,
  16. calculate_drop_path_rates,
  17. LayerNorm2d,
  18. Mlp,
  19. create_conv2d,
  20. NormMlpClassifierHead,
  21. ClassifierHead,
  22. trunc_normal_tf_,
  23. )
  24. from ._builder import build_model_with_cfg
  25. from ._features import feature_take_indices
  26. from ._features_fx import register_notrace_module
  27. from ._manipulate import named_apply, checkpoint_seq
  28. from ._registry import register_model, generate_default_cfgs
  29. __all__ = ['EdgeNeXt'] # model_registry will add each entrypoint fn to this
  30. @register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method
  31. class PositionalEncodingFourier(nn.Module):
  32. def __init__(
  33. self,
  34. hidden_dim: int = 32,
  35. dim: int = 768,
  36. temperature: float = 10000.,
  37. device=None,
  38. dtype=None,
  39. ):
  40. dd = {'device': device, 'dtype': dtype}
  41. super().__init__()
  42. self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1, **dd)
  43. self.scale = 2 * math.pi
  44. self.temperature = temperature
  45. self.hidden_dim = hidden_dim
  46. self.dim = dim
  47. def forward(self, shape: Tuple[int, int, int]):
  48. device = self.token_projection.weight.device
  49. dtype = self.token_projection.weight.dtype
  50. inv_mask = ~torch.zeros(shape).to(device=device, dtype=torch.bool)
  51. y_embed = inv_mask.cumsum(1, dtype=torch.float32)
  52. x_embed = inv_mask.cumsum(2, dtype=torch.float32)
  53. eps = 1e-6
  54. y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
  55. x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
  56. dim_t = torch.arange(self.hidden_dim, dtype=torch.int64, device=device).to(torch.float32)
  57. dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim)
  58. pos_x = x_embed[:, :, :, None] / dim_t
  59. pos_y = y_embed[:, :, :, None] / dim_t
  60. pos_x = torch.stack(
  61. (pos_x[:, :, :, 0::2].sin(),
  62. pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
  63. pos_y = torch.stack(
  64. (pos_y[:, :, :, 0::2].sin(),
  65. pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
  66. pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
  67. pos = self.token_projection(pos.to(dtype))
  68. return pos
  69. class ConvBlock(nn.Module):
  70. def __init__(
  71. self,
  72. dim: int,
  73. dim_out: Optional[int] = None,
  74. kernel_size: int = 7,
  75. stride: int = 1,
  76. conv_bias: bool = True,
  77. expand_ratio: float = 4,
  78. ls_init_value: float = 1e-6,
  79. norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  80. act_layer: Type[nn.Module] = nn.GELU,
  81. drop_path: float = 0.,
  82. device=None,
  83. dtype=None,
  84. ):
  85. dd = {'device': device, 'dtype': dtype}
  86. super().__init__()
  87. dim_out = dim_out or dim
  88. self.shortcut_after_dw = stride > 1 or dim != dim_out
  89. self.conv_dw = create_conv2d(
  90. dim,
  91. dim_out,
  92. kernel_size=kernel_size,
  93. stride=stride,
  94. depthwise=True,
  95. bias=conv_bias,
  96. **dd,
  97. )
  98. self.norm = norm_layer(dim_out, **dd)
  99. self.mlp = Mlp(
  100. dim_out,
  101. int(expand_ratio * dim_out),
  102. act_layer=act_layer,
  103. **dd,
  104. )
  105. self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out, **dd)) if ls_init_value > 0 else None
  106. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  107. def forward(self, x):
  108. shortcut = x
  109. x = self.conv_dw(x)
  110. if self.shortcut_after_dw:
  111. shortcut = x
  112. x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
  113. x = self.norm(x)
  114. x = self.mlp(x)
  115. if self.gamma is not None:
  116. x = self.gamma * x
  117. x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
  118. x = shortcut + self.drop_path(x)
  119. return x
  120. class CrossCovarianceAttn(nn.Module):
  121. def __init__(
  122. self,
  123. dim: int,
  124. num_heads: int = 8,
  125. qkv_bias: bool = False,
  126. attn_drop: float = 0.,
  127. proj_drop: float = 0.,
  128. device=None,
  129. dtype=None,
  130. ):
  131. dd = {'device': device, 'dtype': dtype}
  132. super().__init__()
  133. self.num_heads = num_heads
  134. self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1, **dd))
  135. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
  136. self.attn_drop = nn.Dropout(attn_drop)
  137. self.proj = nn.Linear(dim, dim, **dd)
  138. self.proj_drop = nn.Dropout(proj_drop)
  139. def forward(self, x):
  140. B, N, C = x.shape
  141. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 4, 1)
  142. q, k, v = qkv.unbind(0)
  143. # NOTE, this is NOT spatial attn, q, k, v are B, num_heads, C, L --> C x C attn map
  144. attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) * self.temperature
  145. attn = attn.softmax(dim=-1)
  146. attn = self.attn_drop(attn)
  147. x = (attn @ v)
  148. x = x.permute(0, 3, 1, 2).reshape(B, N, C)
  149. x = self.proj(x)
  150. x = self.proj_drop(x)
  151. return x
  152. @torch.jit.ignore
  153. def no_weight_decay(self):
  154. return {'temperature'}
  155. class SplitTransposeBlock(nn.Module):
  156. def __init__(
  157. self,
  158. dim: int,
  159. num_scales: int = 1,
  160. num_heads: int = 8,
  161. expand_ratio: float = 4,
  162. use_pos_emb: bool = True,
  163. conv_bias: bool = True,
  164. qkv_bias: bool = True,
  165. ls_init_value: float = 1e-6,
  166. norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  167. act_layer: Type[nn.Module] = nn.GELU,
  168. drop_path: float = 0.,
  169. attn_drop: float = 0.,
  170. proj_drop: float = 0.,
  171. device=None,
  172. dtype=None,
  173. ):
  174. dd = {'device': device, 'dtype': dtype}
  175. super().__init__()
  176. width = max(int(math.ceil(dim / num_scales)), int(math.floor(dim // num_scales)))
  177. self.width = width
  178. self.num_scales = max(1, num_scales - 1)
  179. convs = []
  180. for i in range(self.num_scales):
  181. convs.append(create_conv2d(width, width, kernel_size=3, depthwise=True, bias=conv_bias, **dd))
  182. self.convs = nn.ModuleList(convs)
  183. self.pos_embd = None
  184. if use_pos_emb:
  185. self.pos_embd = PositionalEncodingFourier(dim=dim, **dd)
  186. self.norm_xca = norm_layer(dim, **dd)
  187. self.gamma_xca = nn.Parameter(ls_init_value * torch.ones(dim, **dd)) if ls_init_value > 0 else None
  188. self.xca = CrossCovarianceAttn(
  189. dim,
  190. num_heads=num_heads,
  191. qkv_bias=qkv_bias,
  192. attn_drop=attn_drop,
  193. proj_drop=proj_drop,
  194. **dd,
  195. )
  196. self.norm = norm_layer(dim, eps=1e-6, **dd)
  197. self.mlp = Mlp(
  198. dim,
  199. int(expand_ratio * dim),
  200. act_layer=act_layer,
  201. **dd,
  202. )
  203. self.gamma = nn.Parameter(ls_init_value * torch.ones(dim, **dd)) if ls_init_value > 0 else None
  204. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  205. def forward(self, x):
  206. shortcut = x
  207. # scales code re-written for torchscript as per my res2net fixes -rw
  208. # NOTE torch.split(x, self.width, 1) causing issues with ONNX export
  209. spx = x.chunk(len(self.convs) + 1, dim=1)
  210. spo = []
  211. sp = spx[0]
  212. for i, conv in enumerate(self.convs):
  213. if i > 0:
  214. sp = sp + spx[i]
  215. sp = conv(sp)
  216. spo.append(sp)
  217. spo.append(spx[-1])
  218. x = torch.cat(spo, 1)
  219. # XCA
  220. B, C, H, W = x.shape
  221. x = x.reshape(B, C, H * W).permute(0, 2, 1)
  222. if self.pos_embd is not None:
  223. pos_encoding = self.pos_embd((B, H, W)).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
  224. x = x + pos_encoding
  225. x = x + self.drop_path(self.gamma_xca * self.xca(self.norm_xca(x)))
  226. x = x.reshape(B, H, W, C)
  227. # Inverted Bottleneck
  228. x = self.norm(x)
  229. x = self.mlp(x)
  230. if self.gamma is not None:
  231. x = self.gamma * x
  232. x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
  233. x = shortcut + self.drop_path(x)
  234. return x
  235. class EdgeNeXtStage(nn.Module):
  236. def __init__(
  237. self,
  238. in_chs: int,
  239. out_chs: int,
  240. stride: int = 2,
  241. depth: int = 2,
  242. num_global_blocks: int = 1,
  243. num_heads: int = 4,
  244. scales: int = 2,
  245. kernel_size: int = 7,
  246. expand_ratio: float = 4,
  247. use_pos_emb: bool = False,
  248. downsample_block: bool = False,
  249. conv_bias: float = True,
  250. ls_init_value: float = 1.0,
  251. drop_path_rates: Optional[List[float]] = None,
  252. norm_layer: Type[nn.Module] = LayerNorm2d,
  253. norm_layer_cl: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  254. act_layer: Type[nn.Module] = nn.GELU,
  255. device=None,
  256. dtype=None,
  257. ):
  258. dd = {'device': device, 'dtype': dtype}
  259. super().__init__()
  260. self.grad_checkpointing = False
  261. if downsample_block or stride == 1:
  262. self.downsample = nn.Identity()
  263. else:
  264. self.downsample = nn.Sequential(
  265. norm_layer(in_chs, **dd),
  266. nn.Conv2d(in_chs, out_chs, kernel_size=2, stride=2, bias=conv_bias, **dd)
  267. )
  268. in_chs = out_chs
  269. stage_blocks = []
  270. for i in range(depth):
  271. if i < depth - num_global_blocks:
  272. stage_blocks.append(
  273. ConvBlock(
  274. dim=in_chs,
  275. dim_out=out_chs,
  276. stride=stride if downsample_block and i == 0 else 1,
  277. conv_bias=conv_bias,
  278. kernel_size=kernel_size,
  279. expand_ratio=expand_ratio,
  280. ls_init_value=ls_init_value,
  281. drop_path=drop_path_rates[i],
  282. norm_layer=norm_layer_cl,
  283. act_layer=act_layer,
  284. **dd,
  285. )
  286. )
  287. else:
  288. stage_blocks.append(
  289. SplitTransposeBlock(
  290. dim=in_chs,
  291. num_scales=scales,
  292. num_heads=num_heads,
  293. expand_ratio=expand_ratio,
  294. use_pos_emb=use_pos_emb,
  295. conv_bias=conv_bias,
  296. ls_init_value=ls_init_value,
  297. drop_path=drop_path_rates[i],
  298. norm_layer=norm_layer_cl,
  299. act_layer=act_layer,
  300. **dd,
  301. )
  302. )
  303. in_chs = out_chs
  304. self.blocks = nn.Sequential(*stage_blocks)
  305. def forward(self, x):
  306. x = self.downsample(x)
  307. if self.grad_checkpointing and not torch.jit.is_scripting():
  308. x = checkpoint_seq(self.blocks, x)
  309. else:
  310. x = self.blocks(x)
  311. return x
  312. class EdgeNeXt(nn.Module):
  313. def __init__(
  314. self,
  315. in_chans: int = 3,
  316. num_classes: int = 1000,
  317. global_pool: str = 'avg',
  318. dims: Tuple[int, ...] = (24, 48, 88, 168),
  319. depths: Tuple[int, ...] = (3, 3, 9, 3),
  320. global_block_counts: Tuple[int, ...] = (0, 1, 1, 1),
  321. kernel_sizes: Tuple[int, ...] = (3, 5, 7, 9),
  322. heads: Tuple[int, ...] = (8, 8, 8, 8),
  323. d2_scales: Tuple[int, ...] = (2, 2, 3, 4),
  324. use_pos_emb: Tuple[bool, ...] = (False, True, False, False),
  325. ls_init_value: float = 1e-6,
  326. head_init_scale: float = 1.,
  327. expand_ratio: float = 4,
  328. downsample_block: bool = False,
  329. conv_bias: bool = True,
  330. stem_type: str = 'patch',
  331. head_norm_first: bool = False,
  332. act_layer: Type[nn.Module] = nn.GELU,
  333. drop_path_rate: float = 0.,
  334. drop_rate: float = 0.,
  335. device=None,
  336. dtype=None,
  337. ):
  338. super().__init__()
  339. dd = {'device': device, 'dtype': dtype}
  340. self.num_classes = num_classes
  341. self.in_chans = in_chans
  342. self.global_pool = global_pool
  343. self.drop_rate = drop_rate
  344. norm_layer = partial(LayerNorm2d, eps=1e-6)
  345. norm_layer_cl = partial(nn.LayerNorm, eps=1e-6)
  346. self.feature_info = []
  347. assert stem_type in ('patch', 'overlap')
  348. if stem_type == 'patch':
  349. self.stem = nn.Sequential(
  350. nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4, bias=conv_bias, **dd,),
  351. norm_layer(dims[0], **dd),
  352. )
  353. else:
  354. self.stem = nn.Sequential(
  355. nn.Conv2d(in_chans, dims[0], kernel_size=9, stride=4, padding=9 // 2, bias=conv_bias, **dd),
  356. norm_layer(dims[0], **dd),
  357. )
  358. curr_stride = 4
  359. stages = []
  360. dp_rates = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
  361. in_chs = dims[0]
  362. for i in range(4):
  363. stride = 2 if curr_stride == 2 or i > 0 else 1
  364. # FIXME support dilation / output_stride
  365. curr_stride *= stride
  366. stages.append(EdgeNeXtStage(
  367. in_chs=in_chs,
  368. out_chs=dims[i],
  369. stride=stride,
  370. depth=depths[i],
  371. num_global_blocks=global_block_counts[i],
  372. num_heads=heads[i],
  373. drop_path_rates=dp_rates[i],
  374. scales=d2_scales[i],
  375. expand_ratio=expand_ratio,
  376. kernel_size=kernel_sizes[i],
  377. use_pos_emb=use_pos_emb[i],
  378. ls_init_value=ls_init_value,
  379. downsample_block=downsample_block,
  380. conv_bias=conv_bias,
  381. norm_layer=norm_layer,
  382. norm_layer_cl=norm_layer_cl,
  383. act_layer=act_layer,
  384. **dd,
  385. ))
  386. # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
  387. in_chs = dims[i]
  388. self.feature_info += [dict(num_chs=in_chs, reduction=curr_stride, module=f'stages.{i}')]
  389. self.stages = nn.Sequential(*stages)
  390. self.num_features = self.head_hidden_size = dims[-1]
  391. if head_norm_first:
  392. self.norm_pre = norm_layer(self.num_features, **dd)
  393. self.head = ClassifierHead(
  394. self.num_features,
  395. num_classes,
  396. pool_type=global_pool,
  397. drop_rate=self.drop_rate,
  398. **dd,
  399. )
  400. else:
  401. self.norm_pre = nn.Identity()
  402. self.head = NormMlpClassifierHead(
  403. self.num_features,
  404. num_classes,
  405. pool_type=global_pool,
  406. drop_rate=self.drop_rate,
  407. norm_layer=norm_layer,
  408. **dd,
  409. )
  410. named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
  411. @torch.jit.ignore
  412. def group_matcher(self, coarse=False):
  413. return dict(
  414. stem=r'^stem',
  415. blocks=r'^stages\.(\d+)' if coarse else [
  416. (r'^stages\.(\d+)\.downsample', (0,)), # blocks
  417. (r'^stages\.(\d+)\.blocks\.(\d+)', None),
  418. (r'^norm_pre', (99999,))
  419. ]
  420. )
  421. @torch.jit.ignore
  422. def set_grad_checkpointing(self, enable=True):
  423. for s in self.stages:
  424. s.grad_checkpointing = enable
  425. @torch.jit.ignore
  426. def get_classifier(self) -> nn.Module:
  427. return self.head.fc
  428. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  429. self.num_classes = num_classes
  430. self.head.reset(num_classes, global_pool)
  431. def forward_intermediates(
  432. self,
  433. x: torch.Tensor,
  434. indices: Optional[Union[int, List[int]]] = None,
  435. norm: bool = False,
  436. stop_early: bool = False,
  437. output_fmt: str = 'NCHW',
  438. intermediates_only: bool = False,
  439. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  440. """ Forward features that returns intermediates.
  441. Args:
  442. x: Input image tensor
  443. indices: Take last n blocks if int, all if None, select matching indices if sequence
  444. norm: Apply norm layer to compatible intermediates
  445. stop_early: Stop iterating over blocks when last desired intermediate hit
  446. output_fmt: Shape of intermediate feature outputs
  447. intermediates_only: Only return intermediate features
  448. Returns:
  449. """
  450. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  451. intermediates = []
  452. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  453. # forward pass
  454. x = self.stem(x)
  455. last_idx = len(self.stages) - 1
  456. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  457. stages = self.stages
  458. else:
  459. stages = self.stages[:max_index + 1]
  460. for feat_idx, stage in enumerate(stages):
  461. x = stage(x)
  462. if feat_idx in take_indices:
  463. if norm and feat_idx == last_idx:
  464. x_inter = self.norm_pre(x) # applying final norm to last intermediate
  465. else:
  466. x_inter = x
  467. intermediates.append(x_inter)
  468. if intermediates_only:
  469. return intermediates
  470. if feat_idx == last_idx:
  471. x = self.norm_pre(x)
  472. return x, intermediates
  473. def prune_intermediate_layers(
  474. self,
  475. indices: Union[int, List[int]] = 1,
  476. prune_norm: bool = False,
  477. prune_head: bool = True,
  478. ):
  479. """ Prune layers not required for specified intermediates.
  480. """
  481. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  482. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  483. if prune_norm:
  484. self.norm_pre = nn.Identity()
  485. if prune_head:
  486. self.reset_classifier(0, '')
  487. return take_indices
  488. def forward_features(self, x):
  489. x = self.stem(x)
  490. x = self.stages(x)
  491. x = self.norm_pre(x)
  492. return x
  493. def forward_head(self, x, pre_logits: bool = False):
  494. return self.head(x, pre_logits=True) if pre_logits else self.head(x)
  495. def forward(self, x):
  496. x = self.forward_features(x)
  497. x = self.forward_head(x)
  498. return x
  499. def _init_weights(module, name=None, head_init_scale=1.0):
  500. if isinstance(module, nn.Conv2d):
  501. trunc_normal_tf_(module.weight, std=.02)
  502. if module.bias is not None:
  503. nn.init.zeros_(module.bias)
  504. elif isinstance(module, nn.Linear):
  505. trunc_normal_tf_(module.weight, std=.02)
  506. nn.init.zeros_(module.bias)
  507. if name and 'head.' in name:
  508. module.weight.data.mul_(head_init_scale)
  509. module.bias.data.mul_(head_init_scale)
  510. def checkpoint_filter_fn(state_dict, model):
  511. """ Remap FB checkpoints -> timm """
  512. if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict:
  513. return state_dict # non-FB checkpoint
  514. # models were released as train checkpoints... :/
  515. if 'model_ema' in state_dict:
  516. state_dict = state_dict['model_ema']
  517. elif 'model' in state_dict:
  518. state_dict = state_dict['model']
  519. elif 'state_dict' in state_dict:
  520. state_dict = state_dict['state_dict']
  521. out_dict = {}
  522. import re
  523. for k, v in state_dict.items():
  524. k = k.replace('downsample_layers.0.', 'stem.')
  525. k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
  526. k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k)
  527. k = k.replace('dwconv', 'conv_dw')
  528. k = k.replace('pwconv', 'mlp.fc')
  529. k = k.replace('head.', 'head.fc.')
  530. if k.startswith('norm.'):
  531. k = k.replace('norm', 'head.norm')
  532. if v.ndim == 2 and 'head' not in k:
  533. model_shape = model.state_dict()[k].shape
  534. v = v.reshape(model_shape)
  535. out_dict[k] = v
  536. return out_dict
  537. def _create_edgenext(variant, pretrained=False, **kwargs):
  538. model = build_model_with_cfg(
  539. EdgeNeXt, variant, pretrained,
  540. pretrained_filter_fn=checkpoint_filter_fn,
  541. feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
  542. **kwargs)
  543. return model
  544. def _cfg(url='', **kwargs):
  545. return {
  546. 'url': url,
  547. 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
  548. 'crop_pct': 0.9, 'interpolation': 'bicubic',
  549. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  550. 'first_conv': 'stem.0', 'classifier': 'head.fc',
  551. 'license': 'mit',
  552. **kwargs
  553. }
  554. default_cfgs = generate_default_cfgs({
  555. 'edgenext_xx_small.in1k': _cfg(
  556. hf_hub_id='timm/',
  557. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  558. 'edgenext_x_small.in1k': _cfg(
  559. hf_hub_id='timm/',
  560. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  561. 'edgenext_small.usi_in1k': _cfg( # USI weights
  562. hf_hub_id='timm/',
  563. crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
  564. ),
  565. 'edgenext_base.usi_in1k': _cfg( # USI weights
  566. hf_hub_id='timm/',
  567. crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
  568. ),
  569. 'edgenext_base.in21k_ft_in1k': _cfg( # USI weights
  570. hf_hub_id='timm/',
  571. crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
  572. ),
  573. 'edgenext_small_rw.sw_in1k': _cfg(
  574. hf_hub_id='timm/',
  575. test_input_size=(3, 320, 320), test_crop_pct=1.0,
  576. ),
  577. })
  578. @register_model
  579. def edgenext_xx_small(pretrained=False, **kwargs) -> EdgeNeXt:
  580. # 1.33M & 260.58M @ 256 resolution
  581. # 71.23% Top-1 accuracy
  582. # No AA, Color Jitter=0.4, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler
  583. # Jetson FPS=51.66 versus 47.67 for MobileViT_XXS
  584. # For A100: FPS @ BS=1: 212.13 & @ BS=256: 7042.06 versus FPS @ BS=1: 96.68 & @ BS=256: 4624.71 for MobileViT_XXS
  585. model_args = dict(depths=(2, 2, 6, 2), dims=(24, 48, 88, 168), heads=(4, 4, 4, 4))
  586. return _create_edgenext('edgenext_xx_small', pretrained=pretrained, **dict(model_args, **kwargs))
  587. @register_model
  588. def edgenext_x_small(pretrained=False, **kwargs) -> EdgeNeXt:
  589. # 2.34M & 538.0M @ 256 resolution
  590. # 75.00% Top-1 accuracy
  591. # No AA, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler
  592. # Jetson FPS=31.61 versus 28.49 for MobileViT_XS
  593. # For A100: FPS @ BS=1: 179.55 & @ BS=256: 4404.95 versus FPS @ BS=1: 94.55 & @ BS=256: 2361.53 for MobileViT_XS
  594. model_args = dict(depths=(3, 3, 9, 3), dims=(32, 64, 100, 192), heads=(4, 4, 4, 4))
  595. return _create_edgenext('edgenext_x_small', pretrained=pretrained, **dict(model_args, **kwargs))
  596. @register_model
  597. def edgenext_small(pretrained=False, **kwargs) -> EdgeNeXt:
  598. # 5.59M & 1260.59M @ 256 resolution
  599. # 79.43% Top-1 accuracy
  600. # AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
  601. # Jetson FPS=20.47 versus 18.86 for MobileViT_S
  602. # For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S
  603. model_args = dict(depths=(3, 3, 9, 3), dims=(48, 96, 160, 304))
  604. return _create_edgenext('edgenext_small', pretrained=pretrained, **dict(model_args, **kwargs))
  605. @register_model
  606. def edgenext_base(pretrained=False, **kwargs) -> EdgeNeXt:
  607. # 18.51M & 3840.93M @ 256 resolution
  608. # 82.5% (normal) 83.7% (USI) Top-1 accuracy
  609. # AA=True, Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
  610. # Jetson FPS=xx.xx versus xx.xx for MobileViT_S
  611. # For A100: FPS @ BS=1: xxx.xx & @ BS=256: xxxx.xx
  612. model_args = dict(depths=[3, 3, 9, 3], dims=[80, 160, 288, 584])
  613. return _create_edgenext('edgenext_base', pretrained=pretrained, **dict(model_args, **kwargs))
  614. @register_model
  615. def edgenext_small_rw(pretrained=False, **kwargs) -> EdgeNeXt:
  616. model_args = dict(
  617. depths=(3, 3, 9, 3), dims=(48, 96, 192, 384),
  618. downsample_block=True, conv_bias=False, stem_type='overlap')
  619. return _create_edgenext('edgenext_small_rw', pretrained=pretrained, **dict(model_args, **kwargs))