| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089 |
- """ Cross-Covariance Image Transformer (XCiT) in PyTorch
- Paper:
- - https://arxiv.org/abs/2106.09681
- Same as the official implementation, with some minor adaptations, original copyright below
- - https://github.com/facebookresearch/xcit/blob/master/xcit.py
- Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
- """
- # Copyright (c) 2015-present, Facebook, Inc.
- # All rights reserved.
- import math
- from functools import partial
- from typing import List, Optional, Tuple, Union, Type, Any
- import torch
- import torch.nn as nn
- from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
- from timm.layers import DropPath, trunc_normal_, to_2tuple, use_fused_attn, Mlp
- from ._builder import build_model_with_cfg
- from ._features import feature_take_indices
- from ._features_fx import register_notrace_module
- from ._manipulate import checkpoint
- from ._registry import register_model, generate_default_cfgs, register_model_deprecations
- from .cait import ClassAttn
- __all__ = ['Xcit'] # model_registry will add each entrypoint fn to this
- @register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method
- class PositionalEncodingFourier(nn.Module):
- """
- Positional encoding relying on a fourier kernel matching the one used in the "Attention is all you Need" paper.
- Based on the official XCiT code
- - https://github.com/facebookresearch/xcit/blob/master/xcit.py
- """
- def __init__(
- self,
- hidden_dim: int = 32,
- dim: int = 768,
- temperature: float = 10000,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1, **dd)
- self.scale = 2 * math.pi
- self.temperature = temperature
- self.hidden_dim = hidden_dim
- self.dim = dim
- self.eps = 1e-6
- def forward(self, B: int, H: int, W: int):
- device = self.token_projection.weight.device
- dtype = self.token_projection.weight.dtype
- y_embed = torch.arange(1, H + 1, device=device).to(torch.float32).unsqueeze(1).repeat(1, 1, W)
- x_embed = torch.arange(1, W + 1, device=device).to(torch.float32).repeat(1, H, 1)
- y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale
- x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale
- dim_t = torch.arange(self.hidden_dim, device=device).to(torch.float32)
- dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim)
- pos_x = x_embed[:, :, :, None] / dim_t
- pos_y = y_embed[:, :, :, None] / dim_t
- pos_x = torch.stack([pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()], dim=4).flatten(3)
- pos_y = torch.stack([pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()], dim=4).flatten(3)
- pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
- pos = self.token_projection(pos.to(dtype))
- return pos.repeat(B, 1, 1, 1) # (B, C, H, W)
- def conv3x3(in_planes, out_planes, stride=1, device=None, dtype=None):
- """3x3 convolution + batch norm"""
- dd = {'device': device, 'dtype': dtype}
- return torch.nn.Sequential(
- nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False, **dd),
- nn.BatchNorm2d(out_planes, **dd)
- )
- class ConvPatchEmbed(nn.Module):
- """Image to Patch Embedding using multiple convolutional layers"""
- def __init__(
- self,
- img_size: int = 224,
- patch_size: int = 16,
- in_chans: int = 3,
- embed_dim: int = 768,
- act_layer: Type[nn.Module] = nn.GELU,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- img_size = to_2tuple(img_size)
- num_patches = (img_size[1] // patch_size) * (img_size[0] // patch_size)
- self.img_size = img_size
- self.patch_size = patch_size
- self.num_patches = num_patches
- if patch_size == 16:
- self.proj = torch.nn.Sequential(
- conv3x3(in_chans, embed_dim // 8, 2, **dd),
- act_layer(),
- conv3x3(embed_dim // 8, embed_dim // 4, 2, **dd),
- act_layer(),
- conv3x3(embed_dim // 4, embed_dim // 2, 2, **dd),
- act_layer(),
- conv3x3(embed_dim // 2, embed_dim, 2, **dd),
- )
- elif patch_size == 8:
- self.proj = torch.nn.Sequential(
- conv3x3(in_chans, embed_dim // 4, 2, **dd),
- act_layer(),
- conv3x3(embed_dim // 4, embed_dim // 2, 2, **dd),
- act_layer(),
- conv3x3(embed_dim // 2, embed_dim, 2, **dd),
- )
- else:
- raise('For convolutional projection, patch size has to be in [8, 16]')
- def forward(self, x):
- x = self.proj(x)
- Hp, Wp = x.shape[2], x.shape[3]
- x = x.flatten(2).transpose(1, 2) # (B, N, C)
- return x, (Hp, Wp)
- class LPI(nn.Module):
- """
- Local Patch Interaction module that allows explicit communication between tokens in 3x3 windows to augment the
- implicit communication performed by the block diagonal scatter attention. Implemented using 2 layers of separable
- 3x3 convolutions with GeLU and BatchNorm2d
- """
- def __init__(
- self,
- in_features: int,
- out_features: Optional[int] = None,
- act_layer: Type[nn.Module] = nn.GELU,
- kernel_size: int = 3,
- device=None,
- dtype=None,
- ):
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- out_features = out_features or in_features
- padding = kernel_size // 2
- self.conv1 = torch.nn.Conv2d(
- in_features, in_features, kernel_size=kernel_size, padding=padding, groups=in_features, **dd)
- self.act = act_layer()
- self.bn = nn.BatchNorm2d(in_features, **dd)
- self.conv2 = torch.nn.Conv2d(
- in_features, out_features, kernel_size=kernel_size, padding=padding, groups=out_features, **dd)
- def forward(self, x, H: int, W: int):
- B, N, C = x.shape
- x = x.permute(0, 2, 1).reshape(B, C, H, W)
- x = self.conv1(x)
- x = self.act(x)
- x = self.bn(x)
- x = self.conv2(x)
- x = x.reshape(B, C, N).permute(0, 2, 1)
- return x
- class ClassAttentionBlock(nn.Module):
- """Class Attention Layer as in CaiT https://arxiv.org/abs/2103.17239"""
- def __init__(
- self,
- dim: int,
- num_heads: int,
- mlp_ratio: float = 4.,
- qkv_bias: bool = False,
- proj_drop: float = 0.,
- attn_drop: float = 0.,
- drop_path: float = 0.,
- act_layer: Type[nn.Module] = nn.GELU,
- norm_layer: Type[nn.Module] = nn.LayerNorm,
- eta: Optional[float] = 1.,
- tokens_norm: bool = False,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.norm1 = norm_layer(dim, **dd)
- self.attn = ClassAttn(
- dim,
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- attn_drop=attn_drop,
- proj_drop=proj_drop,
- **dd,
- )
- self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim, **dd)
- self.mlp = Mlp(
- in_features=dim,
- hidden_features=int(dim * mlp_ratio),
- act_layer=act_layer,
- drop=proj_drop,
- **dd,
- )
- self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- if eta is not None: # LayerScale Initialization (no layerscale when None)
- self.gamma1 = nn.Parameter(eta * torch.ones(dim, **dd))
- self.gamma2 = nn.Parameter(eta * torch.ones(dim, **dd))
- else:
- self.gamma1, self.gamma2 = 1.0, 1.0
- # See https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721
- self.tokens_norm = tokens_norm
- def forward(self, x):
- x_norm1 = self.norm1(x)
- x_attn = torch.cat([self.attn(x_norm1), x_norm1[:, 1:]], dim=1)
- x = x + self.drop_path1(self.gamma1 * x_attn)
- if self.tokens_norm:
- x = self.norm2(x)
- else:
- x = torch.cat([self.norm2(x[:, 0:1]), x[:, 1:]], dim=1)
- x_res = x
- cls_token = x[:, 0:1]
- cls_token = self.gamma2 * self.mlp(cls_token)
- x = torch.cat([cls_token, x[:, 1:]], dim=1)
- x = x_res + self.drop_path2(x)
- return x
- class XCA(nn.Module):
- fused_attn: torch.jit.Final[bool]
- """ Cross-Covariance Attention (XCA)
- Operation where the channels are updated using a weighted sum. The weights are obtained from the (softmax
- normalized) Cross-covariance matrix (Q^T \\cdot K \\in d_h \\times d_h)
- """
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- qkv_bias: bool = False,
- attn_drop: float = 0.,
- proj_drop: float = 0.,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.num_heads = num_heads
- self.fused_attn = use_fused_attn(experimental=True)
- self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1, **dd))
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim, **dd)
- self.proj_drop = nn.Dropout(proj_drop)
- def forward(self, x):
- B, N, C = x.shape
- # Result of next line is (qkv, B, num (H)eads, (C')hannels per head, N)
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1)
- q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
- if self.fused_attn:
- q = torch.nn.functional.normalize(q, dim=-1) * self.temperature
- k = torch.nn.functional.normalize(k, dim=-1)
- x = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=1.0)
- else:
- # Paper section 3.2 l2-Normalization and temperature scaling
- q = torch.nn.functional.normalize(q, dim=-1)
- k = torch.nn.functional.normalize(k, dim=-1)
- attn = (q @ k.transpose(-2, -1)) * self.temperature
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
- x = attn @ v
- x = x.permute(0, 3, 1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
- @torch.jit.ignore
- def no_weight_decay(self):
- return {'temperature'}
- class XCABlock(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int,
- mlp_ratio: float = 4.,
- qkv_bias: bool = False,
- proj_drop: float = 0.,
- attn_drop: float = 0.,
- drop_path: float = 0.,
- act_layer: Type[nn.Module] = nn.GELU,
- norm_layer: Type[nn.Module] = nn.LayerNorm,
- eta: float = 1.,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.norm1 = norm_layer(dim, **dd)
- self.attn = XCA(
- dim,
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- attn_drop=attn_drop,
- proj_drop=proj_drop,
- **dd,
- )
- self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm3 = norm_layer(dim, **dd)
- self.local_mp = LPI(in_features=dim, act_layer=act_layer, **dd)
- self.drop_path3 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim, **dd)
- self.mlp = Mlp(
- in_features=dim,
- hidden_features=int(dim * mlp_ratio),
- act_layer=act_layer,
- drop=proj_drop,
- **dd,
- )
- self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.gamma1 = nn.Parameter(eta * torch.ones(dim, **dd))
- self.gamma3 = nn.Parameter(eta * torch.ones(dim, **dd))
- self.gamma2 = nn.Parameter(eta * torch.ones(dim, **dd))
- def forward(self, x, H: int, W: int):
- x = x + self.drop_path1(self.gamma1 * self.attn(self.norm1(x)))
- # NOTE official code has 3 then 2, so keeping it the same to be consistent with loaded weights
- # See https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721
- x = x + self.drop_path3(self.gamma3 * self.local_mp(self.norm3(x), H, W))
- x = x + self.drop_path2(self.gamma2 * self.mlp(self.norm2(x)))
- return x
- class Xcit(nn.Module):
- """
- Based on timm and DeiT code bases
- https://github.com/rwightman/pytorch-image-models/tree/master/timm
- https://github.com/facebookresearch/deit/
- """
- def __init__(
- self,
- img_size: Union[int, Tuple[int, int]] = 224,
- patch_size: int = 16,
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'token',
- embed_dim: int = 768,
- depth: int = 12,
- num_heads: int = 12,
- mlp_ratio: float = 4.,
- qkv_bias: bool = True,
- drop_rate: float = 0.,
- pos_drop_rate: float = 0.,
- proj_drop_rate: float = 0.,
- attn_drop_rate: float = 0.,
- drop_path_rate: float = 0.,
- act_layer: Optional[Type[nn.Module]] = None,
- norm_layer: Optional[Type[nn.Module]] = None,
- cls_attn_layers: int = 2,
- use_pos_embed: bool = True,
- eta: float = 1.,
- tokens_norm: bool = False,
- device=None,
- dtype=None,
- ):
- """
- Args:
- img_size (int, tuple): input image size
- patch_size (int): patch size
- in_chans (int): number of input channels
- num_classes (int): number of classes for classification head
- embed_dim (int): embedding dimension
- depth (int): depth of transformer
- num_heads (int): number of attention heads
- mlp_ratio (int): ratio of mlp hidden dim to embedding dim
- qkv_bias (bool): enable bias for qkv if True
- drop_rate (float): dropout rate after positional embedding, and in XCA/CA projection + MLP
- pos_drop_rate: position embedding dropout rate
- proj_drop_rate (float): projection dropout rate
- attn_drop_rate (float): attention dropout rate
- drop_path_rate (float): stochastic depth rate (constant across all layers)
- norm_layer: (nn.Module): normalization layer
- cls_attn_layers: (int) Depth of Class attention layers
- use_pos_embed: (bool) whether to use positional encoding
- eta: (float) layerscale initialization value
- tokens_norm: (bool) Whether to normalize all tokens or just the cls_token in the CA
- Notes:
- - Although `layer_norm` is user specifiable, there are hard-coded `BatchNorm2d`s in the local patch
- interaction (class LPI) and the patch embedding (class ConvPatchEmbed)
- """
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- assert global_pool in ('', 'avg', 'token')
- img_size = to_2tuple(img_size)
- assert (img_size[0] % patch_size == 0) and (img_size[0] % patch_size == 0), \
- '`patch_size` should divide image dimensions evenly'
- norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
- act_layer = act_layer or nn.GELU
- self.num_classes = num_classes
- self.in_chans = in_chans
- self.num_features = self.head_hidden_size = self.embed_dim = embed_dim
- self.global_pool = global_pool
- self.grad_checkpointing = False
- self.patch_embed = ConvPatchEmbed(
- img_size=img_size,
- patch_size=patch_size,
- in_chans=in_chans,
- embed_dim=embed_dim,
- act_layer=act_layer,
- **dd,
- )
- r = patch_size
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, **dd))
- if use_pos_embed:
- self.pos_embed = PositionalEncodingFourier(dim=embed_dim, **dd)
- else:
- self.pos_embed = None
- self.pos_drop = nn.Dropout(p=pos_drop_rate)
- self.blocks = nn.ModuleList([
- XCABlock(
- dim=embed_dim,
- num_heads=num_heads,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- proj_drop=proj_drop_rate,
- attn_drop=attn_drop_rate,
- drop_path=drop_path_rate,
- act_layer=act_layer,
- norm_layer=norm_layer,
- eta=eta,
- **dd,
- )
- for _ in range(depth)])
- self.feature_info = [dict(num_chs=embed_dim, reduction=r, module=f'blocks.{i}') for i in range(depth)]
- self.cls_attn_blocks = nn.ModuleList([
- ClassAttentionBlock(
- dim=embed_dim,
- num_heads=num_heads,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- proj_drop=drop_rate,
- attn_drop=attn_drop_rate,
- act_layer=act_layer,
- norm_layer=norm_layer,
- eta=eta,
- tokens_norm=tokens_norm,
- **dd,
- )
- for _ in range(cls_attn_layers)])
- # Classifier head
- self.norm = norm_layer(embed_dim, **dd)
- self.head_drop = nn.Dropout(drop_rate)
- self.head = nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity()
- # Init weights
- trunc_normal_(self.cls_token, std=.02)
- self.apply(self._init_weights)
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- @torch.jit.ignore
- def no_weight_decay(self):
- return {'pos_embed', 'cls_token'}
- @torch.jit.ignore
- def group_matcher(self, coarse=False):
- return dict(
- stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
- blocks=r'^blocks\.(\d+)',
- cls_attn_blocks=[(r'^cls_attn_blocks\.(\d+)', None), (r'^norm', (99999,))]
- )
- @torch.jit.ignore
- def set_grad_checkpointing(self, enable=True):
- self.grad_checkpointing = enable
- @torch.jit.ignore
- def get_classifier(self) -> nn.Module:
- return self.head
- def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
- self.num_classes = num_classes
- if global_pool is not None:
- assert global_pool in ('', 'avg', 'token')
- self.global_pool = global_pool
- device = self.head.weight.device if hasattr(self.head, 'weight') else None
- dtype = self.head.weight.dtype if hasattr(self.head, 'weight') else None
- self.head = nn.Linear(self.num_features, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity()
- def forward_intermediates(
- self,
- x: torch.Tensor,
- indices: Optional[Union[int, List[int]]] = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
- ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
- """ Forward features that returns intermediates.
- Args:
- x: Input image tensor
- indices: Take last n blocks if int, all if None, select matching indices if sequence
- norm: Apply norm layer to all intermediates
- stop_early: Stop iterating over blocks when last desired intermediate hit
- output_fmt: Shape of intermediate feature outputs
- intermediates_only: Only return intermediate features
- Returns:
- """
- assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
- reshape = output_fmt == 'NCHW'
- intermediates = []
- take_indices, max_index = feature_take_indices(len(self.blocks), indices)
- # forward pass
- B, _, height, width = x.shape
- x, (Hp, Wp) = self.patch_embed(x)
- if self.pos_embed is not None:
- # `pos_embed` (B, C, Hp, Wp), reshape -> (B, C, N), permute -> (B, N, C)
- pos_encoding = self.pos_embed(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
- x = x + pos_encoding
- x = self.pos_drop(x)
- if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
- blocks = self.blocks
- else:
- blocks = self.blocks[:max_index + 1]
- for i, blk in enumerate(blocks):
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint(blk, x, Hp, Wp)
- else:
- x = blk(x, Hp, Wp)
- if i in take_indices:
- # normalize intermediates with final norm layer if enabled
- intermediates.append(self.norm(x) if norm else x)
- # process intermediates
- if reshape:
- # reshape to BCHW output format
- intermediates = [y.reshape(B, Hp, Wp, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
- if intermediates_only:
- return intermediates
- # NOTE not supporting return of class tokens
- x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
- for blk in self.cls_attn_blocks:
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint(blk, x)
- else:
- x = blk(x)
- x = self.norm(x)
- return x, intermediates
- def prune_intermediate_layers(
- self,
- indices: Union[int, List[int]] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
- ):
- """ Prune layers not required for specified intermediates.
- """
- take_indices, max_index = feature_take_indices(len(self.blocks), indices)
- self.blocks = self.blocks[:max_index + 1] # truncate blocks
- if prune_norm:
- self.norm = nn.Identity()
- if prune_head:
- self.cls_attn_blocks = nn.ModuleList() # prune token blocks with head
- self.reset_classifier(0, '')
- return take_indices
- def forward_features(self, x):
- B = x.shape[0]
- # x is (B, N, C). (Hp, Hw) is (height in units of patches, width in units of patches)
- x, (Hp, Wp) = self.patch_embed(x)
- if self.pos_embed is not None:
- # `pos_embed` (B, C, Hp, Wp), reshape -> (B, C, N), permute -> (B, N, C)
- pos_encoding = self.pos_embed(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
- x = x + pos_encoding
- x = self.pos_drop(x)
- for blk in self.blocks:
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint(blk, x, Hp, Wp)
- else:
- x = blk(x, Hp, Wp)
- x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
- for blk in self.cls_attn_blocks:
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint(blk, x)
- else:
- x = blk(x)
- x = self.norm(x)
- return x
- def forward_head(self, x, pre_logits: bool = False):
- if self.global_pool:
- x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
- x = self.head_drop(x)
- return x if pre_logits else self.head(x)
- def forward(self, x):
- x = self.forward_features(x)
- x = self.forward_head(x)
- return x
- def checkpoint_filter_fn(state_dict, model):
- if 'model' in state_dict:
- state_dict = state_dict['model']
- # For consistency with timm's transformer models while being compatible with official weights source we rename
- # pos_embeder to pos_embed. Also account for use_pos_embed == False
- use_pos_embed = getattr(model, 'pos_embed', None) is not None
- pos_embed_keys = [k for k in state_dict if k.startswith('pos_embed')]
- for k in pos_embed_keys:
- if use_pos_embed:
- state_dict[k.replace('pos_embeder.', 'pos_embed.')] = state_dict.pop(k)
- else:
- del state_dict[k]
- # timm's implementation of class attention in CaiT is slightly more efficient as it does not compute query vectors
- # for all tokens, just the class token. To use official weights source we must split qkv into q, k, v
- if 'cls_attn_blocks.0.attn.qkv.weight' in state_dict and 'cls_attn_blocks.0.attn.q.weight' in model.state_dict():
- num_ca_blocks = len(model.cls_attn_blocks)
- for i in range(num_ca_blocks):
- qkv_weight = state_dict.pop(f'cls_attn_blocks.{i}.attn.qkv.weight')
- qkv_weight = qkv_weight.reshape(3, -1, qkv_weight.shape[-1])
- for j, subscript in enumerate('qkv'):
- state_dict[f'cls_attn_blocks.{i}.attn.{subscript}.weight'] = qkv_weight[j]
- qkv_bias = state_dict.pop(f'cls_attn_blocks.{i}.attn.qkv.bias', None)
- if qkv_bias is not None:
- qkv_bias = qkv_bias.reshape(3, -1)
- for j, subscript in enumerate('qkv'):
- state_dict[f'cls_attn_blocks.{i}.attn.{subscript}.bias'] = qkv_bias[j]
- return state_dict
- def _create_xcit(variant, pretrained=False, default_cfg=None, **kwargs):
- out_indices = kwargs.pop('out_indices', 3)
- model = build_model_with_cfg(
- Xcit,
- variant,
- pretrained,
- pretrained_filter_fn=checkpoint_filter_fn,
- feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
- **kwargs,
- )
- return model
- def _cfg(url='', **kwargs):
- return {
- 'url': url,
- 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
- 'crop_pct': 1.0, 'interpolation': 'bicubic', 'fixed_input_size': True,
- 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
- 'first_conv': 'patch_embed.proj.0.0', 'classifier': 'head',
- 'license': 'apache-2.0', **kwargs
- }
- default_cfgs = generate_default_cfgs({
- # Patch size 16
- 'xcit_nano_12_p16_224.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_224.pth'),
- 'xcit_nano_12_p16_224.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_224_dist.pth'),
- 'xcit_nano_12_p16_384.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_384_dist.pth', input_size=(3, 384, 384)),
- 'xcit_tiny_12_p16_224.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_224.pth'),
- 'xcit_tiny_12_p16_224.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_224_dist.pth'),
- 'xcit_tiny_12_p16_384.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_384_dist.pth', input_size=(3, 384, 384)),
- 'xcit_tiny_24_p16_224.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_224.pth'),
- 'xcit_tiny_24_p16_224.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_224_dist.pth'),
- 'xcit_tiny_24_p16_384.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_384_dist.pth', input_size=(3, 384, 384)),
- 'xcit_small_12_p16_224.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_224.pth'),
- 'xcit_small_12_p16_224.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_224_dist.pth'),
- 'xcit_small_12_p16_384.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_384_dist.pth', input_size=(3, 384, 384)),
- 'xcit_small_24_p16_224.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_224.pth'),
- 'xcit_small_24_p16_224.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_224_dist.pth'),
- 'xcit_small_24_p16_384.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_384_dist.pth', input_size=(3, 384, 384)),
- 'xcit_medium_24_p16_224.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_224.pth'),
- 'xcit_medium_24_p16_224.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_224_dist.pth'),
- 'xcit_medium_24_p16_384.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_384_dist.pth', input_size=(3, 384, 384)),
- 'xcit_large_24_p16_224.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_224.pth'),
- 'xcit_large_24_p16_224.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_224_dist.pth'),
- 'xcit_large_24_p16_384.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_384_dist.pth', input_size=(3, 384, 384)),
- # Patch size 8
- 'xcit_nano_12_p8_224.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_224.pth'),
- 'xcit_nano_12_p8_224.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_224_dist.pth'),
- 'xcit_nano_12_p8_384.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_384_dist.pth', input_size=(3, 384, 384)),
- 'xcit_tiny_12_p8_224.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_224.pth'),
- 'xcit_tiny_12_p8_224.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_224_dist.pth'),
- 'xcit_tiny_12_p8_384.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_384_dist.pth', input_size=(3, 384, 384)),
- 'xcit_tiny_24_p8_224.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_224.pth'),
- 'xcit_tiny_24_p8_224.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_224_dist.pth'),
- 'xcit_tiny_24_p8_384.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_384_dist.pth', input_size=(3, 384, 384)),
- 'xcit_small_12_p8_224.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_224.pth'),
- 'xcit_small_12_p8_224.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_224_dist.pth'),
- 'xcit_small_12_p8_384.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_384_dist.pth', input_size=(3, 384, 384)),
- 'xcit_small_24_p8_224.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_224.pth'),
- 'xcit_small_24_p8_224.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_224_dist.pth'),
- 'xcit_small_24_p8_384.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_384_dist.pth', input_size=(3, 384, 384)),
- 'xcit_medium_24_p8_224.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_224.pth'),
- 'xcit_medium_24_p8_224.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_224_dist.pth'),
- 'xcit_medium_24_p8_384.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_384_dist.pth', input_size=(3, 384, 384)),
- 'xcit_large_24_p8_224.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_224.pth'),
- 'xcit_large_24_p8_224.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_224_dist.pth'),
- 'xcit_large_24_p8_384.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_384_dist.pth', input_size=(3, 384, 384)),
- })
- @register_model
- def xcit_nano_12_p16_224(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False)
- model = _create_xcit('xcit_nano_12_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def xcit_nano_12_p16_384(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, img_size=384)
- model = _create_xcit('xcit_nano_12_p16_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def xcit_tiny_12_p16_224(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True)
- model = _create_xcit('xcit_tiny_12_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def xcit_tiny_12_p16_384(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True)
- model = _create_xcit('xcit_tiny_12_p16_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def xcit_small_12_p16_224(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True)
- model = _create_xcit('xcit_small_12_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def xcit_small_12_p16_384(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True)
- model = _create_xcit('xcit_small_12_p16_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def xcit_tiny_24_p16_224(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True)
- model = _create_xcit('xcit_tiny_24_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def xcit_tiny_24_p16_384(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True)
- model = _create_xcit('xcit_tiny_24_p16_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def xcit_small_24_p16_224(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
- model = _create_xcit('xcit_small_24_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def xcit_small_24_p16_384(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
- model = _create_xcit('xcit_small_24_p16_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def xcit_medium_24_p16_224(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
- model = _create_xcit('xcit_medium_24_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def xcit_medium_24_p16_384(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
- model = _create_xcit('xcit_medium_24_p16_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def xcit_large_24_p16_224(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True)
- model = _create_xcit('xcit_large_24_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def xcit_large_24_p16_384(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True)
- model = _create_xcit('xcit_large_24_p16_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- # Patch size 8x8 models
- @register_model
- def xcit_nano_12_p8_224(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False)
- model = _create_xcit('xcit_nano_12_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def xcit_nano_12_p8_384(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False)
- model = _create_xcit('xcit_nano_12_p8_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def xcit_tiny_12_p8_224(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True)
- model = _create_xcit('xcit_tiny_12_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def xcit_tiny_12_p8_384(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True)
- model = _create_xcit('xcit_tiny_12_p8_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def xcit_small_12_p8_224(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True)
- model = _create_xcit('xcit_small_12_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def xcit_small_12_p8_384(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True)
- model = _create_xcit('xcit_small_12_p8_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def xcit_tiny_24_p8_224(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True)
- model = _create_xcit('xcit_tiny_24_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def xcit_tiny_24_p8_384(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True)
- model = _create_xcit('xcit_tiny_24_p8_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def xcit_small_24_p8_224(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
- model = _create_xcit('xcit_small_24_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def xcit_small_24_p8_384(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
- model = _create_xcit('xcit_small_24_p8_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def xcit_medium_24_p8_224(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
- model = _create_xcit('xcit_medium_24_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def xcit_medium_24_p8_384(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
- model = _create_xcit('xcit_medium_24_p8_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def xcit_large_24_p8_224(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True)
- model = _create_xcit('xcit_large_24_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def xcit_large_24_p8_384(pretrained=False, **kwargs) -> Xcit:
- model_args = dict(
- patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True)
- model = _create_xcit('xcit_large_24_p8_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- register_model_deprecations(__name__, {
- # Patch size 16
- 'xcit_nano_12_p16_224_dist': 'xcit_nano_12_p16_224.fb_dist_in1k',
- 'xcit_nano_12_p16_384_dist': 'xcit_nano_12_p16_384.fb_dist_in1k',
- 'xcit_tiny_12_p16_224_dist': 'xcit_tiny_12_p16_224.fb_dist_in1k',
- 'xcit_tiny_12_p16_384_dist': 'xcit_tiny_12_p16_384.fb_dist_in1k',
- 'xcit_tiny_24_p16_224_dist': 'xcit_tiny_24_p16_224.fb_dist_in1k',
- 'xcit_tiny_24_p16_384_dist': 'xcit_tiny_24_p16_384.fb_dist_in1k',
- 'xcit_small_12_p16_224_dist': 'xcit_small_12_p16_224.fb_dist_in1k',
- 'xcit_small_12_p16_384_dist': 'xcit_small_12_p16_384.fb_dist_in1k',
- 'xcit_small_24_p16_224_dist': 'xcit_small_24_p16_224.fb_dist_in1k',
- 'xcit_small_24_p16_384_dist': 'xcit_small_24_p16_384.fb_dist_in1k',
- 'xcit_medium_24_p16_224_dist': 'xcit_medium_24_p16_224.fb_dist_in1k',
- 'xcit_medium_24_p16_384_dist': 'xcit_medium_24_p16_384.fb_dist_in1k',
- 'xcit_large_24_p16_224_dist': 'xcit_large_24_p16_224.fb_dist_in1k',
- 'xcit_large_24_p16_384_dist': 'xcit_large_24_p16_384.fb_dist_in1k',
- # Patch size 8
- 'xcit_nano_12_p8_224_dist': 'xcit_nano_12_p8_224.fb_dist_in1k',
- 'xcit_nano_12_p8_384_dist': 'xcit_nano_12_p8_384.fb_dist_in1k',
- 'xcit_tiny_12_p8_224_dist': 'xcit_tiny_12_p8_224.fb_dist_in1k',
- 'xcit_tiny_12_p8_384_dist': 'xcit_tiny_12_p8_384.fb_dist_in1k',
- 'xcit_tiny_24_p8_224_dist': 'xcit_tiny_24_p8_224.fb_dist_in1k',
- 'xcit_tiny_24_p8_384_dist': 'xcit_tiny_24_p8_384.fb_dist_in1k',
- 'xcit_small_12_p8_224_dist': 'xcit_small_12_p8_224.fb_dist_in1k',
- 'xcit_small_12_p8_384_dist': 'xcit_small_12_p8_384.fb_dist_in1k',
- 'xcit_small_24_p8_224_dist': 'xcit_small_24_p8_224.fb_dist_in1k',
- 'xcit_small_24_p8_384_dist': 'xcit_small_24_p8_384.fb_dist_in1k',
- 'xcit_medium_24_p8_224_dist': 'xcit_medium_24_p8_224.fb_dist_in1k',
- 'xcit_medium_24_p8_384_dist': 'xcit_medium_24_p8_384.fb_dist_in1k',
- 'xcit_large_24_p8_224_dist': 'xcit_large_24_p8_224.fb_dist_in1k',
- 'xcit_large_24_p8_384_dist': 'xcit_large_24_p8_384.fb_dist_in1k',
- })
|