xcit.py 44 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089
  1. """ Cross-Covariance Image Transformer (XCiT) in PyTorch
  2. Paper:
  3. - https://arxiv.org/abs/2106.09681
  4. Same as the official implementation, with some minor adaptations, original copyright below
  5. - https://github.com/facebookresearch/xcit/blob/master/xcit.py
  6. Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
  7. """
  8. # Copyright (c) 2015-present, Facebook, Inc.
  9. # All rights reserved.
  10. import math
  11. from functools import partial
  12. from typing import List, Optional, Tuple, Union, Type, Any
  13. import torch
  14. import torch.nn as nn
  15. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  16. from timm.layers import DropPath, trunc_normal_, to_2tuple, use_fused_attn, Mlp
  17. from ._builder import build_model_with_cfg
  18. from ._features import feature_take_indices
  19. from ._features_fx import register_notrace_module
  20. from ._manipulate import checkpoint
  21. from ._registry import register_model, generate_default_cfgs, register_model_deprecations
  22. from .cait import ClassAttn
  23. __all__ = ['Xcit'] # model_registry will add each entrypoint fn to this
  24. @register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method
  25. class PositionalEncodingFourier(nn.Module):
  26. """
  27. Positional encoding relying on a fourier kernel matching the one used in the "Attention is all you Need" paper.
  28. Based on the official XCiT code
  29. - https://github.com/facebookresearch/xcit/blob/master/xcit.py
  30. """
  31. def __init__(
  32. self,
  33. hidden_dim: int = 32,
  34. dim: int = 768,
  35. temperature: float = 10000,
  36. device=None,
  37. dtype=None,
  38. ):
  39. dd = {'device': device, 'dtype': dtype}
  40. super().__init__()
  41. self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1, **dd)
  42. self.scale = 2 * math.pi
  43. self.temperature = temperature
  44. self.hidden_dim = hidden_dim
  45. self.dim = dim
  46. self.eps = 1e-6
  47. def forward(self, B: int, H: int, W: int):
  48. device = self.token_projection.weight.device
  49. dtype = self.token_projection.weight.dtype
  50. y_embed = torch.arange(1, H + 1, device=device).to(torch.float32).unsqueeze(1).repeat(1, 1, W)
  51. x_embed = torch.arange(1, W + 1, device=device).to(torch.float32).repeat(1, H, 1)
  52. y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale
  53. x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale
  54. dim_t = torch.arange(self.hidden_dim, device=device).to(torch.float32)
  55. dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim)
  56. pos_x = x_embed[:, :, :, None] / dim_t
  57. pos_y = y_embed[:, :, :, None] / dim_t
  58. pos_x = torch.stack([pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()], dim=4).flatten(3)
  59. pos_y = torch.stack([pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()], dim=4).flatten(3)
  60. pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
  61. pos = self.token_projection(pos.to(dtype))
  62. return pos.repeat(B, 1, 1, 1) # (B, C, H, W)
  63. def conv3x3(in_planes, out_planes, stride=1, device=None, dtype=None):
  64. """3x3 convolution + batch norm"""
  65. dd = {'device': device, 'dtype': dtype}
  66. return torch.nn.Sequential(
  67. nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False, **dd),
  68. nn.BatchNorm2d(out_planes, **dd)
  69. )
  70. class ConvPatchEmbed(nn.Module):
  71. """Image to Patch Embedding using multiple convolutional layers"""
  72. def __init__(
  73. self,
  74. img_size: int = 224,
  75. patch_size: int = 16,
  76. in_chans: int = 3,
  77. embed_dim: int = 768,
  78. act_layer: Type[nn.Module] = nn.GELU,
  79. device=None,
  80. dtype=None,
  81. ):
  82. dd = {'device': device, 'dtype': dtype}
  83. super().__init__()
  84. img_size = to_2tuple(img_size)
  85. num_patches = (img_size[1] // patch_size) * (img_size[0] // patch_size)
  86. self.img_size = img_size
  87. self.patch_size = patch_size
  88. self.num_patches = num_patches
  89. if patch_size == 16:
  90. self.proj = torch.nn.Sequential(
  91. conv3x3(in_chans, embed_dim // 8, 2, **dd),
  92. act_layer(),
  93. conv3x3(embed_dim // 8, embed_dim // 4, 2, **dd),
  94. act_layer(),
  95. conv3x3(embed_dim // 4, embed_dim // 2, 2, **dd),
  96. act_layer(),
  97. conv3x3(embed_dim // 2, embed_dim, 2, **dd),
  98. )
  99. elif patch_size == 8:
  100. self.proj = torch.nn.Sequential(
  101. conv3x3(in_chans, embed_dim // 4, 2, **dd),
  102. act_layer(),
  103. conv3x3(embed_dim // 4, embed_dim // 2, 2, **dd),
  104. act_layer(),
  105. conv3x3(embed_dim // 2, embed_dim, 2, **dd),
  106. )
  107. else:
  108. raise('For convolutional projection, patch size has to be in [8, 16]')
  109. def forward(self, x):
  110. x = self.proj(x)
  111. Hp, Wp = x.shape[2], x.shape[3]
  112. x = x.flatten(2).transpose(1, 2) # (B, N, C)
  113. return x, (Hp, Wp)
  114. class LPI(nn.Module):
  115. """
  116. Local Patch Interaction module that allows explicit communication between tokens in 3x3 windows to augment the
  117. implicit communication performed by the block diagonal scatter attention. Implemented using 2 layers of separable
  118. 3x3 convolutions with GeLU and BatchNorm2d
  119. """
  120. def __init__(
  121. self,
  122. in_features: int,
  123. out_features: Optional[int] = None,
  124. act_layer: Type[nn.Module] = nn.GELU,
  125. kernel_size: int = 3,
  126. device=None,
  127. dtype=None,
  128. ):
  129. super().__init__()
  130. dd = {'device': device, 'dtype': dtype}
  131. out_features = out_features or in_features
  132. padding = kernel_size // 2
  133. self.conv1 = torch.nn.Conv2d(
  134. in_features, in_features, kernel_size=kernel_size, padding=padding, groups=in_features, **dd)
  135. self.act = act_layer()
  136. self.bn = nn.BatchNorm2d(in_features, **dd)
  137. self.conv2 = torch.nn.Conv2d(
  138. in_features, out_features, kernel_size=kernel_size, padding=padding, groups=out_features, **dd)
  139. def forward(self, x, H: int, W: int):
  140. B, N, C = x.shape
  141. x = x.permute(0, 2, 1).reshape(B, C, H, W)
  142. x = self.conv1(x)
  143. x = self.act(x)
  144. x = self.bn(x)
  145. x = self.conv2(x)
  146. x = x.reshape(B, C, N).permute(0, 2, 1)
  147. return x
  148. class ClassAttentionBlock(nn.Module):
  149. """Class Attention Layer as in CaiT https://arxiv.org/abs/2103.17239"""
  150. def __init__(
  151. self,
  152. dim: int,
  153. num_heads: int,
  154. mlp_ratio: float = 4.,
  155. qkv_bias: bool = False,
  156. proj_drop: float = 0.,
  157. attn_drop: float = 0.,
  158. drop_path: float = 0.,
  159. act_layer: Type[nn.Module] = nn.GELU,
  160. norm_layer: Type[nn.Module] = nn.LayerNorm,
  161. eta: Optional[float] = 1.,
  162. tokens_norm: bool = False,
  163. device=None,
  164. dtype=None,
  165. ):
  166. dd = {'device': device, 'dtype': dtype}
  167. super().__init__()
  168. self.norm1 = norm_layer(dim, **dd)
  169. self.attn = ClassAttn(
  170. dim,
  171. num_heads=num_heads,
  172. qkv_bias=qkv_bias,
  173. attn_drop=attn_drop,
  174. proj_drop=proj_drop,
  175. **dd,
  176. )
  177. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  178. self.norm2 = norm_layer(dim, **dd)
  179. self.mlp = Mlp(
  180. in_features=dim,
  181. hidden_features=int(dim * mlp_ratio),
  182. act_layer=act_layer,
  183. drop=proj_drop,
  184. **dd,
  185. )
  186. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  187. if eta is not None: # LayerScale Initialization (no layerscale when None)
  188. self.gamma1 = nn.Parameter(eta * torch.ones(dim, **dd))
  189. self.gamma2 = nn.Parameter(eta * torch.ones(dim, **dd))
  190. else:
  191. self.gamma1, self.gamma2 = 1.0, 1.0
  192. # See https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721
  193. self.tokens_norm = tokens_norm
  194. def forward(self, x):
  195. x_norm1 = self.norm1(x)
  196. x_attn = torch.cat([self.attn(x_norm1), x_norm1[:, 1:]], dim=1)
  197. x = x + self.drop_path1(self.gamma1 * x_attn)
  198. if self.tokens_norm:
  199. x = self.norm2(x)
  200. else:
  201. x = torch.cat([self.norm2(x[:, 0:1]), x[:, 1:]], dim=1)
  202. x_res = x
  203. cls_token = x[:, 0:1]
  204. cls_token = self.gamma2 * self.mlp(cls_token)
  205. x = torch.cat([cls_token, x[:, 1:]], dim=1)
  206. x = x_res + self.drop_path2(x)
  207. return x
  208. class XCA(nn.Module):
  209. fused_attn: torch.jit.Final[bool]
  210. """ Cross-Covariance Attention (XCA)
  211. Operation where the channels are updated using a weighted sum. The weights are obtained from the (softmax
  212. normalized) Cross-covariance matrix (Q^T \\cdot K \\in d_h \\times d_h)
  213. """
  214. def __init__(
  215. self,
  216. dim: int,
  217. num_heads: int = 8,
  218. qkv_bias: bool = False,
  219. attn_drop: float = 0.,
  220. proj_drop: float = 0.,
  221. device=None,
  222. dtype=None,
  223. ):
  224. dd = {'device': device, 'dtype': dtype}
  225. super().__init__()
  226. self.num_heads = num_heads
  227. self.fused_attn = use_fused_attn(experimental=True)
  228. self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1, **dd))
  229. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
  230. self.attn_drop = nn.Dropout(attn_drop)
  231. self.proj = nn.Linear(dim, dim, **dd)
  232. self.proj_drop = nn.Dropout(proj_drop)
  233. def forward(self, x):
  234. B, N, C = x.shape
  235. # Result of next line is (qkv, B, num (H)eads, (C')hannels per head, N)
  236. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1)
  237. q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
  238. if self.fused_attn:
  239. q = torch.nn.functional.normalize(q, dim=-1) * self.temperature
  240. k = torch.nn.functional.normalize(k, dim=-1)
  241. x = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=1.0)
  242. else:
  243. # Paper section 3.2 l2-Normalization and temperature scaling
  244. q = torch.nn.functional.normalize(q, dim=-1)
  245. k = torch.nn.functional.normalize(k, dim=-1)
  246. attn = (q @ k.transpose(-2, -1)) * self.temperature
  247. attn = attn.softmax(dim=-1)
  248. attn = self.attn_drop(attn)
  249. x = attn @ v
  250. x = x.permute(0, 3, 1, 2).reshape(B, N, C)
  251. x = self.proj(x)
  252. x = self.proj_drop(x)
  253. return x
  254. @torch.jit.ignore
  255. def no_weight_decay(self):
  256. return {'temperature'}
  257. class XCABlock(nn.Module):
  258. def __init__(
  259. self,
  260. dim: int,
  261. num_heads: int,
  262. mlp_ratio: float = 4.,
  263. qkv_bias: bool = False,
  264. proj_drop: float = 0.,
  265. attn_drop: float = 0.,
  266. drop_path: float = 0.,
  267. act_layer: Type[nn.Module] = nn.GELU,
  268. norm_layer: Type[nn.Module] = nn.LayerNorm,
  269. eta: float = 1.,
  270. device=None,
  271. dtype=None,
  272. ):
  273. dd = {'device': device, 'dtype': dtype}
  274. super().__init__()
  275. self.norm1 = norm_layer(dim, **dd)
  276. self.attn = XCA(
  277. dim,
  278. num_heads=num_heads,
  279. qkv_bias=qkv_bias,
  280. attn_drop=attn_drop,
  281. proj_drop=proj_drop,
  282. **dd,
  283. )
  284. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  285. self.norm3 = norm_layer(dim, **dd)
  286. self.local_mp = LPI(in_features=dim, act_layer=act_layer, **dd)
  287. self.drop_path3 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  288. self.norm2 = norm_layer(dim, **dd)
  289. self.mlp = Mlp(
  290. in_features=dim,
  291. hidden_features=int(dim * mlp_ratio),
  292. act_layer=act_layer,
  293. drop=proj_drop,
  294. **dd,
  295. )
  296. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  297. self.gamma1 = nn.Parameter(eta * torch.ones(dim, **dd))
  298. self.gamma3 = nn.Parameter(eta * torch.ones(dim, **dd))
  299. self.gamma2 = nn.Parameter(eta * torch.ones(dim, **dd))
  300. def forward(self, x, H: int, W: int):
  301. x = x + self.drop_path1(self.gamma1 * self.attn(self.norm1(x)))
  302. # NOTE official code has 3 then 2, so keeping it the same to be consistent with loaded weights
  303. # See https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721
  304. x = x + self.drop_path3(self.gamma3 * self.local_mp(self.norm3(x), H, W))
  305. x = x + self.drop_path2(self.gamma2 * self.mlp(self.norm2(x)))
  306. return x
  307. class Xcit(nn.Module):
  308. """
  309. Based on timm and DeiT code bases
  310. https://github.com/rwightman/pytorch-image-models/tree/master/timm
  311. https://github.com/facebookresearch/deit/
  312. """
  313. def __init__(
  314. self,
  315. img_size: Union[int, Tuple[int, int]] = 224,
  316. patch_size: int = 16,
  317. in_chans: int = 3,
  318. num_classes: int = 1000,
  319. global_pool: str = 'token',
  320. embed_dim: int = 768,
  321. depth: int = 12,
  322. num_heads: int = 12,
  323. mlp_ratio: float = 4.,
  324. qkv_bias: bool = True,
  325. drop_rate: float = 0.,
  326. pos_drop_rate: float = 0.,
  327. proj_drop_rate: float = 0.,
  328. attn_drop_rate: float = 0.,
  329. drop_path_rate: float = 0.,
  330. act_layer: Optional[Type[nn.Module]] = None,
  331. norm_layer: Optional[Type[nn.Module]] = None,
  332. cls_attn_layers: int = 2,
  333. use_pos_embed: bool = True,
  334. eta: float = 1.,
  335. tokens_norm: bool = False,
  336. device=None,
  337. dtype=None,
  338. ):
  339. """
  340. Args:
  341. img_size (int, tuple): input image size
  342. patch_size (int): patch size
  343. in_chans (int): number of input channels
  344. num_classes (int): number of classes for classification head
  345. embed_dim (int): embedding dimension
  346. depth (int): depth of transformer
  347. num_heads (int): number of attention heads
  348. mlp_ratio (int): ratio of mlp hidden dim to embedding dim
  349. qkv_bias (bool): enable bias for qkv if True
  350. drop_rate (float): dropout rate after positional embedding, and in XCA/CA projection + MLP
  351. pos_drop_rate: position embedding dropout rate
  352. proj_drop_rate (float): projection dropout rate
  353. attn_drop_rate (float): attention dropout rate
  354. drop_path_rate (float): stochastic depth rate (constant across all layers)
  355. norm_layer: (nn.Module): normalization layer
  356. cls_attn_layers: (int) Depth of Class attention layers
  357. use_pos_embed: (bool) whether to use positional encoding
  358. eta: (float) layerscale initialization value
  359. tokens_norm: (bool) Whether to normalize all tokens or just the cls_token in the CA
  360. Notes:
  361. - Although `layer_norm` is user specifiable, there are hard-coded `BatchNorm2d`s in the local patch
  362. interaction (class LPI) and the patch embedding (class ConvPatchEmbed)
  363. """
  364. super().__init__()
  365. dd = {'device': device, 'dtype': dtype}
  366. assert global_pool in ('', 'avg', 'token')
  367. img_size = to_2tuple(img_size)
  368. assert (img_size[0] % patch_size == 0) and (img_size[0] % patch_size == 0), \
  369. '`patch_size` should divide image dimensions evenly'
  370. norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
  371. act_layer = act_layer or nn.GELU
  372. self.num_classes = num_classes
  373. self.in_chans = in_chans
  374. self.num_features = self.head_hidden_size = self.embed_dim = embed_dim
  375. self.global_pool = global_pool
  376. self.grad_checkpointing = False
  377. self.patch_embed = ConvPatchEmbed(
  378. img_size=img_size,
  379. patch_size=patch_size,
  380. in_chans=in_chans,
  381. embed_dim=embed_dim,
  382. act_layer=act_layer,
  383. **dd,
  384. )
  385. r = patch_size
  386. self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, **dd))
  387. if use_pos_embed:
  388. self.pos_embed = PositionalEncodingFourier(dim=embed_dim, **dd)
  389. else:
  390. self.pos_embed = None
  391. self.pos_drop = nn.Dropout(p=pos_drop_rate)
  392. self.blocks = nn.ModuleList([
  393. XCABlock(
  394. dim=embed_dim,
  395. num_heads=num_heads,
  396. mlp_ratio=mlp_ratio,
  397. qkv_bias=qkv_bias,
  398. proj_drop=proj_drop_rate,
  399. attn_drop=attn_drop_rate,
  400. drop_path=drop_path_rate,
  401. act_layer=act_layer,
  402. norm_layer=norm_layer,
  403. eta=eta,
  404. **dd,
  405. )
  406. for _ in range(depth)])
  407. self.feature_info = [dict(num_chs=embed_dim, reduction=r, module=f'blocks.{i}') for i in range(depth)]
  408. self.cls_attn_blocks = nn.ModuleList([
  409. ClassAttentionBlock(
  410. dim=embed_dim,
  411. num_heads=num_heads,
  412. mlp_ratio=mlp_ratio,
  413. qkv_bias=qkv_bias,
  414. proj_drop=drop_rate,
  415. attn_drop=attn_drop_rate,
  416. act_layer=act_layer,
  417. norm_layer=norm_layer,
  418. eta=eta,
  419. tokens_norm=tokens_norm,
  420. **dd,
  421. )
  422. for _ in range(cls_attn_layers)])
  423. # Classifier head
  424. self.norm = norm_layer(embed_dim, **dd)
  425. self.head_drop = nn.Dropout(drop_rate)
  426. self.head = nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity()
  427. # Init weights
  428. trunc_normal_(self.cls_token, std=.02)
  429. self.apply(self._init_weights)
  430. def _init_weights(self, m):
  431. if isinstance(m, nn.Linear):
  432. trunc_normal_(m.weight, std=.02)
  433. if isinstance(m, nn.Linear) and m.bias is not None:
  434. nn.init.constant_(m.bias, 0)
  435. @torch.jit.ignore
  436. def no_weight_decay(self):
  437. return {'pos_embed', 'cls_token'}
  438. @torch.jit.ignore
  439. def group_matcher(self, coarse=False):
  440. return dict(
  441. stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
  442. blocks=r'^blocks\.(\d+)',
  443. cls_attn_blocks=[(r'^cls_attn_blocks\.(\d+)', None), (r'^norm', (99999,))]
  444. )
  445. @torch.jit.ignore
  446. def set_grad_checkpointing(self, enable=True):
  447. self.grad_checkpointing = enable
  448. @torch.jit.ignore
  449. def get_classifier(self) -> nn.Module:
  450. return self.head
  451. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  452. self.num_classes = num_classes
  453. if global_pool is not None:
  454. assert global_pool in ('', 'avg', 'token')
  455. self.global_pool = global_pool
  456. device = self.head.weight.device if hasattr(self.head, 'weight') else None
  457. dtype = self.head.weight.dtype if hasattr(self.head, 'weight') else None
  458. self.head = nn.Linear(self.num_features, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity()
  459. def forward_intermediates(
  460. self,
  461. x: torch.Tensor,
  462. indices: Optional[Union[int, List[int]]] = None,
  463. norm: bool = False,
  464. stop_early: bool = False,
  465. output_fmt: str = 'NCHW',
  466. intermediates_only: bool = False,
  467. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  468. """ Forward features that returns intermediates.
  469. Args:
  470. x: Input image tensor
  471. indices: Take last n blocks if int, all if None, select matching indices if sequence
  472. norm: Apply norm layer to all intermediates
  473. stop_early: Stop iterating over blocks when last desired intermediate hit
  474. output_fmt: Shape of intermediate feature outputs
  475. intermediates_only: Only return intermediate features
  476. Returns:
  477. """
  478. assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
  479. reshape = output_fmt == 'NCHW'
  480. intermediates = []
  481. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  482. # forward pass
  483. B, _, height, width = x.shape
  484. x, (Hp, Wp) = self.patch_embed(x)
  485. if self.pos_embed is not None:
  486. # `pos_embed` (B, C, Hp, Wp), reshape -> (B, C, N), permute -> (B, N, C)
  487. pos_encoding = self.pos_embed(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
  488. x = x + pos_encoding
  489. x = self.pos_drop(x)
  490. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  491. blocks = self.blocks
  492. else:
  493. blocks = self.blocks[:max_index + 1]
  494. for i, blk in enumerate(blocks):
  495. if self.grad_checkpointing and not torch.jit.is_scripting():
  496. x = checkpoint(blk, x, Hp, Wp)
  497. else:
  498. x = blk(x, Hp, Wp)
  499. if i in take_indices:
  500. # normalize intermediates with final norm layer if enabled
  501. intermediates.append(self.norm(x) if norm else x)
  502. # process intermediates
  503. if reshape:
  504. # reshape to BCHW output format
  505. intermediates = [y.reshape(B, Hp, Wp, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
  506. if intermediates_only:
  507. return intermediates
  508. # NOTE not supporting return of class tokens
  509. x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
  510. for blk in self.cls_attn_blocks:
  511. if self.grad_checkpointing and not torch.jit.is_scripting():
  512. x = checkpoint(blk, x)
  513. else:
  514. x = blk(x)
  515. x = self.norm(x)
  516. return x, intermediates
  517. def prune_intermediate_layers(
  518. self,
  519. indices: Union[int, List[int]] = 1,
  520. prune_norm: bool = False,
  521. prune_head: bool = True,
  522. ):
  523. """ Prune layers not required for specified intermediates.
  524. """
  525. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  526. self.blocks = self.blocks[:max_index + 1] # truncate blocks
  527. if prune_norm:
  528. self.norm = nn.Identity()
  529. if prune_head:
  530. self.cls_attn_blocks = nn.ModuleList() # prune token blocks with head
  531. self.reset_classifier(0, '')
  532. return take_indices
  533. def forward_features(self, x):
  534. B = x.shape[0]
  535. # x is (B, N, C). (Hp, Hw) is (height in units of patches, width in units of patches)
  536. x, (Hp, Wp) = self.patch_embed(x)
  537. if self.pos_embed is not None:
  538. # `pos_embed` (B, C, Hp, Wp), reshape -> (B, C, N), permute -> (B, N, C)
  539. pos_encoding = self.pos_embed(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
  540. x = x + pos_encoding
  541. x = self.pos_drop(x)
  542. for blk in self.blocks:
  543. if self.grad_checkpointing and not torch.jit.is_scripting():
  544. x = checkpoint(blk, x, Hp, Wp)
  545. else:
  546. x = blk(x, Hp, Wp)
  547. x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
  548. for blk in self.cls_attn_blocks:
  549. if self.grad_checkpointing and not torch.jit.is_scripting():
  550. x = checkpoint(blk, x)
  551. else:
  552. x = blk(x)
  553. x = self.norm(x)
  554. return x
  555. def forward_head(self, x, pre_logits: bool = False):
  556. if self.global_pool:
  557. x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
  558. x = self.head_drop(x)
  559. return x if pre_logits else self.head(x)
  560. def forward(self, x):
  561. x = self.forward_features(x)
  562. x = self.forward_head(x)
  563. return x
  564. def checkpoint_filter_fn(state_dict, model):
  565. if 'model' in state_dict:
  566. state_dict = state_dict['model']
  567. # For consistency with timm's transformer models while being compatible with official weights source we rename
  568. # pos_embeder to pos_embed. Also account for use_pos_embed == False
  569. use_pos_embed = getattr(model, 'pos_embed', None) is not None
  570. pos_embed_keys = [k for k in state_dict if k.startswith('pos_embed')]
  571. for k in pos_embed_keys:
  572. if use_pos_embed:
  573. state_dict[k.replace('pos_embeder.', 'pos_embed.')] = state_dict.pop(k)
  574. else:
  575. del state_dict[k]
  576. # timm's implementation of class attention in CaiT is slightly more efficient as it does not compute query vectors
  577. # for all tokens, just the class token. To use official weights source we must split qkv into q, k, v
  578. if 'cls_attn_blocks.0.attn.qkv.weight' in state_dict and 'cls_attn_blocks.0.attn.q.weight' in model.state_dict():
  579. num_ca_blocks = len(model.cls_attn_blocks)
  580. for i in range(num_ca_blocks):
  581. qkv_weight = state_dict.pop(f'cls_attn_blocks.{i}.attn.qkv.weight')
  582. qkv_weight = qkv_weight.reshape(3, -1, qkv_weight.shape[-1])
  583. for j, subscript in enumerate('qkv'):
  584. state_dict[f'cls_attn_blocks.{i}.attn.{subscript}.weight'] = qkv_weight[j]
  585. qkv_bias = state_dict.pop(f'cls_attn_blocks.{i}.attn.qkv.bias', None)
  586. if qkv_bias is not None:
  587. qkv_bias = qkv_bias.reshape(3, -1)
  588. for j, subscript in enumerate('qkv'):
  589. state_dict[f'cls_attn_blocks.{i}.attn.{subscript}.bias'] = qkv_bias[j]
  590. return state_dict
  591. def _create_xcit(variant, pretrained=False, default_cfg=None, **kwargs):
  592. out_indices = kwargs.pop('out_indices', 3)
  593. model = build_model_with_cfg(
  594. Xcit,
  595. variant,
  596. pretrained,
  597. pretrained_filter_fn=checkpoint_filter_fn,
  598. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  599. **kwargs,
  600. )
  601. return model
  602. def _cfg(url='', **kwargs):
  603. return {
  604. 'url': url,
  605. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  606. 'crop_pct': 1.0, 'interpolation': 'bicubic', 'fixed_input_size': True,
  607. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  608. 'first_conv': 'patch_embed.proj.0.0', 'classifier': 'head',
  609. 'license': 'apache-2.0', **kwargs
  610. }
  611. default_cfgs = generate_default_cfgs({
  612. # Patch size 16
  613. 'xcit_nano_12_p16_224.fb_in1k': _cfg(
  614. hf_hub_id='timm/',
  615. url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_224.pth'),
  616. 'xcit_nano_12_p16_224.fb_dist_in1k': _cfg(
  617. hf_hub_id='timm/',
  618. url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_224_dist.pth'),
  619. 'xcit_nano_12_p16_384.fb_dist_in1k': _cfg(
  620. hf_hub_id='timm/',
  621. url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_384_dist.pth', input_size=(3, 384, 384)),
  622. 'xcit_tiny_12_p16_224.fb_in1k': _cfg(
  623. hf_hub_id='timm/',
  624. url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_224.pth'),
  625. 'xcit_tiny_12_p16_224.fb_dist_in1k': _cfg(
  626. hf_hub_id='timm/',
  627. url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_224_dist.pth'),
  628. 'xcit_tiny_12_p16_384.fb_dist_in1k': _cfg(
  629. hf_hub_id='timm/',
  630. url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_384_dist.pth', input_size=(3, 384, 384)),
  631. 'xcit_tiny_24_p16_224.fb_in1k': _cfg(
  632. hf_hub_id='timm/',
  633. url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_224.pth'),
  634. 'xcit_tiny_24_p16_224.fb_dist_in1k': _cfg(
  635. hf_hub_id='timm/',
  636. url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_224_dist.pth'),
  637. 'xcit_tiny_24_p16_384.fb_dist_in1k': _cfg(
  638. hf_hub_id='timm/',
  639. url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_384_dist.pth', input_size=(3, 384, 384)),
  640. 'xcit_small_12_p16_224.fb_in1k': _cfg(
  641. hf_hub_id='timm/',
  642. url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_224.pth'),
  643. 'xcit_small_12_p16_224.fb_dist_in1k': _cfg(
  644. hf_hub_id='timm/',
  645. url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_224_dist.pth'),
  646. 'xcit_small_12_p16_384.fb_dist_in1k': _cfg(
  647. hf_hub_id='timm/',
  648. url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_384_dist.pth', input_size=(3, 384, 384)),
  649. 'xcit_small_24_p16_224.fb_in1k': _cfg(
  650. hf_hub_id='timm/',
  651. url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_224.pth'),
  652. 'xcit_small_24_p16_224.fb_dist_in1k': _cfg(
  653. hf_hub_id='timm/',
  654. url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_224_dist.pth'),
  655. 'xcit_small_24_p16_384.fb_dist_in1k': _cfg(
  656. hf_hub_id='timm/',
  657. url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_384_dist.pth', input_size=(3, 384, 384)),
  658. 'xcit_medium_24_p16_224.fb_in1k': _cfg(
  659. hf_hub_id='timm/',
  660. url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_224.pth'),
  661. 'xcit_medium_24_p16_224.fb_dist_in1k': _cfg(
  662. hf_hub_id='timm/',
  663. url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_224_dist.pth'),
  664. 'xcit_medium_24_p16_384.fb_dist_in1k': _cfg(
  665. hf_hub_id='timm/',
  666. url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_384_dist.pth', input_size=(3, 384, 384)),
  667. 'xcit_large_24_p16_224.fb_in1k': _cfg(
  668. hf_hub_id='timm/',
  669. url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_224.pth'),
  670. 'xcit_large_24_p16_224.fb_dist_in1k': _cfg(
  671. hf_hub_id='timm/',
  672. url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_224_dist.pth'),
  673. 'xcit_large_24_p16_384.fb_dist_in1k': _cfg(
  674. hf_hub_id='timm/',
  675. url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_384_dist.pth', input_size=(3, 384, 384)),
  676. # Patch size 8
  677. 'xcit_nano_12_p8_224.fb_in1k': _cfg(
  678. hf_hub_id='timm/',
  679. url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_224.pth'),
  680. 'xcit_nano_12_p8_224.fb_dist_in1k': _cfg(
  681. hf_hub_id='timm/',
  682. url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_224_dist.pth'),
  683. 'xcit_nano_12_p8_384.fb_dist_in1k': _cfg(
  684. hf_hub_id='timm/',
  685. url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_384_dist.pth', input_size=(3, 384, 384)),
  686. 'xcit_tiny_12_p8_224.fb_in1k': _cfg(
  687. hf_hub_id='timm/',
  688. url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_224.pth'),
  689. 'xcit_tiny_12_p8_224.fb_dist_in1k': _cfg(
  690. hf_hub_id='timm/',
  691. url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_224_dist.pth'),
  692. 'xcit_tiny_12_p8_384.fb_dist_in1k': _cfg(
  693. hf_hub_id='timm/',
  694. url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_384_dist.pth', input_size=(3, 384, 384)),
  695. 'xcit_tiny_24_p8_224.fb_in1k': _cfg(
  696. hf_hub_id='timm/',
  697. url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_224.pth'),
  698. 'xcit_tiny_24_p8_224.fb_dist_in1k': _cfg(
  699. hf_hub_id='timm/',
  700. url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_224_dist.pth'),
  701. 'xcit_tiny_24_p8_384.fb_dist_in1k': _cfg(
  702. hf_hub_id='timm/',
  703. url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_384_dist.pth', input_size=(3, 384, 384)),
  704. 'xcit_small_12_p8_224.fb_in1k': _cfg(
  705. hf_hub_id='timm/',
  706. url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_224.pth'),
  707. 'xcit_small_12_p8_224.fb_dist_in1k': _cfg(
  708. hf_hub_id='timm/',
  709. url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_224_dist.pth'),
  710. 'xcit_small_12_p8_384.fb_dist_in1k': _cfg(
  711. hf_hub_id='timm/',
  712. url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_384_dist.pth', input_size=(3, 384, 384)),
  713. 'xcit_small_24_p8_224.fb_in1k': _cfg(
  714. hf_hub_id='timm/',
  715. url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_224.pth'),
  716. 'xcit_small_24_p8_224.fb_dist_in1k': _cfg(
  717. hf_hub_id='timm/',
  718. url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_224_dist.pth'),
  719. 'xcit_small_24_p8_384.fb_dist_in1k': _cfg(
  720. hf_hub_id='timm/',
  721. url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_384_dist.pth', input_size=(3, 384, 384)),
  722. 'xcit_medium_24_p8_224.fb_in1k': _cfg(
  723. hf_hub_id='timm/',
  724. url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_224.pth'),
  725. 'xcit_medium_24_p8_224.fb_dist_in1k': _cfg(
  726. hf_hub_id='timm/',
  727. url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_224_dist.pth'),
  728. 'xcit_medium_24_p8_384.fb_dist_in1k': _cfg(
  729. hf_hub_id='timm/',
  730. url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_384_dist.pth', input_size=(3, 384, 384)),
  731. 'xcit_large_24_p8_224.fb_in1k': _cfg(
  732. hf_hub_id='timm/',
  733. url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_224.pth'),
  734. 'xcit_large_24_p8_224.fb_dist_in1k': _cfg(
  735. hf_hub_id='timm/',
  736. url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_224_dist.pth'),
  737. 'xcit_large_24_p8_384.fb_dist_in1k': _cfg(
  738. hf_hub_id='timm/',
  739. url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_384_dist.pth', input_size=(3, 384, 384)),
  740. })
  741. @register_model
  742. def xcit_nano_12_p16_224(pretrained=False, **kwargs) -> Xcit:
  743. model_args = dict(
  744. patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False)
  745. model = _create_xcit('xcit_nano_12_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  746. return model
  747. @register_model
  748. def xcit_nano_12_p16_384(pretrained=False, **kwargs) -> Xcit:
  749. model_args = dict(
  750. patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, img_size=384)
  751. model = _create_xcit('xcit_nano_12_p16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  752. return model
  753. @register_model
  754. def xcit_tiny_12_p16_224(pretrained=False, **kwargs) -> Xcit:
  755. model_args = dict(
  756. patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True)
  757. model = _create_xcit('xcit_tiny_12_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  758. return model
  759. @register_model
  760. def xcit_tiny_12_p16_384(pretrained=False, **kwargs) -> Xcit:
  761. model_args = dict(
  762. patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True)
  763. model = _create_xcit('xcit_tiny_12_p16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  764. return model
  765. @register_model
  766. def xcit_small_12_p16_224(pretrained=False, **kwargs) -> Xcit:
  767. model_args = dict(
  768. patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True)
  769. model = _create_xcit('xcit_small_12_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  770. return model
  771. @register_model
  772. def xcit_small_12_p16_384(pretrained=False, **kwargs) -> Xcit:
  773. model_args = dict(
  774. patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True)
  775. model = _create_xcit('xcit_small_12_p16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  776. return model
  777. @register_model
  778. def xcit_tiny_24_p16_224(pretrained=False, **kwargs) -> Xcit:
  779. model_args = dict(
  780. patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True)
  781. model = _create_xcit('xcit_tiny_24_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  782. return model
  783. @register_model
  784. def xcit_tiny_24_p16_384(pretrained=False, **kwargs) -> Xcit:
  785. model_args = dict(
  786. patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True)
  787. model = _create_xcit('xcit_tiny_24_p16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  788. return model
  789. @register_model
  790. def xcit_small_24_p16_224(pretrained=False, **kwargs) -> Xcit:
  791. model_args = dict(
  792. patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
  793. model = _create_xcit('xcit_small_24_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  794. return model
  795. @register_model
  796. def xcit_small_24_p16_384(pretrained=False, **kwargs) -> Xcit:
  797. model_args = dict(
  798. patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
  799. model = _create_xcit('xcit_small_24_p16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  800. return model
  801. @register_model
  802. def xcit_medium_24_p16_224(pretrained=False, **kwargs) -> Xcit:
  803. model_args = dict(
  804. patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
  805. model = _create_xcit('xcit_medium_24_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  806. return model
  807. @register_model
  808. def xcit_medium_24_p16_384(pretrained=False, **kwargs) -> Xcit:
  809. model_args = dict(
  810. patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
  811. model = _create_xcit('xcit_medium_24_p16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  812. return model
  813. @register_model
  814. def xcit_large_24_p16_224(pretrained=False, **kwargs) -> Xcit:
  815. model_args = dict(
  816. patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True)
  817. model = _create_xcit('xcit_large_24_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  818. return model
  819. @register_model
  820. def xcit_large_24_p16_384(pretrained=False, **kwargs) -> Xcit:
  821. model_args = dict(
  822. patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True)
  823. model = _create_xcit('xcit_large_24_p16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  824. return model
  825. # Patch size 8x8 models
  826. @register_model
  827. def xcit_nano_12_p8_224(pretrained=False, **kwargs) -> Xcit:
  828. model_args = dict(
  829. patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False)
  830. model = _create_xcit('xcit_nano_12_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
  831. return model
  832. @register_model
  833. def xcit_nano_12_p8_384(pretrained=False, **kwargs) -> Xcit:
  834. model_args = dict(
  835. patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False)
  836. model = _create_xcit('xcit_nano_12_p8_384', pretrained=pretrained, **dict(model_args, **kwargs))
  837. return model
  838. @register_model
  839. def xcit_tiny_12_p8_224(pretrained=False, **kwargs) -> Xcit:
  840. model_args = dict(
  841. patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True)
  842. model = _create_xcit('xcit_tiny_12_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
  843. return model
  844. @register_model
  845. def xcit_tiny_12_p8_384(pretrained=False, **kwargs) -> Xcit:
  846. model_args = dict(
  847. patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True)
  848. model = _create_xcit('xcit_tiny_12_p8_384', pretrained=pretrained, **dict(model_args, **kwargs))
  849. return model
  850. @register_model
  851. def xcit_small_12_p8_224(pretrained=False, **kwargs) -> Xcit:
  852. model_args = dict(
  853. patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True)
  854. model = _create_xcit('xcit_small_12_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
  855. return model
  856. @register_model
  857. def xcit_small_12_p8_384(pretrained=False, **kwargs) -> Xcit:
  858. model_args = dict(
  859. patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True)
  860. model = _create_xcit('xcit_small_12_p8_384', pretrained=pretrained, **dict(model_args, **kwargs))
  861. return model
  862. @register_model
  863. def xcit_tiny_24_p8_224(pretrained=False, **kwargs) -> Xcit:
  864. model_args = dict(
  865. patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True)
  866. model = _create_xcit('xcit_tiny_24_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
  867. return model
  868. @register_model
  869. def xcit_tiny_24_p8_384(pretrained=False, **kwargs) -> Xcit:
  870. model_args = dict(
  871. patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True)
  872. model = _create_xcit('xcit_tiny_24_p8_384', pretrained=pretrained, **dict(model_args, **kwargs))
  873. return model
  874. @register_model
  875. def xcit_small_24_p8_224(pretrained=False, **kwargs) -> Xcit:
  876. model_args = dict(
  877. patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
  878. model = _create_xcit('xcit_small_24_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
  879. return model
  880. @register_model
  881. def xcit_small_24_p8_384(pretrained=False, **kwargs) -> Xcit:
  882. model_args = dict(
  883. patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
  884. model = _create_xcit('xcit_small_24_p8_384', pretrained=pretrained, **dict(model_args, **kwargs))
  885. return model
  886. @register_model
  887. def xcit_medium_24_p8_224(pretrained=False, **kwargs) -> Xcit:
  888. model_args = dict(
  889. patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
  890. model = _create_xcit('xcit_medium_24_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
  891. return model
  892. @register_model
  893. def xcit_medium_24_p8_384(pretrained=False, **kwargs) -> Xcit:
  894. model_args = dict(
  895. patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
  896. model = _create_xcit('xcit_medium_24_p8_384', pretrained=pretrained, **dict(model_args, **kwargs))
  897. return model
  898. @register_model
  899. def xcit_large_24_p8_224(pretrained=False, **kwargs) -> Xcit:
  900. model_args = dict(
  901. patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True)
  902. model = _create_xcit('xcit_large_24_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
  903. return model
  904. @register_model
  905. def xcit_large_24_p8_384(pretrained=False, **kwargs) -> Xcit:
  906. model_args = dict(
  907. patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True)
  908. model = _create_xcit('xcit_large_24_p8_384', pretrained=pretrained, **dict(model_args, **kwargs))
  909. return model
  910. register_model_deprecations(__name__, {
  911. # Patch size 16
  912. 'xcit_nano_12_p16_224_dist': 'xcit_nano_12_p16_224.fb_dist_in1k',
  913. 'xcit_nano_12_p16_384_dist': 'xcit_nano_12_p16_384.fb_dist_in1k',
  914. 'xcit_tiny_12_p16_224_dist': 'xcit_tiny_12_p16_224.fb_dist_in1k',
  915. 'xcit_tiny_12_p16_384_dist': 'xcit_tiny_12_p16_384.fb_dist_in1k',
  916. 'xcit_tiny_24_p16_224_dist': 'xcit_tiny_24_p16_224.fb_dist_in1k',
  917. 'xcit_tiny_24_p16_384_dist': 'xcit_tiny_24_p16_384.fb_dist_in1k',
  918. 'xcit_small_12_p16_224_dist': 'xcit_small_12_p16_224.fb_dist_in1k',
  919. 'xcit_small_12_p16_384_dist': 'xcit_small_12_p16_384.fb_dist_in1k',
  920. 'xcit_small_24_p16_224_dist': 'xcit_small_24_p16_224.fb_dist_in1k',
  921. 'xcit_small_24_p16_384_dist': 'xcit_small_24_p16_384.fb_dist_in1k',
  922. 'xcit_medium_24_p16_224_dist': 'xcit_medium_24_p16_224.fb_dist_in1k',
  923. 'xcit_medium_24_p16_384_dist': 'xcit_medium_24_p16_384.fb_dist_in1k',
  924. 'xcit_large_24_p16_224_dist': 'xcit_large_24_p16_224.fb_dist_in1k',
  925. 'xcit_large_24_p16_384_dist': 'xcit_large_24_p16_384.fb_dist_in1k',
  926. # Patch size 8
  927. 'xcit_nano_12_p8_224_dist': 'xcit_nano_12_p8_224.fb_dist_in1k',
  928. 'xcit_nano_12_p8_384_dist': 'xcit_nano_12_p8_384.fb_dist_in1k',
  929. 'xcit_tiny_12_p8_224_dist': 'xcit_tiny_12_p8_224.fb_dist_in1k',
  930. 'xcit_tiny_12_p8_384_dist': 'xcit_tiny_12_p8_384.fb_dist_in1k',
  931. 'xcit_tiny_24_p8_224_dist': 'xcit_tiny_24_p8_224.fb_dist_in1k',
  932. 'xcit_tiny_24_p8_384_dist': 'xcit_tiny_24_p8_384.fb_dist_in1k',
  933. 'xcit_small_12_p8_224_dist': 'xcit_small_12_p8_224.fb_dist_in1k',
  934. 'xcit_small_12_p8_384_dist': 'xcit_small_12_p8_384.fb_dist_in1k',
  935. 'xcit_small_24_p8_224_dist': 'xcit_small_24_p8_224.fb_dist_in1k',
  936. 'xcit_small_24_p8_384_dist': 'xcit_small_24_p8_384.fb_dist_in1k',
  937. 'xcit_medium_24_p8_224_dist': 'xcit_medium_24_p8_224.fb_dist_in1k',
  938. 'xcit_medium_24_p8_384_dist': 'xcit_medium_24_p8_384.fb_dist_in1k',
  939. 'xcit_large_24_p8_224_dist': 'xcit_large_24_p8_224.fb_dist_in1k',
  940. 'xcit_large_24_p8_384_dist': 'xcit_large_24_p8_384.fb_dist_in1k',
  941. })