crossvit.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654
  1. """ CrossViT Model
  2. @inproceedings{
  3. chen2021crossvit,
  4. title={{CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification}},
  5. author={Chun-Fu (Richard) Chen and Quanfu Fan and Rameswar Panda},
  6. booktitle={International Conference on Computer Vision (ICCV)},
  7. year={2021}
  8. }
  9. Paper link: https://arxiv.org/abs/2103.14899
  10. Original code: https://github.com/IBM/CrossViT/blob/main/models/crossvit.py
  11. NOTE: model names have been renamed from originals to represent actual input res all *_224 -> *_240 and *_384 -> *_408
  12. Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
  13. Modified from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  14. """
  15. # Copyright IBM All Rights Reserved.
  16. # SPDX-License-Identifier: Apache-2.0
  17. from functools import partial
  18. from typing import List, Optional, Tuple, Type, Union
  19. import torch
  20. import torch.nn as nn
  21. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  22. from timm.layers import DropPath, calculate_drop_path_rates, to_2tuple, trunc_normal_, _assert
  23. from ._builder import build_model_with_cfg
  24. from ._features_fx import register_notrace_function
  25. from ._registry import register_model, generate_default_cfgs
  26. from .vision_transformer import Block
  27. __all__ = ['CrossVit'] # model_registry will add each entrypoint fn to this
  28. class PatchEmbed(nn.Module):
  29. """ Image to Patch Embedding
  30. """
  31. def __init__(
  32. self,
  33. img_size: Union[int, Tuple[int, int]] = 224,
  34. patch_size: int = 16,
  35. in_chans: int = 3,
  36. embed_dim: int = 768,
  37. multi_conv: bool = False,
  38. device=None,
  39. dtype=None,
  40. ):
  41. dd = {'device': device, 'dtype': dtype}
  42. super().__init__()
  43. img_size = to_2tuple(img_size)
  44. patch_size = to_2tuple(patch_size)
  45. num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
  46. self.img_size = img_size
  47. self.patch_size = patch_size
  48. self.num_patches = num_patches
  49. if multi_conv:
  50. if patch_size[0] == 12:
  51. self.proj = nn.Sequential(
  52. nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3, **dd),
  53. nn.ReLU(inplace=True),
  54. nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=3, padding=0, **dd),
  55. nn.ReLU(inplace=True),
  56. nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=1, padding=1, **dd),
  57. )
  58. elif patch_size[0] == 16:
  59. self.proj = nn.Sequential(
  60. nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3, **dd),
  61. nn.ReLU(inplace=True),
  62. nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=2, padding=1, **dd),
  63. nn.ReLU(inplace=True),
  64. nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1, **dd),
  65. )
  66. else:
  67. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, **dd)
  68. def forward(self, x):
  69. B, C, H, W = x.shape
  70. # FIXME look at relaxing size constraints
  71. _assert(H == self.img_size[0],
  72. f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
  73. _assert(W == self.img_size[1],
  74. f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
  75. x = self.proj(x).flatten(2).transpose(1, 2)
  76. return x
  77. class CrossAttention(nn.Module):
  78. def __init__(
  79. self,
  80. dim: int,
  81. num_heads: int = 8,
  82. qkv_bias: bool = False,
  83. attn_drop: float = 0.,
  84. proj_drop: float = 0.,
  85. device=None,
  86. dtype=None,
  87. ):
  88. dd = {'device': device, 'dtype': dtype}
  89. super().__init__()
  90. self.num_heads = num_heads
  91. head_dim = dim // num_heads
  92. # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
  93. self.scale = head_dim ** -0.5
  94. self.wq = nn.Linear(dim, dim, bias=qkv_bias, **dd)
  95. self.wk = nn.Linear(dim, dim, bias=qkv_bias, **dd)
  96. self.wv = nn.Linear(dim, dim, bias=qkv_bias, **dd)
  97. self.attn_drop = nn.Dropout(attn_drop)
  98. self.proj = nn.Linear(dim, dim, **dd)
  99. self.proj_drop = nn.Dropout(proj_drop)
  100. def forward(self, x):
  101. B, N, C = x.shape
  102. # B1C -> B1H(C/H) -> BH1(C/H)
  103. q = self.wq(x[:, 0:1, ...]).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
  104. # BNC -> BNH(C/H) -> BHN(C/H)
  105. k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
  106. # BNC -> BNH(C/H) -> BHN(C/H)
  107. v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
  108. attn = (q @ k.transpose(-2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N
  109. attn = attn.softmax(dim=-1)
  110. attn = self.attn_drop(attn)
  111. x = (attn @ v).transpose(1, 2).reshape(B, 1, C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C
  112. x = self.proj(x)
  113. x = self.proj_drop(x)
  114. return x
  115. class CrossAttentionBlock(nn.Module):
  116. def __init__(
  117. self,
  118. dim: int,
  119. num_heads: int,
  120. mlp_ratio: float = 4.,
  121. qkv_bias: bool = False,
  122. proj_drop: float = 0.,
  123. attn_drop: float = 0.,
  124. drop_path: float = 0.,
  125. act_layer: Type[nn.Module] = nn.GELU,
  126. norm_layer: Type[nn.Module] = nn.LayerNorm,
  127. device=None,
  128. dtype=None,
  129. ):
  130. dd = {'device': device, 'dtype': dtype}
  131. super().__init__()
  132. self.norm1 = norm_layer(dim, **dd)
  133. self.attn = CrossAttention(
  134. dim,
  135. num_heads=num_heads,
  136. qkv_bias=qkv_bias,
  137. attn_drop=attn_drop,
  138. proj_drop=proj_drop,
  139. **dd,
  140. )
  141. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  142. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  143. def forward(self, x):
  144. x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x)))
  145. return x
  146. class MultiScaleBlock(nn.Module):
  147. def __init__(
  148. self,
  149. dim: Tuple[int, ...],
  150. patches: Tuple[int, ...],
  151. depth: Tuple[int, ...],
  152. num_heads: Tuple[int, ...],
  153. mlp_ratio: Tuple[float, ...],
  154. qkv_bias: bool = False,
  155. proj_drop: float = 0.,
  156. attn_drop: float = 0.,
  157. drop_path: Union[List[float], float] = 0.,
  158. act_layer: Type[nn.Module] = nn.GELU,
  159. norm_layer: Type[nn.Module] = nn.LayerNorm,
  160. device=None,
  161. dtype=None,
  162. ):
  163. dd = {'device': device, 'dtype': dtype}
  164. super().__init__()
  165. num_branches = len(dim)
  166. self.num_branches = num_branches
  167. # different branch could have different embedding size, the first one is the base
  168. self.blocks = nn.ModuleList()
  169. for d in range(num_branches):
  170. tmp = []
  171. for i in range(depth[d]):
  172. tmp.append(Block(
  173. dim=dim[d],
  174. num_heads=num_heads[d],
  175. mlp_ratio=mlp_ratio[d],
  176. qkv_bias=qkv_bias,
  177. proj_drop=proj_drop,
  178. attn_drop=attn_drop,
  179. drop_path=drop_path[i],
  180. norm_layer=norm_layer,
  181. **dd,
  182. ))
  183. if len(tmp) != 0:
  184. self.blocks.append(nn.Sequential(*tmp))
  185. if len(self.blocks) == 0:
  186. self.blocks = None
  187. self.projs = nn.ModuleList()
  188. for d in range(num_branches):
  189. if dim[d] == dim[(d + 1) % num_branches] and False:
  190. tmp = [nn.Identity()]
  191. else:
  192. tmp = [norm_layer(dim[d], **dd), act_layer(), nn.Linear(dim[d], dim[(d + 1) % num_branches], **dd)]
  193. self.projs.append(nn.Sequential(*tmp))
  194. self.fusion = nn.ModuleList()
  195. for d in range(num_branches):
  196. d_ = (d + 1) % num_branches
  197. nh = num_heads[d_]
  198. if depth[-1] == 0: # backward capability:
  199. self.fusion.append(
  200. CrossAttentionBlock(
  201. dim=dim[d_],
  202. num_heads=nh,
  203. mlp_ratio=mlp_ratio[d],
  204. qkv_bias=qkv_bias,
  205. proj_drop=proj_drop,
  206. attn_drop=attn_drop,
  207. drop_path=drop_path[-1],
  208. norm_layer=norm_layer,
  209. **dd,
  210. ))
  211. else:
  212. tmp = []
  213. for _ in range(depth[-1]):
  214. tmp.append(CrossAttentionBlock(
  215. dim=dim[d_],
  216. num_heads=nh,
  217. mlp_ratio=mlp_ratio[d],
  218. qkv_bias=qkv_bias,
  219. proj_drop=proj_drop,
  220. attn_drop=attn_drop,
  221. drop_path=drop_path[-1],
  222. norm_layer=norm_layer,
  223. **dd,
  224. ))
  225. self.fusion.append(nn.Sequential(*tmp))
  226. self.revert_projs = nn.ModuleList()
  227. for d in range(num_branches):
  228. if dim[(d + 1) % num_branches] == dim[d] and False:
  229. tmp = [nn.Identity()]
  230. else:
  231. tmp = [norm_layer(dim[(d + 1) % num_branches], **dd), act_layer(),
  232. nn.Linear(dim[(d + 1) % num_branches], dim[d], **dd)]
  233. self.revert_projs.append(nn.Sequential(*tmp))
  234. def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
  235. outs_b = []
  236. for i, block in enumerate(self.blocks):
  237. outs_b.append(block(x[i]))
  238. # only take the cls token out
  239. proj_cls_token = torch.jit.annotate(List[torch.Tensor], [])
  240. for i, proj in enumerate(self.projs):
  241. proj_cls_token.append(proj(outs_b[i][:, 0:1, ...]))
  242. # cross attention
  243. outs = []
  244. for i, (fusion, revert_proj) in enumerate(zip(self.fusion, self.revert_projs)):
  245. tmp = torch.cat((proj_cls_token[i], outs_b[(i + 1) % self.num_branches][:, 1:, ...]), dim=1)
  246. tmp = fusion(tmp)
  247. reverted_proj_cls_token = revert_proj(tmp[:, 0:1, ...])
  248. tmp = torch.cat((reverted_proj_cls_token, outs_b[i][:, 1:, ...]), dim=1)
  249. outs.append(tmp)
  250. return outs
  251. def _compute_num_patches(img_size, patches):
  252. return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)]
  253. @register_notrace_function
  254. def scale_image(x, ss: Tuple[int, int], crop_scale: bool = False): # annotations for torchscript
  255. """
  256. Pulled out of CrossViT.forward_features to bury conditional logic in a leaf node for FX tracing.
  257. Args:
  258. x (Tensor): input image
  259. ss (tuple[int, int]): height and width to scale to
  260. crop_scale (bool): whether to crop instead of interpolate to achieve the desired scale. Defaults to False
  261. Returns:
  262. Tensor: the "scaled" image batch tensor
  263. """
  264. H, W = x.shape[-2:]
  265. if H != ss[0] or W != ss[1]:
  266. if crop_scale and ss[0] <= H and ss[1] <= W:
  267. cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.))
  268. x = x[:, :, cu:cu + ss[0], cl:cl + ss[1]]
  269. else:
  270. x = torch.nn.functional.interpolate(x, size=ss, mode='bicubic', align_corners=False)
  271. return x
  272. class CrossVit(nn.Module):
  273. """ Vision Transformer with support for patch or hybrid CNN input stage
  274. """
  275. def __init__(
  276. self,
  277. img_size: int = 224,
  278. img_scale: Tuple[float, ...] = (1.0, 1.0),
  279. patch_size: Tuple[int, ...] = (8, 16),
  280. in_chans: int = 3,
  281. num_classes: int = 1000,
  282. embed_dim: Tuple[int, ...] = (192, 384),
  283. depth: Tuple[Tuple[int, ...], ...] = ((1, 3, 1), (1, 3, 1), (1, 3, 1)),
  284. num_heads: Tuple[int, ...] = (6, 12),
  285. mlp_ratio: Tuple[float, ...] = (2., 2., 4.),
  286. multi_conv: bool = False,
  287. crop_scale: bool = False,
  288. qkv_bias: bool = True,
  289. drop_rate: float = 0.,
  290. pos_drop_rate: float = 0.,
  291. proj_drop_rate: float = 0.,
  292. attn_drop_rate: float = 0.,
  293. drop_path_rate: float = 0.,
  294. norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  295. global_pool: str = 'token',
  296. device=None,
  297. dtype=None,
  298. ):
  299. super().__init__()
  300. dd = {'device': device, 'dtype': dtype}
  301. assert global_pool in ('token', 'avg')
  302. self.num_classes = num_classes
  303. self.in_chans = in_chans
  304. self.global_pool = global_pool
  305. self.img_size = to_2tuple(img_size)
  306. img_scale = to_2tuple(img_scale)
  307. self.img_size_scaled = [tuple([int(sj * si) for sj in self.img_size]) for si in img_scale]
  308. self.crop_scale = crop_scale # crop instead of interpolate for scale
  309. num_patches = _compute_num_patches(self.img_size_scaled, patch_size)
  310. self.num_branches = len(patch_size)
  311. self.embed_dim = embed_dim
  312. self.num_features = self.head_hidden_size = sum(embed_dim)
  313. self.patch_embed = nn.ModuleList()
  314. # hard-coded for torch jit script
  315. for i in range(self.num_branches):
  316. setattr(self, f'pos_embed_{i}', nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i], **dd)))
  317. setattr(self, f'cls_token_{i}', nn.Parameter(torch.zeros(1, 1, embed_dim[i], **dd)))
  318. for im_s, p, d in zip(self.img_size_scaled, patch_size, embed_dim):
  319. self.patch_embed.append(
  320. PatchEmbed(
  321. img_size=im_s,
  322. patch_size=p,
  323. in_chans=in_chans,
  324. embed_dim=d,
  325. multi_conv=multi_conv,
  326. **dd,
  327. ))
  328. self.pos_drop = nn.Dropout(p=pos_drop_rate)
  329. total_depth = sum([sum(x[-2:]) for x in depth])
  330. dpr = calculate_drop_path_rates(drop_path_rate, total_depth) # stochastic depth decay rule
  331. dpr_ptr = 0
  332. self.blocks = nn.ModuleList()
  333. for idx, block_cfg in enumerate(depth):
  334. curr_depth = max(block_cfg[:-1]) + block_cfg[-1]
  335. dpr_ = dpr[dpr_ptr:dpr_ptr + curr_depth]
  336. blk = MultiScaleBlock(
  337. embed_dim,
  338. num_patches,
  339. block_cfg,
  340. num_heads=num_heads,
  341. mlp_ratio=mlp_ratio,
  342. qkv_bias=qkv_bias,
  343. proj_drop=proj_drop_rate,
  344. attn_drop=attn_drop_rate,
  345. drop_path=dpr_,
  346. norm_layer=norm_layer,
  347. **dd,
  348. )
  349. dpr_ptr += curr_depth
  350. self.blocks.append(blk)
  351. self.norm = nn.ModuleList([norm_layer(embed_dim[i], **dd) for i in range(self.num_branches)])
  352. self.head_drop = nn.Dropout(drop_rate)
  353. self.head = nn.ModuleList([
  354. nn.Linear(embed_dim[i], num_classes, **dd) if num_classes > 0 else nn.Identity()
  355. for i in range(self.num_branches)])
  356. for i in range(self.num_branches):
  357. trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02)
  358. trunc_normal_(getattr(self, f'cls_token_{i}'), std=.02)
  359. self.apply(self._init_weights)
  360. def _init_weights(self, m):
  361. if isinstance(m, nn.Linear):
  362. trunc_normal_(m.weight, std=.02)
  363. if isinstance(m, nn.Linear) and m.bias is not None:
  364. nn.init.constant_(m.bias, 0)
  365. elif isinstance(m, nn.LayerNorm):
  366. nn.init.constant_(m.bias, 0)
  367. nn.init.constant_(m.weight, 1.0)
  368. @torch.jit.ignore
  369. def no_weight_decay(self):
  370. out = set()
  371. for i in range(self.num_branches):
  372. out.add(f'cls_token_{i}')
  373. pe = getattr(self, f'pos_embed_{i}', None)
  374. if pe is not None and pe.requires_grad:
  375. out.add(f'pos_embed_{i}')
  376. return out
  377. @torch.jit.ignore
  378. def group_matcher(self, coarse=False):
  379. return dict(
  380. stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
  381. blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
  382. )
  383. @torch.jit.ignore
  384. def set_grad_checkpointing(self, enable=True):
  385. assert not enable, 'gradient checkpointing not supported'
  386. @torch.jit.ignore
  387. def get_classifier(self) -> nn.Module:
  388. return self.head
  389. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  390. self.num_classes = num_classes
  391. if global_pool is not None:
  392. assert global_pool in ('token', 'avg')
  393. self.global_pool = global_pool
  394. device = self.head[0].weight.device if hasattr(self.head[0], 'weight') else None
  395. dtype = self.head[0].weight.dtype if hasattr(self.head[0], 'weight') else None
  396. dd = {'device': device, 'dtype': dtype}
  397. self.head = nn.ModuleList([
  398. nn.Linear(self.embed_dim[i], num_classes, **dd) if num_classes > 0 else nn.Identity()
  399. for i in range(self.num_branches)
  400. ])
  401. def forward_features(self, x) -> List[torch.Tensor]:
  402. B = x.shape[0]
  403. xs = []
  404. for i, patch_embed in enumerate(self.patch_embed):
  405. x_ = x
  406. ss = self.img_size_scaled[i]
  407. x_ = scale_image(x_, ss, self.crop_scale)
  408. x_ = patch_embed(x_)
  409. cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
  410. cls_tokens = cls_tokens.expand(B, -1, -1)
  411. x_ = torch.cat((cls_tokens, x_), dim=1)
  412. pos_embed = self.pos_embed_0 if i == 0 else self.pos_embed_1 # hard-coded for torch jit script
  413. x_ = x_ + pos_embed
  414. x_ = self.pos_drop(x_)
  415. xs.append(x_)
  416. for i, blk in enumerate(self.blocks):
  417. xs = blk(xs)
  418. # NOTE: was before branch token section, move to here to assure all branch token are before layer norm
  419. xs = [norm(xs[i]) for i, norm in enumerate(self.norm)]
  420. return xs
  421. def forward_head(self, xs: List[torch.Tensor], pre_logits: bool = False) -> torch.Tensor:
  422. xs = [x[:, 1:].mean(dim=1) for x in xs] if self.global_pool == 'avg' else [x[:, 0] for x in xs]
  423. xs = [self.head_drop(x) for x in xs]
  424. if pre_logits or isinstance(self.head[0], nn.Identity):
  425. return torch.cat([x for x in xs], dim=1)
  426. return torch.mean(torch.stack([head(xs[i]) for i, head in enumerate(self.head)], dim=0), dim=0)
  427. def forward(self, x):
  428. xs = self.forward_features(x)
  429. x = self.forward_head(xs)
  430. return x
  431. def _create_crossvit(variant, pretrained=False, **kwargs):
  432. if kwargs.get('features_only', None):
  433. raise RuntimeError('features_only not implemented for Vision Transformer models.')
  434. def pretrained_filter_fn(state_dict):
  435. new_state_dict = {}
  436. for key in state_dict.keys():
  437. if 'pos_embed' in key or 'cls_token' in key:
  438. new_key = key.replace(".", "_")
  439. else:
  440. new_key = key
  441. new_state_dict[new_key] = state_dict[key]
  442. return new_state_dict
  443. return build_model_with_cfg(
  444. CrossVit,
  445. variant,
  446. pretrained,
  447. pretrained_filter_fn=pretrained_filter_fn,
  448. **kwargs,
  449. )
  450. def _cfg(url='', **kwargs):
  451. return {
  452. 'url': url,
  453. 'num_classes': 1000, 'input_size': (3, 240, 240), 'pool_size': None, 'crop_pct': 0.875,
  454. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True,
  455. 'first_conv': ('patch_embed.0.proj', 'patch_embed.1.proj'),
  456. 'classifier': ('head.0', 'head.1'),
  457. 'license': 'apache-2.0',
  458. **kwargs
  459. }
  460. default_cfgs = generate_default_cfgs({
  461. 'crossvit_15_240.in1k': _cfg(hf_hub_id='timm/'),
  462. 'crossvit_15_dagger_240.in1k': _cfg(
  463. hf_hub_id='timm/',
  464. first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
  465. ),
  466. 'crossvit_15_dagger_408.in1k': _cfg(
  467. hf_hub_id='timm/',
  468. input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0,
  469. ),
  470. 'crossvit_18_240.in1k': _cfg(hf_hub_id='timm/'),
  471. 'crossvit_18_dagger_240.in1k': _cfg(
  472. hf_hub_id='timm/',
  473. first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
  474. ),
  475. 'crossvit_18_dagger_408.in1k': _cfg(
  476. hf_hub_id='timm/',
  477. input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0,
  478. ),
  479. 'crossvit_9_240.in1k': _cfg(hf_hub_id='timm/'),
  480. 'crossvit_9_dagger_240.in1k': _cfg(
  481. hf_hub_id='timm/',
  482. first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
  483. ),
  484. 'crossvit_base_240.in1k': _cfg(hf_hub_id='timm/'),
  485. 'crossvit_small_240.in1k': _cfg(hf_hub_id='timm/'),
  486. 'crossvit_tiny_240.in1k': _cfg(hf_hub_id='timm/'),
  487. })
  488. @register_model
  489. def crossvit_tiny_240(pretrained=False, **kwargs) -> CrossVit:
  490. model_args = dict(
  491. img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[96, 192], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
  492. num_heads=[3, 3], mlp_ratio=[4, 4, 1])
  493. model = _create_crossvit(variant='crossvit_tiny_240', pretrained=pretrained, **dict(model_args, **kwargs))
  494. return model
  495. @register_model
  496. def crossvit_small_240(pretrained=False, **kwargs) -> CrossVit:
  497. model_args = dict(
  498. img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
  499. num_heads=[6, 6], mlp_ratio=[4, 4, 1])
  500. model = _create_crossvit(variant='crossvit_small_240', pretrained=pretrained, **dict(model_args, **kwargs))
  501. return model
  502. @register_model
  503. def crossvit_base_240(pretrained=False, **kwargs) -> CrossVit:
  504. model_args = dict(
  505. img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[384, 768], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
  506. num_heads=[12, 12], mlp_ratio=[4, 4, 1])
  507. model = _create_crossvit(variant='crossvit_base_240', pretrained=pretrained, **dict(model_args, **kwargs))
  508. return model
  509. @register_model
  510. def crossvit_9_240(pretrained=False, **kwargs) -> CrossVit:
  511. model_args = dict(
  512. img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]],
  513. num_heads=[4, 4], mlp_ratio=[3, 3, 1])
  514. model = _create_crossvit(variant='crossvit_9_240', pretrained=pretrained, **dict(model_args, **kwargs))
  515. return model
  516. @register_model
  517. def crossvit_15_240(pretrained=False, **kwargs) -> CrossVit:
  518. model_args = dict(
  519. img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
  520. num_heads=[6, 6], mlp_ratio=[3, 3, 1])
  521. model = _create_crossvit(variant='crossvit_15_240', pretrained=pretrained, **dict(model_args, **kwargs))
  522. return model
  523. @register_model
  524. def crossvit_18_240(pretrained=False, **kwargs) -> CrossVit:
  525. model_args = dict(
  526. img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
  527. num_heads=[7, 7], mlp_ratio=[3, 3, 1], **kwargs)
  528. model = _create_crossvit(variant='crossvit_18_240', pretrained=pretrained, **dict(model_args, **kwargs))
  529. return model
  530. @register_model
  531. def crossvit_9_dagger_240(pretrained=False, **kwargs) -> CrossVit:
  532. model_args = dict(
  533. img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]],
  534. num_heads=[4, 4], mlp_ratio=[3, 3, 1], multi_conv=True)
  535. model = _create_crossvit(variant='crossvit_9_dagger_240', pretrained=pretrained, **dict(model_args, **kwargs))
  536. return model
  537. @register_model
  538. def crossvit_15_dagger_240(pretrained=False, **kwargs) -> CrossVit:
  539. model_args = dict(
  540. img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
  541. num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True)
  542. model = _create_crossvit(variant='crossvit_15_dagger_240', pretrained=pretrained, **dict(model_args, **kwargs))
  543. return model
  544. @register_model
  545. def crossvit_15_dagger_408(pretrained=False, **kwargs) -> CrossVit:
  546. model_args = dict(
  547. img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
  548. num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True)
  549. model = _create_crossvit(variant='crossvit_15_dagger_408', pretrained=pretrained, **dict(model_args, **kwargs))
  550. return model
  551. @register_model
  552. def crossvit_18_dagger_240(pretrained=False, **kwargs) -> CrossVit:
  553. model_args = dict(
  554. img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
  555. num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True)
  556. model = _create_crossvit(variant='crossvit_18_dagger_240', pretrained=pretrained, **dict(model_args, **kwargs))
  557. return model
  558. @register_model
  559. def crossvit_18_dagger_408(pretrained=False, **kwargs) -> CrossVit:
  560. model_args = dict(
  561. img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
  562. num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True)
  563. model = _create_crossvit(variant='crossvit_18_dagger_408', pretrained=pretrained, **dict(model_args, **kwargs))
  564. return model