cait.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632
  1. """ Class-Attention in Image Transformers (CaiT)
  2. Paper: 'Going deeper with Image Transformers' - https://arxiv.org/abs/2103.17239
  3. Original code and weights from https://github.com/facebookresearch/deit, copyright below
  4. Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
  5. """
  6. # Copyright (c) 2015-present, Facebook, Inc.
  7. # All rights reserved.
  8. from functools import partial
  9. from typing import List, Optional, Tuple, Union, Type, Any
  10. import torch
  11. import torch.nn as nn
  12. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  13. from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, use_fused_attn
  14. from ._builder import build_model_with_cfg
  15. from ._features import feature_take_indices
  16. from ._manipulate import checkpoint, checkpoint_seq
  17. from ._registry import register_model, generate_default_cfgs
  18. __all__ = ['Cait', 'ClassAttn', 'LayerScaleBlockClassAttn', 'LayerScaleBlock', 'TalkingHeadAttn']
  19. class ClassAttn(nn.Module):
  20. # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  21. # with slight modifications to do CA
  22. fused_attn: torch.jit.Final[bool]
  23. def __init__(
  24. self,
  25. dim: int,
  26. num_heads: int = 8,
  27. qkv_bias: bool = False,
  28. attn_drop: float = 0.,
  29. proj_drop: float = 0.,
  30. device=None,
  31. dtype=None,
  32. ):
  33. super().__init__()
  34. dd = {'device': device, 'dtype': dtype}
  35. self.num_heads = num_heads
  36. head_dim = dim // num_heads
  37. self.scale = head_dim ** -0.5
  38. self.fused_attn = use_fused_attn()
  39. self.q = nn.Linear(dim, dim, bias=qkv_bias, **dd)
  40. self.k = nn.Linear(dim, dim, bias=qkv_bias, **dd)
  41. self.v = nn.Linear(dim, dim, bias=qkv_bias, **dd)
  42. self.attn_drop = nn.Dropout(attn_drop)
  43. self.proj = nn.Linear(dim, dim, **dd)
  44. self.proj_drop = nn.Dropout(proj_drop)
  45. def forward(self, x):
  46. B, N, C = x.shape
  47. q = self.q(x[:, 0]).unsqueeze(1).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
  48. k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
  49. v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
  50. if self.fused_attn:
  51. x_cls = torch.nn.functional.scaled_dot_product_attention(
  52. q, k, v,
  53. dropout_p=self.attn_drop.p if self.training else 0.,
  54. )
  55. else:
  56. q = q * self.scale
  57. attn = q @ k.transpose(-2, -1)
  58. attn = attn.softmax(dim=-1)
  59. attn = self.attn_drop(attn)
  60. x_cls = attn @ v
  61. x_cls = x_cls.transpose(1, 2).reshape(B, 1, C)
  62. x_cls = self.proj(x_cls)
  63. x_cls = self.proj_drop(x_cls)
  64. return x_cls
  65. class LayerScaleBlockClassAttn(nn.Module):
  66. # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  67. # with slight modifications to add CA and LayerScale
  68. def __init__(
  69. self,
  70. dim: int,
  71. num_heads: int,
  72. mlp_ratio: float = 4.,
  73. qkv_bias: bool = False,
  74. proj_drop: float = 0.,
  75. attn_drop: float = 0.,
  76. drop_path: float = 0.,
  77. act_layer: Type[nn.Module] = nn.GELU,
  78. norm_layer: Type[nn.Module] = nn.LayerNorm,
  79. attn_block: Type[nn.Module] = ClassAttn,
  80. mlp_block: Type[nn.Module] = Mlp,
  81. init_values: float = 1e-4,
  82. device=None,
  83. dtype=None,
  84. ):
  85. super().__init__()
  86. dd = {'device': device, 'dtype': dtype}
  87. self.norm1 = norm_layer(dim, **dd)
  88. self.attn = attn_block(
  89. dim,
  90. num_heads=num_heads,
  91. qkv_bias=qkv_bias,
  92. attn_drop=attn_drop,
  93. proj_drop=proj_drop,
  94. **dd,
  95. )
  96. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  97. self.norm2 = norm_layer(dim, **dd)
  98. mlp_hidden_dim = int(dim * mlp_ratio)
  99. self.mlp = mlp_block(
  100. in_features=dim,
  101. hidden_features=mlp_hidden_dim,
  102. act_layer=act_layer,
  103. drop=proj_drop,
  104. **dd,
  105. )
  106. self.gamma_1 = nn.Parameter(init_values * torch.ones(dim, **dd))
  107. self.gamma_2 = nn.Parameter(init_values * torch.ones(dim, **dd))
  108. def forward(self, x, x_cls):
  109. u = torch.cat((x_cls, x), dim=1)
  110. x_cls = x_cls + self.drop_path(self.gamma_1 * self.attn(self.norm1(u)))
  111. x_cls = x_cls + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x_cls)))
  112. return x_cls
  113. class TalkingHeadAttn(nn.Module):
  114. # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  115. # with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf)
  116. def __init__(
  117. self,
  118. dim: int,
  119. num_heads: int = 8,
  120. qkv_bias: bool = False,
  121. attn_drop: float = 0.,
  122. proj_drop: float = 0.,
  123. device=None,
  124. dtype=None,
  125. ):
  126. super().__init__()
  127. dd = {'device': device, 'dtype': dtype}
  128. self.num_heads = num_heads
  129. head_dim = dim // num_heads
  130. self.scale = head_dim ** -0.5
  131. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
  132. self.attn_drop = nn.Dropout(attn_drop)
  133. self.proj = nn.Linear(dim, dim, **dd)
  134. self.proj_l = nn.Linear(num_heads, num_heads, **dd)
  135. self.proj_w = nn.Linear(num_heads, num_heads, **dd)
  136. self.proj_drop = nn.Dropout(proj_drop)
  137. def forward(self, x):
  138. B, N, C = x.shape
  139. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  140. q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
  141. attn = q @ k.transpose(-2, -1)
  142. attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
  143. attn = attn.softmax(dim=-1)
  144. attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
  145. attn = self.attn_drop(attn)
  146. x = (attn @ v).transpose(1, 2).reshape(B, N, C)
  147. x = self.proj(x)
  148. x = self.proj_drop(x)
  149. return x
  150. class LayerScaleBlock(nn.Module):
  151. # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  152. # with slight modifications to add layerScale
  153. def __init__(
  154. self,
  155. dim: int,
  156. num_heads: int,
  157. mlp_ratio: float = 4.,
  158. qkv_bias: bool = False,
  159. proj_drop: float = 0.,
  160. attn_drop: float = 0.,
  161. drop_path: float = 0.,
  162. act_layer: Type[nn.Module] = nn.GELU,
  163. norm_layer: Type[nn.Module] = nn.LayerNorm,
  164. attn_block: Type[nn.Module] = TalkingHeadAttn,
  165. mlp_block: Type[nn.Module] = Mlp,
  166. init_values: float = 1e-4,
  167. device=None,
  168. dtype=None,
  169. ):
  170. super().__init__()
  171. dd = {'device': device, 'dtype': dtype}
  172. self.norm1 = norm_layer(dim, **dd)
  173. self.attn = attn_block(
  174. dim,
  175. num_heads=num_heads,
  176. qkv_bias=qkv_bias,
  177. attn_drop=attn_drop,
  178. proj_drop=proj_drop,
  179. **dd,
  180. )
  181. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  182. self.norm2 = norm_layer(dim, **dd)
  183. mlp_hidden_dim = int(dim * mlp_ratio)
  184. self.mlp = mlp_block(
  185. in_features=dim,
  186. hidden_features=mlp_hidden_dim,
  187. act_layer=act_layer,
  188. drop=proj_drop,
  189. **dd,
  190. )
  191. self.gamma_1 = nn.Parameter(init_values * torch.ones(dim, **dd))
  192. self.gamma_2 = nn.Parameter(init_values * torch.ones(dim, **dd))
  193. def forward(self, x):
  194. x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
  195. x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
  196. return x
  197. class Cait(nn.Module):
  198. # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  199. # with slight modifications to adapt to our cait models
  200. def __init__(
  201. self,
  202. img_size: int = 224,
  203. patch_size: int = 16,
  204. in_chans: int = 3,
  205. num_classes: int = 1000,
  206. global_pool: str = 'token',
  207. embed_dim: int = 768,
  208. depth: int = 12,
  209. num_heads: int = 12,
  210. mlp_ratio: float = 4.,
  211. qkv_bias: bool = True,
  212. drop_rate: float = 0.,
  213. pos_drop_rate: float = 0.,
  214. proj_drop_rate: float = 0.,
  215. attn_drop_rate: float = 0.,
  216. drop_path_rate: float = 0.,
  217. block_layers: Type[nn.Module] = LayerScaleBlock,
  218. block_layers_token: Type[nn.Module] = LayerScaleBlockClassAttn,
  219. patch_layer: Type[nn.Module] = PatchEmbed,
  220. norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  221. act_layer: Type[nn.Module] = nn.GELU,
  222. attn_block: Type[nn.Module] = TalkingHeadAttn,
  223. mlp_block: Type[nn.Module] = Mlp,
  224. init_values: float = 1e-4,
  225. attn_block_token_only: Type[nn.Module] = ClassAttn,
  226. mlp_block_token_only: Type[nn.Module] = Mlp,
  227. depth_token_only: int = 2,
  228. mlp_ratio_token_only: float = 4.0,
  229. device=None,
  230. dtype=None,
  231. ):
  232. super().__init__()
  233. dd = {'device': device, 'dtype': dtype}
  234. assert global_pool in ('', 'token', 'avg')
  235. self.num_classes = num_classes
  236. self.in_chans = in_chans
  237. self.global_pool = global_pool
  238. self.num_features = self.head_hidden_size = self.embed_dim = embed_dim
  239. self.grad_checkpointing = False
  240. self.patch_embed = patch_layer(
  241. img_size=img_size,
  242. patch_size=patch_size,
  243. in_chans=in_chans,
  244. embed_dim=embed_dim,
  245. **dd,
  246. )
  247. num_patches = self.patch_embed.num_patches
  248. r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
  249. self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, **dd))
  250. self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim, **dd))
  251. self.pos_drop = nn.Dropout(p=pos_drop_rate)
  252. dpr = [drop_path_rate for i in range(depth)]
  253. self.blocks = nn.Sequential(*[block_layers(
  254. dim=embed_dim,
  255. num_heads=num_heads,
  256. mlp_ratio=mlp_ratio,
  257. qkv_bias=qkv_bias,
  258. proj_drop=proj_drop_rate,
  259. attn_drop=attn_drop_rate,
  260. drop_path=dpr[i],
  261. norm_layer=norm_layer,
  262. act_layer=act_layer,
  263. attn_block=attn_block,
  264. mlp_block=mlp_block,
  265. init_values=init_values,
  266. **dd,
  267. ) for i in range(depth)])
  268. self.feature_info = [dict(num_chs=embed_dim, reduction=r, module=f'blocks.{i}') for i in range(depth)]
  269. self.blocks_token_only = nn.ModuleList([block_layers_token(
  270. dim=embed_dim,
  271. num_heads=num_heads,
  272. mlp_ratio=mlp_ratio_token_only,
  273. qkv_bias=qkv_bias,
  274. norm_layer=norm_layer,
  275. act_layer=act_layer,
  276. attn_block=attn_block_token_only,
  277. mlp_block=mlp_block_token_only,
  278. init_values=init_values,
  279. **dd,
  280. ) for _ in range(depth_token_only)])
  281. self.norm = norm_layer(embed_dim, **dd)
  282. self.head_drop = nn.Dropout(drop_rate)
  283. self.head = nn.Linear(embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
  284. trunc_normal_(self.pos_embed, std=.02)
  285. trunc_normal_(self.cls_token, std=.02)
  286. self.apply(self._init_weights)
  287. def _init_weights(self, m):
  288. if isinstance(m, nn.Linear):
  289. trunc_normal_(m.weight, std=.02)
  290. if isinstance(m, nn.Linear) and m.bias is not None:
  291. nn.init.constant_(m.bias, 0)
  292. elif isinstance(m, nn.LayerNorm):
  293. nn.init.constant_(m.bias, 0)
  294. nn.init.constant_(m.weight, 1.0)
  295. @torch.jit.ignore
  296. def no_weight_decay(self):
  297. return {'pos_embed', 'cls_token'}
  298. @torch.jit.ignore
  299. def set_grad_checkpointing(self, enable=True):
  300. self.grad_checkpointing = enable
  301. @torch.jit.ignore
  302. def group_matcher(self, coarse=False):
  303. def _matcher(name):
  304. if any([name.startswith(n) for n in ('cls_token', 'pos_embed', 'patch_embed')]):
  305. return 0
  306. elif name.startswith('blocks.'):
  307. return int(name.split('.')[1]) + 1
  308. elif name.startswith('blocks_token_only.'):
  309. # overlap token only blocks with last blocks
  310. to_offset = len(self.blocks) - len(self.blocks_token_only) + 1
  311. return int(name.split('.')[1]) + to_offset
  312. elif name.startswith('norm.'):
  313. return len(self.blocks)
  314. else:
  315. return float('inf')
  316. return _matcher
  317. @torch.jit.ignore
  318. def get_classifier(self) -> nn.Module:
  319. return self.head
  320. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  321. self.num_classes = num_classes
  322. if global_pool is not None:
  323. assert global_pool in ('', 'token', 'avg')
  324. self.global_pool = global_pool
  325. self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  326. def forward_intermediates(
  327. self,
  328. x: torch.Tensor,
  329. indices: Optional[Union[int, List[int]]] = None,
  330. norm: bool = False,
  331. stop_early: bool = False,
  332. output_fmt: str = 'NCHW',
  333. intermediates_only: bool = False,
  334. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  335. """ Forward features that returns intermediates.
  336. Args:
  337. x: Input image tensor
  338. indices: Take last n blocks if int, all if None, select matching indices if sequence
  339. norm: Apply norm layer to all intermediates
  340. stop_early: Stop iterating over blocks when last desired intermediate hit
  341. output_fmt: Shape of intermediate feature outputs
  342. intermediates_only: Only return intermediate features
  343. """
  344. assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
  345. reshape = output_fmt == 'NCHW'
  346. intermediates = []
  347. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  348. # forward pass
  349. B, _, height, width = x.shape
  350. x = self.patch_embed(x)
  351. x = x + self.pos_embed
  352. x = self.pos_drop(x)
  353. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  354. blocks = self.blocks
  355. else:
  356. blocks = self.blocks[:max_index + 1]
  357. for i, blk in enumerate(blocks):
  358. if self.grad_checkpointing and not torch.jit.is_scripting():
  359. x = checkpoint(blk, x)
  360. else:
  361. x = blk(x)
  362. if i in take_indices:
  363. # normalize intermediates with final norm layer if enabled
  364. intermediates.append(self.norm(x) if norm else x)
  365. # process intermediates
  366. if reshape:
  367. # reshape to BCHW output format
  368. H, W = self.patch_embed.dynamic_feat_size((height, width))
  369. intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
  370. if intermediates_only:
  371. return intermediates
  372. # NOTE not supporting return of class tokens
  373. cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
  374. for i, blk in enumerate(self.blocks_token_only):
  375. cls_tokens = blk(x, cls_tokens)
  376. x = torch.cat((cls_tokens, x), dim=1)
  377. x = self.norm(x)
  378. return x, intermediates
  379. def prune_intermediate_layers(
  380. self,
  381. indices: Union[int, List[int]] = 1,
  382. prune_norm: bool = False,
  383. prune_head: bool = True,
  384. ):
  385. """ Prune layers not required for specified intermediates.
  386. """
  387. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  388. self.blocks = self.blocks[:max_index + 1] # truncate blocks
  389. if prune_norm:
  390. self.norm = nn.Identity()
  391. if prune_head:
  392. self.blocks_token_only = nn.ModuleList() # prune token blocks with head
  393. self.reset_classifier(0, '')
  394. return take_indices
  395. def forward_features(self, x):
  396. x = self.patch_embed(x)
  397. x = x + self.pos_embed
  398. x = self.pos_drop(x)
  399. if self.grad_checkpointing and not torch.jit.is_scripting():
  400. x = checkpoint_seq(self.blocks, x)
  401. else:
  402. x = self.blocks(x)
  403. cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
  404. for i, blk in enumerate(self.blocks_token_only):
  405. cls_tokens = blk(x, cls_tokens)
  406. x = torch.cat((cls_tokens, x), dim=1)
  407. x = self.norm(x)
  408. return x
  409. def forward_head(self, x, pre_logits: bool = False):
  410. if self.global_pool:
  411. x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
  412. x = self.head_drop(x)
  413. return x if pre_logits else self.head(x)
  414. def forward(self, x):
  415. x = self.forward_features(x)
  416. x = self.forward_head(x)
  417. return x
  418. def checkpoint_filter_fn(state_dict, model=None):
  419. if 'model' in state_dict:
  420. state_dict = state_dict['model']
  421. checkpoint_no_module = {}
  422. for k, v in state_dict.items():
  423. checkpoint_no_module[k.replace('module.', '')] = v
  424. return checkpoint_no_module
  425. def _create_cait(variant, pretrained=False, **kwargs):
  426. out_indices = kwargs.pop('out_indices', 3)
  427. model = build_model_with_cfg(
  428. Cait,
  429. variant,
  430. pretrained,
  431. pretrained_filter_fn=checkpoint_filter_fn,
  432. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  433. **kwargs,
  434. )
  435. return model
  436. def _cfg(url='', **kwargs):
  437. return {
  438. 'url': url,
  439. 'num_classes': 1000, 'input_size': (3, 384, 384), 'pool_size': None,
  440. 'crop_pct': 1.0, 'interpolation': 'bicubic', 'fixed_input_size': True,
  441. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  442. 'first_conv': 'patch_embed.proj', 'classifier': 'head',
  443. 'license': 'apache-2.0',
  444. **kwargs
  445. }
  446. default_cfgs = generate_default_cfgs({
  447. 'cait_xxs24_224.fb_dist_in1k': _cfg(
  448. hf_hub_id='timm/',
  449. url='https://dl.fbaipublicfiles.com/deit/XXS24_224.pth',
  450. input_size=(3, 224, 224),
  451. ),
  452. 'cait_xxs24_384.fb_dist_in1k': _cfg(
  453. hf_hub_id='timm/',
  454. url='https://dl.fbaipublicfiles.com/deit/XXS24_384.pth',
  455. ),
  456. 'cait_xxs36_224.fb_dist_in1k': _cfg(
  457. hf_hub_id='timm/',
  458. url='https://dl.fbaipublicfiles.com/deit/XXS36_224.pth',
  459. input_size=(3, 224, 224),
  460. ),
  461. 'cait_xxs36_384.fb_dist_in1k': _cfg(
  462. hf_hub_id='timm/',
  463. url='https://dl.fbaipublicfiles.com/deit/XXS36_384.pth',
  464. ),
  465. 'cait_xs24_384.fb_dist_in1k': _cfg(
  466. hf_hub_id='timm/',
  467. url='https://dl.fbaipublicfiles.com/deit/XS24_384.pth',
  468. ),
  469. 'cait_s24_224.fb_dist_in1k': _cfg(
  470. hf_hub_id='timm/',
  471. url='https://dl.fbaipublicfiles.com/deit/S24_224.pth',
  472. input_size=(3, 224, 224),
  473. ),
  474. 'cait_s24_384.fb_dist_in1k': _cfg(
  475. hf_hub_id='timm/',
  476. url='https://dl.fbaipublicfiles.com/deit/S24_384.pth',
  477. ),
  478. 'cait_s36_384.fb_dist_in1k': _cfg(
  479. hf_hub_id='timm/',
  480. url='https://dl.fbaipublicfiles.com/deit/S36_384.pth',
  481. ),
  482. 'cait_m36_384.fb_dist_in1k': _cfg(
  483. hf_hub_id='timm/',
  484. url='https://dl.fbaipublicfiles.com/deit/M36_384.pth',
  485. ),
  486. 'cait_m48_448.fb_dist_in1k': _cfg(
  487. hf_hub_id='timm/',
  488. url='https://dl.fbaipublicfiles.com/deit/M48_448.pth',
  489. input_size=(3, 448, 448),
  490. ),
  491. })
  492. @register_model
  493. def cait_xxs24_224(pretrained=False, **kwargs) -> Cait:
  494. model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_values=1e-5)
  495. model = _create_cait('cait_xxs24_224', pretrained=pretrained, **dict(model_args, **kwargs))
  496. return model
  497. @register_model
  498. def cait_xxs24_384(pretrained=False, **kwargs) -> Cait:
  499. model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_values=1e-5)
  500. model = _create_cait('cait_xxs24_384', pretrained=pretrained, **dict(model_args, **kwargs))
  501. return model
  502. @register_model
  503. def cait_xxs36_224(pretrained=False, **kwargs) -> Cait:
  504. model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_values=1e-5)
  505. model = _create_cait('cait_xxs36_224', pretrained=pretrained, **dict(model_args, **kwargs))
  506. return model
  507. @register_model
  508. def cait_xxs36_384(pretrained=False, **kwargs) -> Cait:
  509. model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_values=1e-5)
  510. model = _create_cait('cait_xxs36_384', pretrained=pretrained, **dict(model_args, **kwargs))
  511. return model
  512. @register_model
  513. def cait_xs24_384(pretrained=False, **kwargs) -> Cait:
  514. model_args = dict(patch_size=16, embed_dim=288, depth=24, num_heads=6, init_values=1e-5)
  515. model = _create_cait('cait_xs24_384', pretrained=pretrained, **dict(model_args, **kwargs))
  516. return model
  517. @register_model
  518. def cait_s24_224(pretrained=False, **kwargs) -> Cait:
  519. model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_values=1e-5)
  520. model = _create_cait('cait_s24_224', pretrained=pretrained, **dict(model_args, **kwargs))
  521. return model
  522. @register_model
  523. def cait_s24_384(pretrained=False, **kwargs) -> Cait:
  524. model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_values=1e-5)
  525. model = _create_cait('cait_s24_384', pretrained=pretrained, **dict(model_args, **kwargs))
  526. return model
  527. @register_model
  528. def cait_s36_384(pretrained=False, **kwargs) -> Cait:
  529. model_args = dict(patch_size=16, embed_dim=384, depth=36, num_heads=8, init_values=1e-6)
  530. model = _create_cait('cait_s36_384', pretrained=pretrained, **dict(model_args, **kwargs))
  531. return model
  532. @register_model
  533. def cait_m36_384(pretrained=False, **kwargs) -> Cait:
  534. model_args = dict(patch_size=16, embed_dim=768, depth=36, num_heads=16, init_values=1e-6)
  535. model = _create_cait('cait_m36_384', pretrained=pretrained, **dict(model_args, **kwargs))
  536. return model
  537. @register_model
  538. def cait_m48_448(pretrained=False, **kwargs) -> Cait:
  539. model_args = dict(patch_size=16, embed_dim=768, depth=48, num_heads=16, init_values=1e-6)
  540. model = _create_cait('cait_m48_448', pretrained=pretrained, **dict(model_args, **kwargs))
  541. return model