mvitv2.py 40 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156
  1. """ Multi-Scale Vision Transformer v2
  2. @inproceedings{li2021improved,
  3. title={MViTv2: Improved multiscale vision transformers for classification and detection},
  4. author={Li, Yanghao and Wu, Chao-Yuan and Fan, Haoqi and Mangalam, Karttikeya and Xiong, Bo and Malik, Jitendra and Feichtenhofer, Christoph},
  5. booktitle={CVPR},
  6. year={2022}
  7. }
  8. Code adapted from original Apache 2.0 licensed impl at https://github.com/facebookresearch/mvit
  9. Original copyright below.
  10. Modifications and timm support by / Copyright 2022, Ross Wightman
  11. """
  12. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. All Rights Reserved.
  13. import operator
  14. from collections import OrderedDict
  15. from dataclasses import dataclass
  16. from functools import partial, reduce
  17. from typing import Union, List, Tuple, Optional, Any, Type
  18. import torch
  19. from torch import nn
  20. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  21. from timm.layers import Mlp, DropPath, calculate_drop_path_rates, trunc_normal_tf_, get_norm_layer, to_2tuple
  22. from ._builder import build_model_with_cfg
  23. from ._features import feature_take_indices
  24. from ._features_fx import register_notrace_function
  25. from ._manipulate import checkpoint
  26. from ._registry import register_model, generate_default_cfgs
  27. __all__ = ['MultiScaleVit', 'MultiScaleVitCfg'] # model_registry will add each entrypoint fn to this
  28. @dataclass
  29. class MultiScaleVitCfg:
  30. depths: Tuple[int, ...] = (2, 3, 16, 3)
  31. embed_dim: Union[int, Tuple[int, ...]] = 96
  32. num_heads: Union[int, Tuple[int, ...]] = 1
  33. mlp_ratio: float = 4.
  34. pool_first: bool = False
  35. expand_attn: bool = True
  36. qkv_bias: bool = True
  37. use_cls_token: bool = False
  38. use_abs_pos: bool = False
  39. residual_pooling: bool = True
  40. mode: str = 'conv'
  41. kernel_qkv: Tuple[int, int] = (3, 3)
  42. stride_q: Optional[Tuple[Tuple[int, int]]] = ((1, 1), (2, 2), (2, 2), (2, 2))
  43. stride_kv: Optional[Tuple[Tuple[int, int]]] = None
  44. stride_kv_adaptive: Optional[Tuple[int, int]] = (4, 4)
  45. patch_kernel: Tuple[int, int] = (7, 7)
  46. patch_stride: Tuple[int, int] = (4, 4)
  47. patch_padding: Tuple[int, int] = (3, 3)
  48. pool_type: str = 'max'
  49. rel_pos_type: str = 'spatial'
  50. act_layer: Union[str, Tuple[str, str]] = 'gelu'
  51. norm_layer: Union[str, Tuple[str, str]] = 'layernorm'
  52. norm_eps: float = 1e-6
  53. def __post_init__(self):
  54. num_stages = len(self.depths)
  55. if not isinstance(self.embed_dim, (tuple, list)):
  56. self.embed_dim = tuple(self.embed_dim * 2 ** i for i in range(num_stages))
  57. assert len(self.embed_dim) == num_stages
  58. if not isinstance(self.num_heads, (tuple, list)):
  59. self.num_heads = tuple(self.num_heads * 2 ** i for i in range(num_stages))
  60. assert len(self.num_heads) == num_stages
  61. if self.stride_kv_adaptive is not None and self.stride_kv is None:
  62. _stride_kv = self.stride_kv_adaptive
  63. pool_kv_stride = []
  64. for i in range(num_stages):
  65. if min(self.stride_q[i]) > 1:
  66. _stride_kv = [
  67. max(_stride_kv[d] // self.stride_q[i][d], 1)
  68. for d in range(len(_stride_kv))
  69. ]
  70. pool_kv_stride.append(tuple(_stride_kv))
  71. self.stride_kv = tuple(pool_kv_stride)
  72. def prod(iterable):
  73. return reduce(operator.mul, iterable, 1)
  74. class PatchEmbed(nn.Module):
  75. """
  76. PatchEmbed.
  77. """
  78. def __init__(
  79. self,
  80. dim_in: int = 3,
  81. dim_out: int = 768,
  82. kernel: Tuple[int, int] = (7, 7),
  83. stride: Tuple[int, int] = (4, 4),
  84. padding: Tuple[int, int] = (3, 3),
  85. device=None,
  86. dtype=None,
  87. ):
  88. super().__init__()
  89. dd = {'device': device, 'dtype': dtype}
  90. self.proj = nn.Conv2d(
  91. dim_in,
  92. dim_out,
  93. kernel_size=kernel,
  94. stride=stride,
  95. padding=padding,
  96. **dd,
  97. )
  98. def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
  99. x = self.proj(x)
  100. # B C H W -> B HW C
  101. return x.flatten(2).transpose(1, 2), x.shape[-2:]
  102. @register_notrace_function
  103. def reshape_pre_pool(
  104. x,
  105. feat_size: List[int],
  106. has_cls_token: bool = True
  107. ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  108. H, W = feat_size
  109. if has_cls_token:
  110. cls_tok, x = x[:, :, :1, :], x[:, :, 1:, :]
  111. else:
  112. cls_tok = None
  113. x = x.reshape(-1, H, W, x.shape[-1]).permute(0, 3, 1, 2).contiguous()
  114. return x, cls_tok
  115. @register_notrace_function
  116. def reshape_post_pool(
  117. x,
  118. num_heads: int,
  119. cls_tok: Optional[torch.Tensor] = None
  120. ) -> Tuple[torch.Tensor, List[int]]:
  121. feat_size = [x.shape[2], x.shape[3]]
  122. L_pooled = x.shape[2] * x.shape[3]
  123. x = x.reshape(-1, num_heads, x.shape[1], L_pooled).transpose(2, 3)
  124. if cls_tok is not None:
  125. x = torch.cat((cls_tok, x), dim=2)
  126. return x, feat_size
  127. @register_notrace_function
  128. def cal_rel_pos_type(
  129. attn: torch.Tensor,
  130. q: torch.Tensor,
  131. has_cls_token: bool,
  132. q_size: List[int],
  133. k_size: List[int],
  134. rel_pos_h: torch.Tensor,
  135. rel_pos_w: torch.Tensor,
  136. ):
  137. """
  138. Spatial Relative Positional Embeddings.
  139. """
  140. sp_idx = 1 if has_cls_token else 0
  141. q_h, q_w = q_size
  142. k_h, k_w = k_size
  143. # Scale up rel pos if shapes for q and k are different.
  144. q_h_ratio = max(k_h / q_h, 1.0)
  145. k_h_ratio = max(q_h / k_h, 1.0)
  146. dist_h = (
  147. torch.arange(q_h, device=q.device, dtype=torch.long).unsqueeze(-1) * q_h_ratio -
  148. torch.arange(k_h, device=q.device, dtype=torch.long).unsqueeze(0) * k_h_ratio
  149. )
  150. dist_h += (k_h - 1) * k_h_ratio
  151. q_w_ratio = max(k_w / q_w, 1.0)
  152. k_w_ratio = max(q_w / k_w, 1.0)
  153. dist_w = (
  154. torch.arange(q_w, device=q.device, dtype=torch.long).unsqueeze(-1) * q_w_ratio -
  155. torch.arange(k_w, device=q.device, dtype=torch.long).unsqueeze(0) * k_w_ratio
  156. )
  157. dist_w += (k_w - 1) * k_w_ratio
  158. rel_h = rel_pos_h[dist_h.long()]
  159. rel_w = rel_pos_w[dist_w.long()]
  160. B, n_head, q_N, dim = q.shape
  161. r_q = q[:, :, sp_idx:].reshape(B, n_head, q_h, q_w, dim)
  162. rel_h = torch.einsum("byhwc,hkc->byhwk", r_q, rel_h)
  163. rel_w = torch.einsum("byhwc,wkc->byhwk", r_q, rel_w)
  164. attn[:, :, sp_idx:, sp_idx:] = (
  165. attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w)
  166. + rel_h.unsqueeze(-1)
  167. + rel_w.unsqueeze(-2)
  168. ).view(B, -1, q_h * q_w, k_h * k_w)
  169. return attn
  170. class MultiScaleAttentionPoolFirst(nn.Module):
  171. def __init__(
  172. self,
  173. dim: int,
  174. dim_out: int,
  175. feat_size: Tuple[int, int],
  176. num_heads: int = 8,
  177. qkv_bias: bool = True,
  178. mode: str = "conv",
  179. kernel_q: Tuple[int, int] = (1, 1),
  180. kernel_kv: Tuple[int, int] = (1, 1),
  181. stride_q: Tuple[int, int] = (1, 1),
  182. stride_kv: Tuple[int, int] = (1, 1),
  183. has_cls_token: bool = True,
  184. rel_pos_type: str = 'spatial',
  185. residual_pooling: bool = True,
  186. norm_layer: Type[nn.Module] = nn.LayerNorm,
  187. device=None,
  188. dtype=None,
  189. ):
  190. dd = {'device': device, 'dtype': dtype}
  191. super().__init__()
  192. self.num_heads = num_heads
  193. self.dim_out = dim_out
  194. self.head_dim = dim_out // num_heads
  195. self.scale = self.head_dim ** -0.5
  196. self.has_cls_token = has_cls_token
  197. padding_q = tuple([int(q // 2) for q in kernel_q])
  198. padding_kv = tuple([int(kv // 2) for kv in kernel_kv])
  199. self.q = nn.Linear(dim, dim_out, bias=qkv_bias, **dd)
  200. self.k = nn.Linear(dim, dim_out, bias=qkv_bias, **dd)
  201. self.v = nn.Linear(dim, dim_out, bias=qkv_bias, **dd)
  202. self.proj = nn.Linear(dim_out, dim_out, **dd)
  203. # Skip pooling with kernel and stride size of (1, 1, 1).
  204. if prod(kernel_q) == 1 and prod(stride_q) == 1:
  205. kernel_q = None
  206. if prod(kernel_kv) == 1 and prod(stride_kv) == 1:
  207. kernel_kv = None
  208. self.mode = mode
  209. self.unshared = mode == 'conv_unshared'
  210. self.pool_q, self.pool_k, self.pool_v = None, None, None
  211. self.norm_q, self.norm_k, self.norm_v = None, None, None
  212. if mode in ("avg", "max"):
  213. pool_op = nn.MaxPool2d if mode == "max" else nn.AvgPool2d
  214. if kernel_q:
  215. self.pool_q = pool_op(kernel_q, stride_q, padding_q)
  216. if kernel_kv:
  217. self.pool_k = pool_op(kernel_kv, stride_kv, padding_kv)
  218. self.pool_v = pool_op(kernel_kv, stride_kv, padding_kv)
  219. elif mode == "conv" or mode == "conv_unshared":
  220. dim_conv = dim // num_heads if mode == "conv" else dim
  221. if kernel_q:
  222. self.pool_q = nn.Conv2d(
  223. dim_conv,
  224. dim_conv,
  225. kernel_q,
  226. stride=stride_q,
  227. padding=padding_q,
  228. groups=dim_conv,
  229. bias=False,
  230. **dd,
  231. )
  232. self.norm_q = norm_layer(dim_conv, **dd)
  233. if kernel_kv:
  234. self.pool_k = nn.Conv2d(
  235. dim_conv,
  236. dim_conv,
  237. kernel_kv,
  238. stride=stride_kv,
  239. padding=padding_kv,
  240. groups=dim_conv,
  241. bias=False,
  242. **dd,
  243. )
  244. self.norm_k = norm_layer(dim_conv, **dd)
  245. self.pool_v = nn.Conv2d(
  246. dim_conv,
  247. dim_conv,
  248. kernel_kv,
  249. stride=stride_kv,
  250. padding=padding_kv,
  251. groups=dim_conv,
  252. bias=False,
  253. **dd,
  254. )
  255. self.norm_v = norm_layer(dim_conv, **dd)
  256. else:
  257. raise NotImplementedError(f"Unsupported model {mode}")
  258. # relative pos embedding
  259. self.rel_pos_type = rel_pos_type
  260. if self.rel_pos_type == 'spatial':
  261. assert feat_size[0] == feat_size[1]
  262. size = feat_size[0]
  263. q_size = size // stride_q[1] if len(stride_q) > 0 else size
  264. kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size
  265. rel_sp_dim = 2 * max(q_size, kv_size) - 1
  266. self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim, **dd))
  267. self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim, **dd))
  268. trunc_normal_tf_(self.rel_pos_h, std=0.02)
  269. trunc_normal_tf_(self.rel_pos_w, std=0.02)
  270. self.residual_pooling = residual_pooling
  271. def forward(self, x, feat_size: List[int]):
  272. B, N, _ = x.shape
  273. fold_dim = 1 if self.unshared else self.num_heads
  274. x = x.reshape(B, N, fold_dim, -1).permute(0, 2, 1, 3)
  275. q = k = v = x
  276. if self.pool_q is not None:
  277. q, q_tok = reshape_pre_pool(q, feat_size, self.has_cls_token)
  278. q = self.pool_q(q)
  279. q, q_size = reshape_post_pool(q, self.num_heads, q_tok)
  280. else:
  281. q_size = feat_size
  282. if self.norm_q is not None:
  283. q = self.norm_q(q)
  284. if self.pool_k is not None:
  285. k, k_tok = reshape_pre_pool(k, feat_size, self.has_cls_token)
  286. k = self.pool_k(k)
  287. k, k_size = reshape_post_pool(k, self.num_heads, k_tok)
  288. else:
  289. k_size = feat_size
  290. if self.norm_k is not None:
  291. k = self.norm_k(k)
  292. if self.pool_v is not None:
  293. v, v_tok = reshape_pre_pool(v, feat_size, self.has_cls_token)
  294. v = self.pool_v(v)
  295. v, v_size = reshape_post_pool(v, self.num_heads, v_tok)
  296. else:
  297. v_size = feat_size
  298. if self.norm_v is not None:
  299. v = self.norm_v(v)
  300. q_N = q_size[0] * q_size[1] + int(self.has_cls_token)
  301. q = q.transpose(1, 2).reshape(B, q_N, -1)
  302. q = self.q(q).reshape(B, q_N, self.num_heads, -1).transpose(1, 2)
  303. k_N = k_size[0] * k_size[1] + int(self.has_cls_token)
  304. k = k.transpose(1, 2).reshape(B, k_N, -1)
  305. k = self.k(k).reshape(B, k_N, self.num_heads, -1)
  306. v_N = v_size[0] * v_size[1] + int(self.has_cls_token)
  307. v = v.transpose(1, 2).reshape(B, v_N, -1)
  308. v = self.v(v).reshape(B, v_N, self.num_heads, -1).transpose(1, 2)
  309. attn = (q * self.scale) @ k
  310. if self.rel_pos_type == 'spatial':
  311. attn = cal_rel_pos_type(
  312. attn,
  313. q,
  314. self.has_cls_token,
  315. q_size,
  316. k_size,
  317. self.rel_pos_h,
  318. self.rel_pos_w,
  319. )
  320. attn = attn.softmax(dim=-1)
  321. x = attn @ v
  322. if self.residual_pooling:
  323. x = x + q
  324. x = x.transpose(1, 2).reshape(B, -1, self.dim_out)
  325. x = self.proj(x)
  326. return x, q_size
  327. class MultiScaleAttention(nn.Module):
  328. def __init__(
  329. self,
  330. dim: int,
  331. dim_out: int,
  332. feat_size: Tuple[int, int],
  333. num_heads: int = 8,
  334. qkv_bias: bool = True,
  335. mode: str = "conv",
  336. kernel_q: Tuple[int, int] = (1, 1),
  337. kernel_kv: Tuple[int, int] = (1, 1),
  338. stride_q: Tuple[int, int] = (1, 1),
  339. stride_kv: Tuple[int, int] = (1, 1),
  340. has_cls_token: bool = True,
  341. rel_pos_type: str = 'spatial',
  342. residual_pooling: bool = True,
  343. norm_layer: Type[nn.Module] = nn.LayerNorm,
  344. device=None,
  345. dtype=None,
  346. ):
  347. dd = {'device': device, 'dtype': dtype}
  348. super().__init__()
  349. self.num_heads = num_heads
  350. self.dim_out = dim_out
  351. self.head_dim = dim_out // num_heads
  352. self.scale = self.head_dim ** -0.5
  353. self.has_cls_token = has_cls_token
  354. padding_q = tuple([int(q // 2) for q in kernel_q])
  355. padding_kv = tuple([int(kv // 2) for kv in kernel_kv])
  356. self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias, **dd)
  357. self.proj = nn.Linear(dim_out, dim_out, **dd)
  358. # Skip pooling with kernel and stride size of (1, 1, 1).
  359. if prod(kernel_q) == 1 and prod(stride_q) == 1:
  360. kernel_q = None
  361. if prod(kernel_kv) == 1 and prod(stride_kv) == 1:
  362. kernel_kv = None
  363. self.mode = mode
  364. self.unshared = mode == 'conv_unshared'
  365. self.norm_q, self.norm_k, self.norm_v = None, None, None
  366. self.pool_q, self.pool_k, self.pool_v = None, None, None
  367. if mode in ("avg", "max"):
  368. pool_op = nn.MaxPool2d if mode == "max" else nn.AvgPool2d
  369. if kernel_q:
  370. self.pool_q = pool_op(kernel_q, stride_q, padding_q)
  371. if kernel_kv:
  372. self.pool_k = pool_op(kernel_kv, stride_kv, padding_kv)
  373. self.pool_v = pool_op(kernel_kv, stride_kv, padding_kv)
  374. elif mode == "conv" or mode == "conv_unshared":
  375. dim_conv = dim_out // num_heads if mode == "conv" else dim_out
  376. if kernel_q:
  377. self.pool_q = nn.Conv2d(
  378. dim_conv,
  379. dim_conv,
  380. kernel_q,
  381. stride=stride_q,
  382. padding=padding_q,
  383. groups=dim_conv,
  384. bias=False,
  385. **dd,
  386. )
  387. self.norm_q = norm_layer(dim_conv, **dd)
  388. if kernel_kv:
  389. self.pool_k = nn.Conv2d(
  390. dim_conv,
  391. dim_conv,
  392. kernel_kv,
  393. stride=stride_kv,
  394. padding=padding_kv,
  395. groups=dim_conv,
  396. bias=False,
  397. **dd,
  398. )
  399. self.norm_k = norm_layer(dim_conv, **dd)
  400. self.pool_v = nn.Conv2d(
  401. dim_conv,
  402. dim_conv,
  403. kernel_kv,
  404. stride=stride_kv,
  405. padding=padding_kv,
  406. groups=dim_conv,
  407. bias=False,
  408. **dd,
  409. )
  410. self.norm_v = norm_layer(dim_conv, **dd)
  411. else:
  412. raise NotImplementedError(f"Unsupported model {mode}")
  413. # relative pos embedding
  414. self.rel_pos_type = rel_pos_type
  415. if self.rel_pos_type == 'spatial':
  416. assert feat_size[0] == feat_size[1]
  417. size = feat_size[0]
  418. q_size = size // stride_q[1] if len(stride_q) > 0 else size
  419. kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size
  420. rel_sp_dim = 2 * max(q_size, kv_size) - 1
  421. self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim, **dd))
  422. self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim, **dd))
  423. trunc_normal_tf_(self.rel_pos_h, std=0.02)
  424. trunc_normal_tf_(self.rel_pos_w, std=0.02)
  425. self.residual_pooling = residual_pooling
  426. def forward(self, x, feat_size: List[int]):
  427. B, N, _ = x.shape
  428. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
  429. q, k, v = qkv.unbind(dim=0)
  430. if self.pool_q is not None:
  431. q, q_tok = reshape_pre_pool(q, feat_size, self.has_cls_token)
  432. q = self.pool_q(q)
  433. q, q_size = reshape_post_pool(q, self.num_heads, q_tok)
  434. else:
  435. q_size = feat_size
  436. if self.norm_q is not None:
  437. q = self.norm_q(q)
  438. if self.pool_k is not None:
  439. k, k_tok = reshape_pre_pool(k, feat_size, self.has_cls_token)
  440. k = self.pool_k(k)
  441. k, k_size = reshape_post_pool(k, self.num_heads, k_tok)
  442. else:
  443. k_size = feat_size
  444. if self.norm_k is not None:
  445. k = self.norm_k(k)
  446. if self.pool_v is not None:
  447. v, v_tok = reshape_pre_pool(v, feat_size, self.has_cls_token)
  448. v = self.pool_v(v)
  449. v, _ = reshape_post_pool(v, self.num_heads, v_tok)
  450. if self.norm_v is not None:
  451. v = self.norm_v(v)
  452. attn = (q * self.scale) @ k.transpose(-2, -1)
  453. if self.rel_pos_type == 'spatial':
  454. attn = cal_rel_pos_type(
  455. attn,
  456. q,
  457. self.has_cls_token,
  458. q_size,
  459. k_size,
  460. self.rel_pos_h,
  461. self.rel_pos_w,
  462. )
  463. attn = attn.softmax(dim=-1)
  464. x = attn @ v
  465. if self.residual_pooling:
  466. x = x + q
  467. x = x.transpose(1, 2).reshape(B, -1, self.dim_out)
  468. x = self.proj(x)
  469. return x, q_size
  470. class MultiScaleBlock(nn.Module):
  471. def __init__(
  472. self,
  473. dim: int,
  474. dim_out: int,
  475. num_heads: int,
  476. feat_size: Tuple[int, int],
  477. mlp_ratio: float = 4.0,
  478. qkv_bias: bool = True,
  479. drop_path: float = 0.0,
  480. norm_layer: Type[nn.Module] = nn.LayerNorm,
  481. kernel_q: Tuple[int, int] = (1, 1),
  482. kernel_kv: Tuple[int, int] = (1, 1),
  483. stride_q: Tuple[int, int] = (1, 1),
  484. stride_kv: Tuple[int, int] = (1, 1),
  485. mode: str = "conv",
  486. has_cls_token: bool = True,
  487. expand_attn: bool = False,
  488. pool_first: bool = False,
  489. rel_pos_type: str = 'spatial',
  490. residual_pooling: bool = True,
  491. device=None,
  492. dtype=None,
  493. ):
  494. dd = {'device': device, 'dtype': dtype}
  495. super().__init__()
  496. proj_needed = dim != dim_out
  497. self.dim = dim
  498. self.dim_out = dim_out
  499. self.has_cls_token = has_cls_token
  500. self.norm1 = norm_layer(dim, **dd)
  501. self.shortcut_proj_attn = nn.Linear(dim, dim_out, **dd) if proj_needed and expand_attn else None
  502. if stride_q and prod(stride_q) > 1:
  503. kernel_skip = [s + 1 if s > 1 else s for s in stride_q]
  504. stride_skip = stride_q
  505. padding_skip = [int(skip // 2) for skip in kernel_skip]
  506. self.shortcut_pool_attn = nn.MaxPool2d(kernel_skip, stride_skip, padding_skip)
  507. else:
  508. self.shortcut_pool_attn = None
  509. att_dim = dim_out if expand_attn else dim
  510. attn_layer = MultiScaleAttentionPoolFirst if pool_first else MultiScaleAttention
  511. self.attn = attn_layer(
  512. dim,
  513. att_dim,
  514. num_heads=num_heads,
  515. feat_size=feat_size,
  516. qkv_bias=qkv_bias,
  517. kernel_q=kernel_q,
  518. kernel_kv=kernel_kv,
  519. stride_q=stride_q,
  520. stride_kv=stride_kv,
  521. norm_layer=norm_layer,
  522. has_cls_token=has_cls_token,
  523. mode=mode,
  524. rel_pos_type=rel_pos_type,
  525. residual_pooling=residual_pooling,
  526. **dd,
  527. )
  528. self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  529. self.norm2 = norm_layer(att_dim, **dd)
  530. mlp_dim_out = dim_out
  531. self.shortcut_proj_mlp = nn.Linear(dim, dim_out, **dd) if proj_needed and not expand_attn else None
  532. self.mlp = Mlp(
  533. in_features=att_dim,
  534. hidden_features=int(att_dim * mlp_ratio),
  535. out_features=mlp_dim_out,
  536. **dd,
  537. )
  538. self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  539. def _shortcut_pool(self, x, feat_size: List[int]):
  540. if self.shortcut_pool_attn is None:
  541. return x
  542. if self.has_cls_token:
  543. cls_tok, x = x[:, :1, :], x[:, 1:, :]
  544. else:
  545. cls_tok = None
  546. B, L, C = x.shape
  547. H, W = feat_size
  548. x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
  549. x = self.shortcut_pool_attn(x)
  550. x = x.reshape(B, C, -1).transpose(1, 2)
  551. if cls_tok is not None:
  552. x = torch.cat((cls_tok, x), dim=1)
  553. return x
  554. def forward(self, x, feat_size: List[int]):
  555. x_norm = self.norm1(x)
  556. # NOTE as per the original impl, this seems odd, but shortcut uses un-normalized input if no proj
  557. x_shortcut = x if self.shortcut_proj_attn is None else self.shortcut_proj_attn(x_norm)
  558. x_shortcut = self._shortcut_pool(x_shortcut, feat_size)
  559. x, feat_size_new = self.attn(x_norm, feat_size)
  560. x = x_shortcut + self.drop_path1(x)
  561. x_norm = self.norm2(x)
  562. x_shortcut = x if self.shortcut_proj_mlp is None else self.shortcut_proj_mlp(x_norm)
  563. x = x_shortcut + self.drop_path2(self.mlp(x_norm))
  564. return x, feat_size_new
  565. class MultiScaleVitStage(nn.Module):
  566. def __init__(
  567. self,
  568. dim: int,
  569. dim_out: int,
  570. depth: int,
  571. num_heads: int,
  572. feat_size: Tuple[int, int],
  573. mlp_ratio: float = 4.0,
  574. qkv_bias: bool = True,
  575. kernel_q: Tuple[int, int] = (1, 1),
  576. kernel_kv: Tuple[int, int] = (1, 1),
  577. stride_q: Tuple[int, int] = (1, 1),
  578. stride_kv: Tuple[int, int] = (1, 1),
  579. mode: str = "conv",
  580. has_cls_token: bool = True,
  581. expand_attn: bool = False,
  582. pool_first: bool = False,
  583. rel_pos_type: str = 'spatial',
  584. residual_pooling: bool = True,
  585. norm_layer: Type[nn.Module] = nn.LayerNorm,
  586. drop_path: Union[float, List[float]] = 0.0,
  587. device=None,
  588. dtype=None,
  589. ):
  590. dd = {'device': device, 'dtype': dtype}
  591. super().__init__()
  592. self.grad_checkpointing = False
  593. self.blocks = nn.ModuleList()
  594. if expand_attn:
  595. out_dims = (dim_out,) * depth
  596. else:
  597. out_dims = (dim,) * (depth - 1) + (dim_out,)
  598. for i in range(depth):
  599. attention_block = MultiScaleBlock(
  600. dim=dim,
  601. dim_out=out_dims[i],
  602. num_heads=num_heads,
  603. feat_size=feat_size,
  604. mlp_ratio=mlp_ratio,
  605. qkv_bias=qkv_bias,
  606. kernel_q=kernel_q,
  607. kernel_kv=kernel_kv,
  608. stride_q=stride_q if i == 0 else (1, 1),
  609. stride_kv=stride_kv,
  610. mode=mode,
  611. has_cls_token=has_cls_token,
  612. pool_first=pool_first,
  613. rel_pos_type=rel_pos_type,
  614. residual_pooling=residual_pooling,
  615. expand_attn=expand_attn,
  616. norm_layer=norm_layer,
  617. drop_path=drop_path[i] if isinstance(drop_path, (list, tuple)) else drop_path,
  618. **dd,
  619. )
  620. dim = out_dims[i]
  621. self.blocks.append(attention_block)
  622. if i == 0:
  623. feat_size = tuple([size // stride for size, stride in zip(feat_size, stride_q)])
  624. self.feat_size = feat_size
  625. def forward(self, x, feat_size: List[int]):
  626. for blk in self.blocks:
  627. if self.grad_checkpointing and not torch.jit.is_scripting():
  628. x, feat_size = checkpoint(blk, x, feat_size)
  629. else:
  630. x, feat_size = blk(x, feat_size)
  631. return x, feat_size
  632. class MultiScaleVit(nn.Module):
  633. """
  634. Improved Multiscale Vision Transformers for Classification and Detection
  635. Yanghao Li*, Chao-Yuan Wu*, Haoqi Fan, Karttikeya Mangalam, Bo Xiong, Jitendra Malik,
  636. Christoph Feichtenhofer*
  637. https://arxiv.org/abs/2112.01526
  638. Multiscale Vision Transformers
  639. Haoqi Fan*, Bo Xiong*, Karttikeya Mangalam*, Yanghao Li*, Zhicheng Yan, Jitendra Malik,
  640. Christoph Feichtenhofer*
  641. https://arxiv.org/abs/2104.11227
  642. """
  643. def __init__(
  644. self,
  645. cfg: MultiScaleVitCfg,
  646. img_size: Tuple[int, int] = (224, 224),
  647. in_chans: int = 3,
  648. global_pool: Optional[str] = None,
  649. num_classes: int = 1000,
  650. drop_path_rate: float = 0.,
  651. drop_rate: float = 0.,
  652. device=None,
  653. dtype=None,
  654. ):
  655. super().__init__()
  656. dd = {'device': device, 'dtype': dtype}
  657. img_size = to_2tuple(img_size)
  658. norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps)
  659. self.num_classes = num_classes
  660. self.in_chans = in_chans
  661. self.drop_rate = drop_rate
  662. if global_pool is None:
  663. global_pool = 'token' if cfg.use_cls_token else 'avg'
  664. self.global_pool = global_pool
  665. self.depths = tuple(cfg.depths)
  666. self.expand_attn = cfg.expand_attn
  667. embed_dim = cfg.embed_dim[0]
  668. self.patch_embed = PatchEmbed(
  669. dim_in=in_chans,
  670. dim_out=embed_dim,
  671. kernel=cfg.patch_kernel,
  672. stride=cfg.patch_stride,
  673. padding=cfg.patch_padding,
  674. **dd,
  675. )
  676. patch_dims = (img_size[0] // cfg.patch_stride[0], img_size[1] // cfg.patch_stride[1])
  677. num_patches = prod(patch_dims)
  678. if cfg.use_cls_token:
  679. self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, **dd))
  680. self.num_prefix_tokens = 1
  681. pos_embed_dim = num_patches + 1
  682. else:
  683. self.num_prefix_tokens = 0
  684. self.cls_token = None
  685. pos_embed_dim = num_patches
  686. if cfg.use_abs_pos:
  687. self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_dim, embed_dim, **dd))
  688. else:
  689. self.pos_embed = None
  690. num_stages = len(cfg.embed_dim)
  691. feat_size = patch_dims
  692. curr_stride = max(cfg.patch_stride)
  693. dpr = calculate_drop_path_rates(drop_path_rate, cfg.depths, stagewise=True)
  694. self.stages = nn.ModuleList()
  695. self.feature_info = []
  696. for i in range(num_stages):
  697. if cfg.expand_attn:
  698. dim_out = cfg.embed_dim[i]
  699. else:
  700. dim_out = cfg.embed_dim[min(i + 1, num_stages - 1)]
  701. stage = MultiScaleVitStage(
  702. dim=embed_dim,
  703. dim_out=dim_out,
  704. depth=cfg.depths[i],
  705. num_heads=cfg.num_heads[i],
  706. feat_size=feat_size,
  707. mlp_ratio=cfg.mlp_ratio,
  708. qkv_bias=cfg.qkv_bias,
  709. mode=cfg.mode,
  710. pool_first=cfg.pool_first,
  711. expand_attn=cfg.expand_attn,
  712. kernel_q=cfg.kernel_qkv,
  713. kernel_kv=cfg.kernel_qkv,
  714. stride_q=cfg.stride_q[i],
  715. stride_kv=cfg.stride_kv[i],
  716. has_cls_token=cfg.use_cls_token,
  717. rel_pos_type=cfg.rel_pos_type,
  718. residual_pooling=cfg.residual_pooling,
  719. norm_layer=norm_layer,
  720. drop_path=dpr[i],
  721. **dd,
  722. )
  723. curr_stride *= max(cfg.stride_q[i])
  724. self.feature_info += [dict(module=f'block.{i}', num_chs=dim_out, reduction=curr_stride)]
  725. embed_dim = dim_out
  726. feat_size = stage.feat_size
  727. self.stages.append(stage)
  728. self.num_features = self.head_hidden_size = embed_dim
  729. self.norm = norm_layer(embed_dim, **dd)
  730. self.head = nn.Sequential(OrderedDict([
  731. ('drop', nn.Dropout(self.drop_rate)),
  732. ('fc', nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity())
  733. ]))
  734. if self.pos_embed is not None:
  735. trunc_normal_tf_(self.pos_embed, std=0.02)
  736. if self.cls_token is not None:
  737. trunc_normal_tf_(self.cls_token, std=0.02)
  738. self.apply(self._init_weights)
  739. def _init_weights(self, m):
  740. if isinstance(m, nn.Linear):
  741. trunc_normal_tf_(m.weight, std=0.02)
  742. if isinstance(m, nn.Linear) and m.bias is not None:
  743. nn.init.constant_(m.bias, 0.0)
  744. @torch.jit.ignore
  745. def no_weight_decay(self):
  746. return {k for k, _ in self.named_parameters()
  747. if any(n in k for n in ["pos_embed", "rel_pos_h", "rel_pos_w", "cls_token"])}
  748. @torch.jit.ignore
  749. def group_matcher(self, coarse=False):
  750. matcher = dict(
  751. stem=r'^patch_embed', # stem and embed
  752. blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
  753. )
  754. return matcher
  755. @torch.jit.ignore
  756. def set_grad_checkpointing(self, enable=True):
  757. for s in self.stages:
  758. s.grad_checkpointing = enable
  759. @torch.jit.ignore
  760. def get_classifier(self) -> nn.Module:
  761. return self.head.fc
  762. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  763. self.num_classes = num_classes
  764. if global_pool is not None:
  765. self.global_pool = global_pool
  766. device = self.head.fc.weight.device if hasattr(self.head.fc, 'weight') else None
  767. dtype = self.head.fc.weight.dtype if hasattr(self.head.fc, 'weight') else None
  768. self.head = nn.Sequential(OrderedDict([
  769. ('drop', nn.Dropout(self.drop_rate)),
  770. ('fc', nn.Linear(self.num_features, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity())
  771. ]))
  772. def forward_intermediates(
  773. self,
  774. x: torch.Tensor,
  775. indices: Optional[Union[int, List[int]]] = None,
  776. norm: bool = False,
  777. stop_early: bool = False,
  778. output_fmt: str = 'NCHW',
  779. intermediates_only: bool = False,
  780. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  781. """ Forward features that returns intermediates.
  782. Args:
  783. x: Input image tensor
  784. indices: Take last n blocks if int, all if None, select matching indices if sequence
  785. norm: Apply norm layer to all intermediates
  786. stop_early: Stop iterating over blocks when last desired intermediate hit
  787. output_fmt: Shape of intermediate feature outputs
  788. intermediates_only: Only return intermediate features
  789. Returns:
  790. """
  791. assert output_fmt in ('NCHW', 'NLC'), 'Output shape must be NCHW or NLC.'
  792. reshape = output_fmt == 'NCHW'
  793. intermediates = []
  794. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  795. # FIXME slice block/pos_block if < max
  796. # forward pass
  797. x, feat_size = self.patch_embed(x)
  798. B = x.shape[0]
  799. if self.cls_token is not None:
  800. cls_tokens = self.cls_token.expand(B, -1, -1)
  801. x = torch.cat((cls_tokens, x), dim=1)
  802. if self.pos_embed is not None:
  803. x = x + self.pos_embed
  804. last_idx = len(self.stages) - 1
  805. for feat_idx, stage in enumerate(self.stages):
  806. x, feat_size = stage(x, feat_size)
  807. if feat_idx in take_indices:
  808. if norm and feat_idx == last_idx:
  809. x_inter = self.norm(x) # applying final norm last intermediate
  810. else:
  811. x_inter = x
  812. if reshape:
  813. if self.cls_token is not None:
  814. # possible to allow return of class tokens, TBD
  815. x_inter = x_inter[:, 1:]
  816. x_inter = x_inter.reshape(B, feat_size[0], feat_size[1], -1).permute(0, 3, 1, 2)
  817. intermediates.append(x_inter)
  818. if intermediates_only:
  819. return intermediates
  820. if feat_idx == last_idx:
  821. x = self.norm(x)
  822. return x, intermediates
  823. def prune_intermediate_layers(
  824. self,
  825. indices: Union[int, List[int]] = 1,
  826. prune_norm: bool = False,
  827. prune_head: bool = True,
  828. ):
  829. """ Prune layers not required for specified intermediates.
  830. """
  831. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  832. # FIXME add stage pruning
  833. # self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
  834. if prune_norm:
  835. self.norm = nn.Identity()
  836. if prune_head:
  837. self.reset_classifier(0, '')
  838. return take_indices
  839. def forward_features(self, x):
  840. x, feat_size = self.patch_embed(x)
  841. B, N, C = x.shape
  842. if self.cls_token is not None:
  843. cls_tokens = self.cls_token.expand(B, -1, -1)
  844. x = torch.cat((cls_tokens, x), dim=1)
  845. if self.pos_embed is not None:
  846. x = x + self.pos_embed
  847. for stage in self.stages:
  848. x, feat_size = stage(x, feat_size)
  849. x = self.norm(x)
  850. return x
  851. def forward_head(self, x, pre_logits: bool = False):
  852. if self.global_pool:
  853. if self.global_pool == 'avg':
  854. x = x[:, self.num_prefix_tokens:].mean(1)
  855. else:
  856. x = x[:, 0]
  857. return x if pre_logits else self.head(x)
  858. def forward(self, x):
  859. x = self.forward_features(x)
  860. x = self.forward_head(x)
  861. return x
  862. def checkpoint_filter_fn(state_dict, model):
  863. if 'stages.0.blocks.0.norm1.weight' in state_dict:
  864. # native checkpoint, look for rel_pos interpolations
  865. for k in state_dict.keys():
  866. if 'rel_pos' in k:
  867. rel_pos = state_dict[k]
  868. dest_rel_pos_shape = model.state_dict()[k].shape
  869. if rel_pos.shape[0] != dest_rel_pos_shape[0]:
  870. rel_pos_resized = torch.nn.functional.interpolate(
  871. rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
  872. size=dest_rel_pos_shape[0],
  873. mode="linear",
  874. )
  875. state_dict[k] = rel_pos_resized.reshape(-1, dest_rel_pos_shape[0]).permute(1, 0)
  876. return state_dict
  877. import re
  878. if 'model_state' in state_dict:
  879. state_dict = state_dict['model_state']
  880. depths = getattr(model, 'depths', None)
  881. expand_attn = getattr(model, 'expand_attn', True)
  882. assert depths is not None, 'model requires depth attribute to remap checkpoints'
  883. depth_map = {}
  884. block_idx = 0
  885. for stage_idx, d in enumerate(depths):
  886. depth_map.update({i: (stage_idx, i - block_idx) for i in range(block_idx, block_idx + d)})
  887. block_idx += d
  888. out_dict = {}
  889. for k, v in state_dict.items():
  890. k = re.sub(
  891. r'blocks\.(\d+)',
  892. lambda x: f'stages.{depth_map[int(x.group(1))][0]}.blocks.{depth_map[int(x.group(1))][1]}',
  893. k)
  894. if expand_attn:
  895. k = re.sub(r'stages\.(\d+).blocks\.(\d+).proj', f'stages.\\1.blocks.\\2.shortcut_proj_attn', k)
  896. else:
  897. k = re.sub(r'stages\.(\d+).blocks\.(\d+).proj', f'stages.\\1.blocks.\\2.shortcut_proj_mlp', k)
  898. if 'head' in k:
  899. k = k.replace('head.projection', 'head.fc')
  900. out_dict[k] = v
  901. return out_dict
  902. model_cfgs = dict(
  903. mvitv2_tiny=MultiScaleVitCfg(
  904. depths=(1, 2, 5, 2),
  905. ),
  906. mvitv2_small=MultiScaleVitCfg(
  907. depths=(1, 2, 11, 2),
  908. ),
  909. mvitv2_base=MultiScaleVitCfg(
  910. depths=(2, 3, 16, 3),
  911. ),
  912. mvitv2_large=MultiScaleVitCfg(
  913. depths=(2, 6, 36, 4),
  914. embed_dim=144,
  915. num_heads=2,
  916. expand_attn=False,
  917. ),
  918. mvitv2_small_cls=MultiScaleVitCfg(
  919. depths=(1, 2, 11, 2),
  920. use_cls_token=True,
  921. ),
  922. mvitv2_base_cls=MultiScaleVitCfg(
  923. depths=(2, 3, 16, 3),
  924. use_cls_token=True,
  925. ),
  926. mvitv2_large_cls=MultiScaleVitCfg(
  927. depths=(2, 6, 36, 4),
  928. embed_dim=144,
  929. num_heads=2,
  930. use_cls_token=True,
  931. expand_attn=True,
  932. ),
  933. mvitv2_huge_cls=MultiScaleVitCfg(
  934. depths=(4, 8, 60, 8),
  935. embed_dim=192,
  936. num_heads=3,
  937. use_cls_token=True,
  938. expand_attn=True,
  939. ),
  940. )
  941. def _create_mvitv2(variant, cfg_variant=None, pretrained=False, **kwargs):
  942. out_indices = kwargs.pop('out_indices', 4)
  943. return build_model_with_cfg(
  944. MultiScaleVit,
  945. variant,
  946. pretrained,
  947. model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
  948. pretrained_filter_fn=checkpoint_filter_fn,
  949. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  950. **kwargs,
  951. )
  952. def _cfg(url='', **kwargs):
  953. return {
  954. 'url': url,
  955. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  956. 'crop_pct': .9, 'interpolation': 'bicubic',
  957. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  958. 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
  959. 'fixed_input_size': True,
  960. 'license': 'apache-2.0',
  961. **kwargs
  962. }
  963. default_cfgs = generate_default_cfgs({
  964. 'mvitv2_tiny.fb_in1k': _cfg(
  965. url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_T_in1k.pyth',
  966. hf_hub_id='timm/'),
  967. 'mvitv2_small.fb_in1k': _cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_S_in1k.pyth',
  968. hf_hub_id='timm/'),
  969. 'mvitv2_base.fb_in1k': _cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in1k.pyth',
  970. hf_hub_id='timm/'),
  971. 'mvitv2_large.fb_in1k': _cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in1k.pyth',
  972. hf_hub_id='timm/'),
  973. 'mvitv2_small_cls': _cfg(url=''),
  974. 'mvitv2_base_cls.fb_inw21k': _cfg(
  975. url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in21k.pyth',
  976. hf_hub_id='timm/',
  977. num_classes=19168),
  978. 'mvitv2_large_cls.fb_inw21k': _cfg(
  979. url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in21k.pyth',
  980. hf_hub_id='timm/',
  981. num_classes=19168),
  982. 'mvitv2_huge_cls.fb_inw21k': _cfg(
  983. url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_H_in21k.pyth',
  984. hf_hub_id='timm/',
  985. num_classes=19168),
  986. })
  987. @register_model
  988. def mvitv2_tiny(pretrained=False, **kwargs) -> MultiScaleVit:
  989. return _create_mvitv2('mvitv2_tiny', pretrained=pretrained, **kwargs)
  990. @register_model
  991. def mvitv2_small(pretrained=False, **kwargs) -> MultiScaleVit:
  992. return _create_mvitv2('mvitv2_small', pretrained=pretrained, **kwargs)
  993. @register_model
  994. def mvitv2_base(pretrained=False, **kwargs) -> MultiScaleVit:
  995. return _create_mvitv2('mvitv2_base', pretrained=pretrained, **kwargs)
  996. @register_model
  997. def mvitv2_large(pretrained=False, **kwargs) -> MultiScaleVit:
  998. return _create_mvitv2('mvitv2_large', pretrained=pretrained, **kwargs)
  999. @register_model
  1000. def mvitv2_small_cls(pretrained=False, **kwargs) -> MultiScaleVit:
  1001. return _create_mvitv2('mvitv2_small_cls', pretrained=pretrained, **kwargs)
  1002. @register_model
  1003. def mvitv2_base_cls(pretrained=False, **kwargs) -> MultiScaleVit:
  1004. return _create_mvitv2('mvitv2_base_cls', pretrained=pretrained, **kwargs)
  1005. @register_model
  1006. def mvitv2_large_cls(pretrained=False, **kwargs) -> MultiScaleVit:
  1007. return _create_mvitv2('mvitv2_large_cls', pretrained=pretrained, **kwargs)
  1008. @register_model
  1009. def mvitv2_huge_cls(pretrained=False, **kwargs) -> MultiScaleVit:
  1010. return _create_mvitv2('mvitv2_huge_cls', pretrained=pretrained, **kwargs)