efficientvit_mit.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281
  1. """ EfficientViT (by MIT Song Han's Lab)
  2. Paper: `Efficientvit: Enhanced linear attention for high-resolution low-computation visual recognition`
  3. - https://arxiv.org/abs/2205.14756
  4. Adapted from official impl at https://github.com/mit-han-lab/efficientvit
  5. """
  6. __all__ = ['EfficientVit', 'EfficientVitLarge']
  7. from typing import List, Optional, Tuple, Type, Union
  8. from functools import partial
  9. import torch
  10. import torch.nn as nn
  11. import torch.nn.functional as F
  12. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  13. from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh
  14. from ._builder import build_model_with_cfg
  15. from ._features import feature_take_indices
  16. from ._features_fx import register_notrace_module
  17. from ._manipulate import checkpoint_seq
  18. from ._registry import register_model, generate_default_cfgs
  19. def val2list(x: list or tuple or any, repeat_time=1):
  20. if isinstance(x, (list, tuple)):
  21. return list(x)
  22. return [x for _ in range(repeat_time)]
  23. def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1):
  24. # repeat elements if necessary
  25. x = val2list(x)
  26. if len(x) > 0:
  27. x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
  28. return tuple(x)
  29. def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, ...]:
  30. if isinstance(kernel_size, tuple):
  31. return tuple([get_same_padding(ks) for ks in kernel_size])
  32. else:
  33. assert kernel_size % 2 > 0, "kernel size should be odd number"
  34. return kernel_size // 2
  35. class ConvNormAct(nn.Module):
  36. def __init__(
  37. self,
  38. in_channels: int,
  39. out_channels: int,
  40. kernel_size: Union[int, Tuple[int, int]] = 3,
  41. stride: int = 1,
  42. dilation: int = 1,
  43. groups: int = 1,
  44. bias: bool = False,
  45. dropout: float = 0.,
  46. norm_layer: Optional[Type[nn.Module]] = nn.BatchNorm2d,
  47. act_layer: Optional[Type[nn.Module]] = nn.ReLU,
  48. device=None,
  49. dtype=None,
  50. ):
  51. dd = {'device': device, 'dtype': dtype}
  52. super().__init__()
  53. self.dropout = nn.Dropout(dropout, inplace=False)
  54. self.conv = create_conv2d(
  55. in_channels,
  56. out_channels,
  57. kernel_size=kernel_size,
  58. stride=stride,
  59. dilation=dilation,
  60. groups=groups,
  61. bias=bias,
  62. **dd,
  63. )
  64. self.norm = norm_layer(num_features=out_channels, **dd) if norm_layer else nn.Identity()
  65. self.act = act_layer(inplace=True) if act_layer is not None else nn.Identity()
  66. def forward(self, x):
  67. x = self.dropout(x)
  68. x = self.conv(x)
  69. x = self.norm(x)
  70. x = self.act(x)
  71. return x
  72. class DSConv(nn.Module):
  73. def __init__(
  74. self,
  75. in_channels: int,
  76. out_channels: int,
  77. kernel_size: int = 3,
  78. stride: int = 1,
  79. use_bias: Union[bool, Tuple[bool, bool]] = False,
  80. norm_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = nn.BatchNorm2d,
  81. act_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = (nn.ReLU6, None),
  82. device=None,
  83. dtype=None,
  84. ):
  85. dd = {'device': device, 'dtype': dtype}
  86. super().__init__()
  87. use_bias = val2tuple(use_bias, 2)
  88. norm_layer = val2tuple(norm_layer, 2)
  89. act_layer = val2tuple(act_layer, 2)
  90. self.depth_conv = ConvNormAct(
  91. in_channels,
  92. in_channels,
  93. kernel_size,
  94. stride,
  95. groups=in_channels,
  96. norm_layer=norm_layer[0],
  97. act_layer=act_layer[0],
  98. bias=use_bias[0],
  99. **dd,
  100. )
  101. self.point_conv = ConvNormAct(
  102. in_channels,
  103. out_channels,
  104. 1,
  105. norm_layer=norm_layer[1],
  106. act_layer=act_layer[1],
  107. bias=use_bias[1],
  108. **dd,
  109. )
  110. def forward(self, x):
  111. x = self.depth_conv(x)
  112. x = self.point_conv(x)
  113. return x
  114. class ConvBlock(nn.Module):
  115. def __init__(
  116. self,
  117. in_channels: int,
  118. out_channels: int,
  119. kernel_size: int = 3,
  120. stride: int = 1,
  121. mid_channels: Optional[int] = None,
  122. expand_ratio: float = 1,
  123. use_bias: Union[bool, Tuple[bool, bool]] = False,
  124. norm_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = nn.BatchNorm2d,
  125. act_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = (nn.ReLU6, None),
  126. device=None,
  127. dtype=None,
  128. ):
  129. dd = {'device': device, 'dtype': dtype}
  130. super().__init__()
  131. use_bias = val2tuple(use_bias, 2)
  132. norm_layer = val2tuple(norm_layer, 2)
  133. act_layer = val2tuple(act_layer, 2)
  134. mid_channels = mid_channels or round(in_channels * expand_ratio)
  135. self.conv1 = ConvNormAct(
  136. in_channels,
  137. mid_channels,
  138. kernel_size,
  139. stride,
  140. norm_layer=norm_layer[0],
  141. act_layer=act_layer[0],
  142. bias=use_bias[0],
  143. **dd,
  144. )
  145. self.conv2 = ConvNormAct(
  146. mid_channels,
  147. out_channels,
  148. kernel_size,
  149. 1,
  150. norm_layer=norm_layer[1],
  151. act_layer=act_layer[1],
  152. bias=use_bias[1],
  153. **dd,
  154. )
  155. def forward(self, x):
  156. x = self.conv1(x)
  157. x = self.conv2(x)
  158. return x
  159. class MBConv(nn.Module):
  160. def __init__(
  161. self,
  162. in_channels: int,
  163. out_channels: int,
  164. kernel_size: int = 3,
  165. stride: int = 1,
  166. mid_channels: Optional[int] = None,
  167. expand_ratio: float = 6,
  168. use_bias: Union[bool, Tuple[bool, ...]] = False,
  169. norm_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = nn.BatchNorm2d,
  170. act_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = (nn.ReLU6, nn.ReLU6, None),
  171. device=None,
  172. dtype=None,
  173. ):
  174. dd = {'device': device, 'dtype': dtype}
  175. super().__init__()
  176. use_bias = val2tuple(use_bias, 3)
  177. norm_layer = val2tuple(norm_layer, 3)
  178. act_layer = val2tuple(act_layer, 3)
  179. mid_channels = mid_channels or round(in_channels * expand_ratio)
  180. self.inverted_conv = ConvNormAct(
  181. in_channels,
  182. mid_channels,
  183. 1,
  184. stride=1,
  185. norm_layer=norm_layer[0],
  186. act_layer=act_layer[0],
  187. bias=use_bias[0],
  188. **dd,
  189. )
  190. self.depth_conv = ConvNormAct(
  191. mid_channels,
  192. mid_channels,
  193. kernel_size,
  194. stride=stride,
  195. groups=mid_channels,
  196. norm_layer=norm_layer[1],
  197. act_layer=act_layer[1],
  198. bias=use_bias[1],
  199. **dd,
  200. )
  201. self.point_conv = ConvNormAct(
  202. mid_channels,
  203. out_channels,
  204. 1,
  205. norm_layer=norm_layer[2],
  206. act_layer=act_layer[2],
  207. bias=use_bias[2],
  208. **dd,
  209. )
  210. def forward(self, x):
  211. x = self.inverted_conv(x)
  212. x = self.depth_conv(x)
  213. x = self.point_conv(x)
  214. return x
  215. class FusedMBConv(nn.Module):
  216. def __init__(
  217. self,
  218. in_channels: int,
  219. out_channels: int,
  220. kernel_size: int = 3,
  221. stride: int = 1,
  222. mid_channels: Optional[int] = None,
  223. expand_ratio: float = 6,
  224. groups: int = 1,
  225. use_bias: Union[bool, Tuple[bool, ...]] = False,
  226. norm_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = nn.BatchNorm2d,
  227. act_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = (nn.ReLU6, None),
  228. device=None,
  229. dtype=None,
  230. ):
  231. dd = {'device': device, 'dtype': dtype}
  232. super().__init__()
  233. use_bias = val2tuple(use_bias, 2)
  234. norm_layer = val2tuple(norm_layer, 2)
  235. act_layer = val2tuple(act_layer, 2)
  236. mid_channels = mid_channels or round(in_channels * expand_ratio)
  237. self.spatial_conv = ConvNormAct(
  238. in_channels,
  239. mid_channels,
  240. kernel_size,
  241. stride=stride,
  242. groups=groups,
  243. norm_layer=norm_layer[0],
  244. act_layer=act_layer[0],
  245. bias=use_bias[0],
  246. **dd,
  247. )
  248. self.point_conv = ConvNormAct(
  249. mid_channels,
  250. out_channels,
  251. 1,
  252. norm_layer=norm_layer[1],
  253. act_layer=act_layer[1],
  254. bias=use_bias[1],
  255. **dd,
  256. )
  257. def forward(self, x):
  258. x = self.spatial_conv(x)
  259. x = self.point_conv(x)
  260. return x
  261. class LiteMLA(nn.Module):
  262. """Lightweight multi-scale linear attention"""
  263. def __init__(
  264. self,
  265. in_channels: int,
  266. out_channels: int,
  267. heads: Optional[int] = None,
  268. heads_ratio: float = 1.0,
  269. dim: int = 8,
  270. use_bias: Union[bool, Tuple[bool, ...]] = False,
  271. norm_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = (None, nn.BatchNorm2d),
  272. act_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = (None, None),
  273. kernel_func: Type[nn.Module] = nn.ReLU,
  274. scales: Tuple[int, ...] = (5,),
  275. eps: float = 1e-5,
  276. device=None,
  277. dtype=None,
  278. ):
  279. dd = {'device': device, 'dtype': dtype}
  280. super().__init__()
  281. self.eps = eps
  282. heads = heads or int(in_channels // dim * heads_ratio)
  283. total_dim = heads * dim
  284. use_bias = val2tuple(use_bias, 2)
  285. norm_layer = val2tuple(norm_layer, 2)
  286. act_layer = val2tuple(act_layer, 2)
  287. self.dim = dim
  288. self.qkv = ConvNormAct(
  289. in_channels,
  290. 3 * total_dim,
  291. 1,
  292. bias=use_bias[0],
  293. norm_layer=norm_layer[0],
  294. act_layer=act_layer[0],
  295. **dd,
  296. )
  297. self.aggreg = nn.ModuleList([
  298. nn.Sequential(
  299. nn.Conv2d(
  300. 3 * total_dim,
  301. 3 * total_dim,
  302. scale,
  303. padding=get_same_padding(scale),
  304. groups=3 * total_dim,
  305. bias=use_bias[0],
  306. **dd,
  307. ),
  308. nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0], **dd),
  309. )
  310. for scale in scales
  311. ])
  312. self.kernel_func = kernel_func(inplace=False)
  313. self.proj = ConvNormAct(
  314. total_dim * (1 + len(scales)),
  315. out_channels,
  316. 1,
  317. bias=use_bias[1],
  318. norm_layer=norm_layer[1],
  319. act_layer=act_layer[1],
  320. **dd,
  321. )
  322. def _attn(self, q, k, v):
  323. dtype = v.dtype
  324. q, k, v = q.float(), k.float(), v.float()
  325. kv = k.transpose(-1, -2) @ v
  326. out = q @ kv
  327. out = out[..., :-1] / (out[..., -1:] + self.eps)
  328. return out.to(dtype)
  329. def forward(self, x):
  330. B, _, H, W = x.shape
  331. # generate multi-scale q, k, v
  332. qkv = self.qkv(x)
  333. multi_scale_qkv = [qkv]
  334. for op in self.aggreg:
  335. multi_scale_qkv.append(op(qkv))
  336. multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1)
  337. multi_scale_qkv = multi_scale_qkv.reshape(B, -1, 3 * self.dim, H * W).transpose(-1, -2)
  338. q, k, v = multi_scale_qkv.chunk(3, dim=-1)
  339. # lightweight global attention
  340. q = self.kernel_func(q)
  341. k = self.kernel_func(k)
  342. v = F.pad(v, (0, 1), mode="constant", value=1.)
  343. if not torch.jit.is_scripting():
  344. with torch.autocast(device_type=v.device.type, enabled=False):
  345. out = self._attn(q, k, v)
  346. else:
  347. out = self._attn(q, k, v)
  348. # final projection
  349. out = out.transpose(-1, -2).reshape(B, -1, H, W)
  350. out = self.proj(out)
  351. return out
  352. register_notrace_module(LiteMLA)
  353. class EfficientVitBlock(nn.Module):
  354. def __init__(
  355. self,
  356. in_channels: int,
  357. heads_ratio: float = 1.0,
  358. head_dim: int = 32,
  359. expand_ratio: float = 4,
  360. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  361. act_layer: Type[nn.Module] = nn.Hardswish,
  362. device=None,
  363. dtype=None,
  364. ):
  365. dd = {'device': device, 'dtype': dtype}
  366. super().__init__()
  367. self.context_module = ResidualBlock(
  368. LiteMLA(
  369. in_channels=in_channels,
  370. out_channels=in_channels,
  371. heads_ratio=heads_ratio,
  372. dim=head_dim,
  373. norm_layer=(None, norm_layer),
  374. **dd,
  375. ),
  376. nn.Identity(),
  377. )
  378. self.local_module = ResidualBlock(
  379. MBConv(
  380. in_channels=in_channels,
  381. out_channels=in_channels,
  382. expand_ratio=expand_ratio,
  383. use_bias=(True, True, False),
  384. norm_layer=(None, None, norm_layer),
  385. act_layer=(act_layer, act_layer, None),
  386. **dd,
  387. ),
  388. nn.Identity(),
  389. )
  390. def forward(self, x):
  391. x = self.context_module(x)
  392. x = self.local_module(x)
  393. return x
  394. class ResidualBlock(nn.Module):
  395. def __init__(
  396. self,
  397. main: Optional[nn.Module],
  398. shortcut: Optional[nn.Module] = None,
  399. pre_norm: Optional[nn.Module] = None,
  400. ):
  401. super().__init__()
  402. self.pre_norm = pre_norm if pre_norm is not None else nn.Identity()
  403. self.main = main
  404. self.shortcut = shortcut
  405. def forward(self, x):
  406. res = self.main(self.pre_norm(x))
  407. if self.shortcut is not None:
  408. res = res + self.shortcut(x)
  409. return res
  410. def build_local_block(
  411. in_channels: int,
  412. out_channels: int,
  413. stride: int,
  414. expand_ratio: float,
  415. norm_layer: str,
  416. act_layer: str,
  417. fewer_norm: bool = False,
  418. block_type: str = "default",
  419. device=None,
  420. dtype=None,
  421. ):
  422. dd = {'device': device, 'dtype': dtype}
  423. assert block_type in ["default", "large", "fused"]
  424. if expand_ratio == 1:
  425. if block_type == "default":
  426. block = DSConv(
  427. in_channels=in_channels,
  428. out_channels=out_channels,
  429. stride=stride,
  430. use_bias=(True, False) if fewer_norm else False,
  431. norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
  432. act_layer=(act_layer, None),
  433. **dd,
  434. )
  435. else:
  436. block = ConvBlock(
  437. in_channels=in_channels,
  438. out_channels=out_channels,
  439. stride=stride,
  440. use_bias=(True, False) if fewer_norm else False,
  441. norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
  442. act_layer=(act_layer, None),
  443. **dd,
  444. )
  445. else:
  446. if block_type == "default":
  447. block = MBConv(
  448. in_channels=in_channels,
  449. out_channels=out_channels,
  450. stride=stride,
  451. expand_ratio=expand_ratio,
  452. use_bias=(True, True, False) if fewer_norm else False,
  453. norm_layer=(None, None, norm_layer) if fewer_norm else norm_layer,
  454. act_layer=(act_layer, act_layer, None),
  455. **dd,
  456. )
  457. else:
  458. block = FusedMBConv(
  459. in_channels=in_channels,
  460. out_channels=out_channels,
  461. stride=stride,
  462. expand_ratio=expand_ratio,
  463. use_bias=(True, False) if fewer_norm else False,
  464. norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
  465. act_layer=(act_layer, None),
  466. **dd,
  467. )
  468. return block
  469. class Stem(nn.Sequential):
  470. def __init__(
  471. self,
  472. in_chs: int,
  473. out_chs: int,
  474. depth: int,
  475. norm_layer: Type[nn.Module],
  476. act_layer: Type[nn.Module],
  477. block_type: str = 'default',
  478. device=None,
  479. dtype=None,
  480. ):
  481. super().__init__()
  482. dd = {'device': device, 'dtype': dtype}
  483. self.stride = 2
  484. self.add_module(
  485. 'in_conv',
  486. ConvNormAct(
  487. in_chs,
  488. out_chs,
  489. kernel_size=3,
  490. stride=2,
  491. norm_layer=norm_layer,
  492. act_layer=act_layer,
  493. **dd,
  494. )
  495. )
  496. stem_block = 0
  497. for _ in range(depth):
  498. self.add_module(f'res{stem_block}', ResidualBlock(
  499. build_local_block(
  500. in_channels=out_chs,
  501. out_channels=out_chs,
  502. stride=1,
  503. expand_ratio=1,
  504. norm_layer=norm_layer,
  505. act_layer=act_layer,
  506. block_type=block_type,
  507. **dd,
  508. ),
  509. nn.Identity(),
  510. ))
  511. stem_block += 1
  512. class EfficientVitStage(nn.Module):
  513. def __init__(
  514. self,
  515. in_chs: int,
  516. out_chs: int,
  517. depth: int,
  518. norm_layer: Type[nn.Module],
  519. act_layer: Type[nn.Module],
  520. expand_ratio: float,
  521. head_dim: int,
  522. vit_stage: bool = False,
  523. device=None,
  524. dtype=None,
  525. ):
  526. dd = {'device': device, 'dtype': dtype}
  527. super().__init__()
  528. blocks = [ResidualBlock(
  529. build_local_block(
  530. in_channels=in_chs,
  531. out_channels=out_chs,
  532. stride=2,
  533. expand_ratio=expand_ratio,
  534. norm_layer=norm_layer,
  535. act_layer=act_layer,
  536. fewer_norm=vit_stage,
  537. **dd,
  538. ),
  539. None,
  540. )]
  541. in_chs = out_chs
  542. if vit_stage:
  543. # for stage 3, 4
  544. for _ in range(depth):
  545. blocks.append(
  546. EfficientVitBlock(
  547. in_channels=in_chs,
  548. head_dim=head_dim,
  549. expand_ratio=expand_ratio,
  550. norm_layer=norm_layer,
  551. act_layer=act_layer,
  552. **dd,
  553. )
  554. )
  555. else:
  556. # for stage 1, 2
  557. for i in range(1, depth):
  558. blocks.append(ResidualBlock(
  559. build_local_block(
  560. in_channels=in_chs,
  561. out_channels=out_chs,
  562. stride=1,
  563. expand_ratio=expand_ratio,
  564. norm_layer=norm_layer,
  565. act_layer=act_layer,
  566. **dd,
  567. ),
  568. nn.Identity(),
  569. ))
  570. self.blocks = nn.Sequential(*blocks)
  571. def forward(self, x):
  572. return self.blocks(x)
  573. class EfficientVitLargeStage(nn.Module):
  574. def __init__(
  575. self,
  576. in_chs: int,
  577. out_chs: int,
  578. depth: int,
  579. norm_layer: Type[nn.Module],
  580. act_layer: Type[nn.Module],
  581. head_dim: int,
  582. vit_stage: bool = False,
  583. fewer_norm: bool = False,
  584. device=None,
  585. dtype=None,
  586. ):
  587. dd = {'device': device, 'dtype': dtype}
  588. super().__init__()
  589. blocks = [ResidualBlock(
  590. build_local_block(
  591. in_channels=in_chs,
  592. out_channels=out_chs,
  593. stride=2,
  594. expand_ratio=24 if vit_stage else 16,
  595. norm_layer=norm_layer,
  596. act_layer=act_layer,
  597. fewer_norm=vit_stage or fewer_norm,
  598. block_type='default' if fewer_norm else 'fused',
  599. **dd,
  600. ),
  601. None,
  602. )]
  603. in_chs = out_chs
  604. if vit_stage:
  605. # for stage 4
  606. for _ in range(depth):
  607. blocks.append(
  608. EfficientVitBlock(
  609. in_channels=in_chs,
  610. head_dim=head_dim,
  611. expand_ratio=6,
  612. norm_layer=norm_layer,
  613. act_layer=act_layer,
  614. **dd,
  615. )
  616. )
  617. else:
  618. # for stage 1, 2, 3
  619. for i in range(depth):
  620. blocks.append(ResidualBlock(
  621. build_local_block(
  622. in_channels=in_chs,
  623. out_channels=out_chs,
  624. stride=1,
  625. expand_ratio=4,
  626. norm_layer=norm_layer,
  627. act_layer=act_layer,
  628. fewer_norm=fewer_norm,
  629. block_type='default' if fewer_norm else 'fused',
  630. **dd,
  631. ),
  632. nn.Identity(),
  633. ))
  634. self.blocks = nn.Sequential(*blocks)
  635. def forward(self, x):
  636. return self.blocks(x)
  637. class ClassifierHead(nn.Module):
  638. def __init__(
  639. self,
  640. in_channels: int,
  641. widths: List[int],
  642. num_classes: int = 1000,
  643. dropout: float = 0.,
  644. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  645. act_layer: Optional[Type[nn.Module]] = nn.Hardswish,
  646. pool_type: str = 'avg',
  647. norm_eps: float = 1e-5,
  648. device=None,
  649. dtype=None,
  650. ):
  651. dd = {'device': device, 'dtype': dtype}
  652. super().__init__()
  653. self.widths = widths
  654. self.num_features = widths[-1]
  655. assert pool_type, 'Cannot disable pooling'
  656. self.in_conv = ConvNormAct(in_channels, widths[0], 1, norm_layer=norm_layer, act_layer=act_layer, **dd)
  657. self.global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True)
  658. self.classifier = nn.Sequential(
  659. nn.Linear(widths[0], widths[1], bias=False, **dd),
  660. nn.LayerNorm(widths[1], eps=norm_eps, **dd),
  661. act_layer(inplace=True) if act_layer is not None else nn.Identity(),
  662. nn.Dropout(dropout, inplace=False),
  663. nn.Linear(widths[1], num_classes, bias=True, **dd) if num_classes > 0 else nn.Identity(),
  664. )
  665. def reset(self, num_classes: int, pool_type: Optional[str] = None):
  666. if pool_type is not None:
  667. assert pool_type, 'Cannot disable pooling'
  668. self.global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True,)
  669. if num_classes > 0:
  670. self.classifier[-1] = nn.Linear(self.num_features, num_classes, bias=True)
  671. else:
  672. self.classifier[-1] = nn.Identity()
  673. def forward(self, x, pre_logits: bool = False):
  674. x = self.in_conv(x)
  675. x = self.global_pool(x)
  676. if pre_logits:
  677. # cannot slice or iterate with torchscript so, this
  678. x = self.classifier[0](x)
  679. x = self.classifier[1](x)
  680. x = self.classifier[2](x)
  681. x = self.classifier[3](x)
  682. else:
  683. x = self.classifier(x)
  684. return x
  685. class EfficientVit(nn.Module):
  686. def __init__(
  687. self,
  688. in_chans: int = 3,
  689. widths: Tuple[int, ...] = (),
  690. depths: Tuple[int, ...] = (),
  691. head_dim: int = 32,
  692. expand_ratio: float = 4,
  693. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  694. act_layer: Type[nn.Module] = nn.Hardswish,
  695. global_pool: str = 'avg',
  696. head_widths: Tuple[int, ...] = (),
  697. drop_rate: float = 0.0,
  698. num_classes: int = 1000,
  699. device=None,
  700. dtype=None,
  701. ):
  702. dd = {'device': device, 'dtype': dtype}
  703. super().__init__()
  704. self.grad_checkpointing = False
  705. self.global_pool = global_pool
  706. self.num_classes = num_classes
  707. self.in_chans = in_chans
  708. # input stem
  709. self.stem = Stem(in_chans, widths[0], depths[0], norm_layer, act_layer, **dd)
  710. stride = self.stem.stride
  711. # stages
  712. self.feature_info = []
  713. self.stages = nn.Sequential()
  714. in_channels = widths[0]
  715. for i, (w, d) in enumerate(zip(widths[1:], depths[1:])):
  716. self.stages.append(EfficientVitStage(
  717. in_channels,
  718. w,
  719. depth=d,
  720. norm_layer=norm_layer,
  721. act_layer=act_layer,
  722. expand_ratio=expand_ratio,
  723. head_dim=head_dim,
  724. vit_stage=i >= 2,
  725. **dd,
  726. ))
  727. stride *= 2
  728. in_channels = w
  729. self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{i}')]
  730. self.num_features = in_channels
  731. self.head = ClassifierHead(
  732. self.num_features,
  733. widths=head_widths,
  734. num_classes=num_classes,
  735. dropout=drop_rate,
  736. pool_type=self.global_pool,
  737. **dd,
  738. )
  739. self.head_hidden_size = self.head.num_features
  740. @torch.jit.ignore
  741. def group_matcher(self, coarse=False):
  742. matcher = dict(
  743. stem=r'^stem',
  744. blocks=r'^stages\.(\d+)' if coarse else [
  745. (r'^stages\.(\d+).downsample', (0,)),
  746. (r'^stages\.(\d+)\.\w+\.(\d+)', None),
  747. ]
  748. )
  749. return matcher
  750. @torch.jit.ignore
  751. def set_grad_checkpointing(self, enable=True):
  752. self.grad_checkpointing = enable
  753. @torch.jit.ignore
  754. def get_classifier(self) -> nn.Module:
  755. return self.head.classifier[-1]
  756. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  757. self.num_classes = num_classes
  758. self.head.reset(num_classes, global_pool)
  759. def forward_intermediates(
  760. self,
  761. x: torch.Tensor,
  762. indices: Optional[Union[int, List[int]]] = None,
  763. norm: bool = False,
  764. stop_early: bool = False,
  765. output_fmt: str = 'NCHW',
  766. intermediates_only: bool = False,
  767. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  768. """ Forward features that returns intermediates.
  769. Args:
  770. x: Input image tensor
  771. indices: Take last n blocks if int, all if None, select matching indices if sequence
  772. norm: Apply norm layer to compatible intermediates
  773. stop_early: Stop iterating over blocks when last desired intermediate hit
  774. output_fmt: Shape of intermediate feature outputs
  775. intermediates_only: Only return intermediate features
  776. Returns:
  777. """
  778. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  779. intermediates = []
  780. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  781. # forward pass
  782. x = self.stem(x)
  783. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  784. stages = self.stages
  785. else:
  786. stages = self.stages[:max_index + 1]
  787. for feat_idx, stage in enumerate(stages):
  788. if self.grad_checkpointing and not torch.jit.is_scripting():
  789. x = checkpoint_seq(stages, x)
  790. else:
  791. x = stage(x)
  792. if feat_idx in take_indices:
  793. intermediates.append(x)
  794. if intermediates_only:
  795. return intermediates
  796. return x, intermediates
  797. def prune_intermediate_layers(
  798. self,
  799. indices: Union[int, List[int]] = 1,
  800. prune_norm: bool = False,
  801. prune_head: bool = True,
  802. ):
  803. """ Prune layers not required for specified intermediates.
  804. """
  805. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  806. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  807. if prune_head:
  808. self.reset_classifier(0, '')
  809. return take_indices
  810. def forward_features(self, x):
  811. x = self.stem(x)
  812. if self.grad_checkpointing and not torch.jit.is_scripting():
  813. x = checkpoint_seq(self.stages, x)
  814. else:
  815. x = self.stages(x)
  816. return x
  817. def forward_head(self, x, pre_logits: bool = False):
  818. return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
  819. def forward(self, x):
  820. x = self.forward_features(x)
  821. x = self.forward_head(x)
  822. return x
  823. class EfficientVitLarge(nn.Module):
  824. def __init__(
  825. self,
  826. in_chans: int = 3,
  827. widths: Tuple[int, ...] = (),
  828. depths: Tuple[int, ...] = (),
  829. head_dim: int = 32,
  830. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  831. act_layer: Type[nn.Module] = GELUTanh,
  832. global_pool: str = 'avg',
  833. head_widths: Tuple[int, ...] = (),
  834. drop_rate: float = 0.0,
  835. num_classes: int = 1000,
  836. norm_eps: float = 1e-7,
  837. device=None,
  838. dtype=None,
  839. ):
  840. dd = {'device': device, 'dtype': dtype}
  841. super().__init__()
  842. self.grad_checkpointing = False
  843. self.global_pool = global_pool
  844. self.num_classes = num_classes
  845. self.in_chans = in_chans
  846. self.norm_eps = norm_eps
  847. norm_layer = partial(norm_layer, eps=self.norm_eps)
  848. # input stem
  849. self.stem = Stem(in_chans, widths[0], depths[0], norm_layer, act_layer, block_type='large', **dd)
  850. stride = self.stem.stride
  851. # stages
  852. self.feature_info = []
  853. self.stages = nn.Sequential()
  854. in_channels = widths[0]
  855. for i, (w, d) in enumerate(zip(widths[1:], depths[1:])):
  856. self.stages.append(EfficientVitLargeStage(
  857. in_channels,
  858. w,
  859. depth=d,
  860. norm_layer=norm_layer,
  861. act_layer=act_layer,
  862. head_dim=head_dim,
  863. vit_stage=i >= 3,
  864. fewer_norm=i >= 2,
  865. **dd,
  866. ))
  867. stride *= 2
  868. in_channels = w
  869. self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{i}')]
  870. self.num_features = in_channels
  871. self.head = ClassifierHead(
  872. self.num_features,
  873. widths=head_widths,
  874. num_classes=num_classes,
  875. dropout=drop_rate,
  876. pool_type=self.global_pool,
  877. act_layer=act_layer,
  878. norm_eps=self.norm_eps,
  879. **dd,
  880. )
  881. self.head_hidden_size = self.head.num_features
  882. @torch.jit.ignore
  883. def group_matcher(self, coarse=False):
  884. matcher = dict(
  885. stem=r'^stem',
  886. blocks=r'^stages\.(\d+)' if coarse else [
  887. (r'^stages\.(\d+).downsample', (0,)),
  888. (r'^stages\.(\d+)\.\w+\.(\d+)', None),
  889. ]
  890. )
  891. return matcher
  892. @torch.jit.ignore
  893. def set_grad_checkpointing(self, enable=True):
  894. self.grad_checkpointing = enable
  895. @torch.jit.ignore
  896. def get_classifier(self) -> nn.Module:
  897. return self.head.classifier[-1]
  898. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  899. self.num_classes = num_classes
  900. self.head.reset(num_classes, global_pool)
  901. def forward_intermediates(
  902. self,
  903. x: torch.Tensor,
  904. indices: Optional[Union[int, List[int]]] = None,
  905. norm: bool = False,
  906. stop_early: bool = False,
  907. output_fmt: str = 'NCHW',
  908. intermediates_only: bool = False,
  909. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  910. """ Forward features that returns intermediates.
  911. Args:
  912. x: Input image tensor
  913. indices: Take last n blocks if int, all if None, select matching indices if sequence
  914. norm: Apply norm layer to compatible intermediates
  915. stop_early: Stop iterating over blocks when last desired intermediate hit
  916. output_fmt: Shape of intermediate feature outputs
  917. intermediates_only: Only return intermediate features
  918. Returns:
  919. """
  920. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  921. intermediates = []
  922. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  923. # forward pass
  924. x = self.stem(x)
  925. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  926. stages = self.stages
  927. else:
  928. stages = self.stages[:max_index + 1]
  929. for feat_idx, stage in enumerate(stages):
  930. if self.grad_checkpointing and not torch.jit.is_scripting():
  931. x = checkpoint_seq(stages, x)
  932. else:
  933. x = stage(x)
  934. if feat_idx in take_indices:
  935. intermediates.append(x)
  936. if intermediates_only:
  937. return intermediates
  938. return x, intermediates
  939. def prune_intermediate_layers(
  940. self,
  941. indices: Union[int, List[int]] = 1,
  942. prune_norm: bool = False,
  943. prune_head: bool = True,
  944. ):
  945. """ Prune layers not required for specified intermediates.
  946. """
  947. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  948. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  949. if prune_head:
  950. self.reset_classifier(0, '')
  951. return take_indices
  952. def forward_features(self, x):
  953. x = self.stem(x)
  954. if self.grad_checkpointing and not torch.jit.is_scripting():
  955. x = checkpoint_seq(self.stages, x)
  956. else:
  957. x = self.stages(x)
  958. return x
  959. def forward_head(self, x, pre_logits: bool = False):
  960. return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
  961. def forward(self, x):
  962. x = self.forward_features(x)
  963. x = self.forward_head(x)
  964. return x
  965. def _cfg(url='', **kwargs):
  966. return {
  967. 'url': url,
  968. 'num_classes': 1000,
  969. 'mean': IMAGENET_DEFAULT_MEAN,
  970. 'std': IMAGENET_DEFAULT_STD,
  971. 'first_conv': 'stem.in_conv.conv',
  972. 'classifier': 'head.classifier.4',
  973. 'crop_pct': 0.95,
  974. 'license': 'apache-2.0',
  975. 'input_size': (3, 224, 224),
  976. 'pool_size': (7, 7),
  977. **kwargs,
  978. }
  979. default_cfgs = generate_default_cfgs({
  980. 'efficientvit_b0.r224_in1k': _cfg(
  981. hf_hub_id='timm/',
  982. ),
  983. 'efficientvit_b1.r224_in1k': _cfg(
  984. hf_hub_id='timm/',
  985. ),
  986. 'efficientvit_b1.r256_in1k': _cfg(
  987. hf_hub_id='timm/',
  988. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
  989. ),
  990. 'efficientvit_b1.r288_in1k': _cfg(
  991. hf_hub_id='timm/',
  992. input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0,
  993. ),
  994. 'efficientvit_b2.r224_in1k': _cfg(
  995. hf_hub_id='timm/',
  996. ),
  997. 'efficientvit_b2.r256_in1k': _cfg(
  998. hf_hub_id='timm/',
  999. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
  1000. ),
  1001. 'efficientvit_b2.r288_in1k': _cfg(
  1002. hf_hub_id='timm/',
  1003. input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0,
  1004. ),
  1005. 'efficientvit_b3.r224_in1k': _cfg(
  1006. hf_hub_id='timm/',
  1007. ),
  1008. 'efficientvit_b3.r256_in1k': _cfg(
  1009. hf_hub_id='timm/',
  1010. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
  1011. ),
  1012. 'efficientvit_b3.r288_in1k': _cfg(
  1013. hf_hub_id='timm/',
  1014. input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0,
  1015. ),
  1016. 'efficientvit_l1.r224_in1k': _cfg(
  1017. hf_hub_id='timm/',
  1018. crop_pct=1.0,
  1019. ),
  1020. 'efficientvit_l2.r224_in1k': _cfg(
  1021. hf_hub_id='timm/',
  1022. crop_pct=1.0,
  1023. ),
  1024. 'efficientvit_l2.r256_in1k': _cfg(
  1025. hf_hub_id='timm/',
  1026. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
  1027. ),
  1028. 'efficientvit_l2.r288_in1k': _cfg(
  1029. hf_hub_id='timm/',
  1030. input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0,
  1031. ),
  1032. 'efficientvit_l2.r384_in1k': _cfg(
  1033. hf_hub_id='timm/',
  1034. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
  1035. ),
  1036. 'efficientvit_l3.r224_in1k': _cfg(
  1037. hf_hub_id='timm/',
  1038. crop_pct=1.0,
  1039. ),
  1040. 'efficientvit_l3.r256_in1k': _cfg(
  1041. hf_hub_id='timm/',
  1042. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
  1043. ),
  1044. 'efficientvit_l3.r320_in1k': _cfg(
  1045. hf_hub_id='timm/',
  1046. input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0,
  1047. ),
  1048. 'efficientvit_l3.r384_in1k': _cfg(
  1049. hf_hub_id='timm/',
  1050. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
  1051. ),
  1052. # 'efficientvit_l0_sam.sam': _cfg(
  1053. # # hf_hub_id='timm/',
  1054. # input_size=(3, 512, 512), crop_pct=1.0,
  1055. # num_classes=0,
  1056. # ),
  1057. # 'efficientvit_l1_sam.sam': _cfg(
  1058. # # hf_hub_id='timm/',
  1059. # input_size=(3, 512, 512), crop_pct=1.0,
  1060. # num_classes=0,
  1061. # ),
  1062. # 'efficientvit_l2_sam.sam': _cfg(
  1063. # # hf_hub_id='timm/',f
  1064. # input_size=(3, 512, 512), crop_pct=1.0,
  1065. # num_classes=0,
  1066. # ),
  1067. })
  1068. def _create_efficientvit(variant, pretrained=False, **kwargs):
  1069. out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
  1070. model = build_model_with_cfg(
  1071. EfficientVit,
  1072. variant,
  1073. pretrained,
  1074. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  1075. **kwargs
  1076. )
  1077. return model
  1078. def _create_efficientvit_large(variant, pretrained=False, **kwargs):
  1079. out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
  1080. model = build_model_with_cfg(
  1081. EfficientVitLarge,
  1082. variant,
  1083. pretrained,
  1084. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  1085. **kwargs
  1086. )
  1087. return model
  1088. @register_model
  1089. def efficientvit_b0(pretrained=False, **kwargs):
  1090. model_args = dict(
  1091. widths=(8, 16, 32, 64, 128), depths=(1, 2, 2, 2, 2), head_dim=16, head_widths=(1024, 1280))
  1092. return _create_efficientvit('efficientvit_b0', pretrained=pretrained, **dict(model_args, **kwargs))
  1093. @register_model
  1094. def efficientvit_b1(pretrained=False, **kwargs):
  1095. model_args = dict(
  1096. widths=(16, 32, 64, 128, 256), depths=(1, 2, 3, 3, 4), head_dim=16, head_widths=(1536, 1600))
  1097. return _create_efficientvit('efficientvit_b1', pretrained=pretrained, **dict(model_args, **kwargs))
  1098. @register_model
  1099. def efficientvit_b2(pretrained=False, **kwargs):
  1100. model_args = dict(
  1101. widths=(24, 48, 96, 192, 384), depths=(1, 3, 4, 4, 6), head_dim=32, head_widths=(2304, 2560))
  1102. return _create_efficientvit('efficientvit_b2', pretrained=pretrained, **dict(model_args, **kwargs))
  1103. @register_model
  1104. def efficientvit_b3(pretrained=False, **kwargs):
  1105. model_args = dict(
  1106. widths=(32, 64, 128, 256, 512), depths=(1, 4, 6, 6, 9), head_dim=32, head_widths=(2304, 2560))
  1107. return _create_efficientvit('efficientvit_b3', pretrained=pretrained, **dict(model_args, **kwargs))
  1108. @register_model
  1109. def efficientvit_l1(pretrained=False, **kwargs):
  1110. model_args = dict(
  1111. widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 6, 6), head_dim=32, head_widths=(3072, 3200))
  1112. return _create_efficientvit_large('efficientvit_l1', pretrained=pretrained, **dict(model_args, **kwargs))
  1113. @register_model
  1114. def efficientvit_l2(pretrained=False, **kwargs):
  1115. model_args = dict(
  1116. widths=(32, 64, 128, 256, 512), depths=(1, 2, 2, 8, 8), head_dim=32, head_widths=(3072, 3200))
  1117. return _create_efficientvit_large('efficientvit_l2', pretrained=pretrained, **dict(model_args, **kwargs))
  1118. @register_model
  1119. def efficientvit_l3(pretrained=False, **kwargs):
  1120. model_args = dict(
  1121. widths=(64, 128, 256, 512, 1024), depths=(1, 2, 2, 8, 8), head_dim=32, head_widths=(6144, 6400))
  1122. return _create_efficientvit_large('efficientvit_l3', pretrained=pretrained, **dict(model_args, **kwargs))
  1123. # FIXME will wait for v2 SAM models which are pending
  1124. # @register_model
  1125. # def efficientvit_l0_sam(pretrained=False, **kwargs):
  1126. # # only backbone for segment-anything-model weights
  1127. # model_args = dict(
  1128. # widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 4, 4), head_dim=32, num_classes=0, norm_eps=1e-6)
  1129. # return _create_efficientvit_large('efficientvit_l0_sam', pretrained=pretrained, **dict(model_args, **kwargs))
  1130. #
  1131. #
  1132. # @register_model
  1133. # def efficientvit_l1_sam(pretrained=False, **kwargs):
  1134. # # only backbone for segment-anything-model weights
  1135. # model_args = dict(
  1136. # widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 6, 6), head_dim=32, num_classes=0, norm_eps=1e-6)
  1137. # return _create_efficientvit_large('efficientvit_l1_sam', pretrained=pretrained, **dict(model_args, **kwargs))
  1138. #
  1139. #
  1140. # @register_model
  1141. # def efficientvit_l2_sam(pretrained=False, **kwargs):
  1142. # # only backbone for segment-anything-model weights
  1143. # model_args = dict(
  1144. # widths=(32, 64, 128, 256, 512), depths=(1, 2, 2, 8, 8), head_dim=32, num_classes=0, norm_eps=1e-6)
  1145. # return _create_efficientvit_large('efficientvit_l2_sam', pretrained=pretrained, **dict(model_args, **kwargs))