tiny_vit.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. # https://github.com/microsoft/Cream/blob/8dc38822b99fff8c262c585a32a4f09ac504d693/TinyViT/models/tiny_vit.py
  18. # https://github.com/ChaoningZhang/MobileSAM/blob/01ea8d0f5590082f0c1ceb0a3e2272593f20154b/mobile_sam/modeling/tiny_vit_sam.py
  19. from __future__ import annotations
  20. import warnings
  21. from typing import Any, Optional, Sequence
  22. import torch
  23. import torch.nn.functional as F
  24. from torch import nn
  25. from torch.utils import checkpoint
  26. from kornia.contrib.models.common import DropPath, LayerNorm2d, window_partition, window_unpartition
  27. from kornia.core import Module, Tensor
  28. from kornia.core.check import KORNIA_CHECK
  29. def _make_pair(x: int | tuple[int, int]) -> tuple[int, int]:
  30. return (x, x) if isinstance(x, int) else x
  31. class ConvBN(nn.Sequential):
  32. def __init__(
  33. self,
  34. in_channels: int,
  35. out_channels: int,
  36. kernel_size: int,
  37. stride: int = 1,
  38. padding: int = 0,
  39. groups: int = 1,
  40. activation: type[Module] = nn.Identity,
  41. ) -> None:
  42. super().__init__()
  43. self.c = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=False)
  44. self.bn = nn.BatchNorm2d(out_channels)
  45. self.act = activation()
  46. class PatchEmbed(nn.Sequential):
  47. def __init__(self, in_channels: int, embed_dim: int, activation: type[Module] = nn.GELU) -> None:
  48. super().__init__()
  49. self.seq = nn.Sequential(
  50. ConvBN(in_channels, embed_dim // 2, 3, 2, 1), activation(), ConvBN(embed_dim // 2, embed_dim, 3, 2, 1)
  51. )
  52. class MBConv(Module):
  53. def __init__(
  54. self,
  55. in_channels: int,
  56. out_channels: int,
  57. expansion_ratio: float,
  58. activation: type[Module] = nn.GELU,
  59. drop_path: float = 0.0,
  60. ) -> None:
  61. super().__init__()
  62. hidden_channels = int(in_channels * expansion_ratio)
  63. self.conv1 = ConvBN(in_channels, hidden_channels, 1, activation=activation) # point-wise
  64. self.conv2 = ConvBN(hidden_channels, hidden_channels, 3, 1, 1, hidden_channels, activation) # depth-wise
  65. self.conv3 = ConvBN(hidden_channels, out_channels, 1)
  66. self.drop_path = DropPath(drop_path)
  67. self.act = activation()
  68. def forward(self, x: Tensor) -> Tensor:
  69. return self.act(x + self.drop_path(self.conv3(self.conv2(self.conv1(x)))))
  70. class PatchMerging(Module):
  71. def __init__(
  72. self,
  73. input_resolution: int | tuple[int, int],
  74. dim: int,
  75. out_dim: int,
  76. stride: int,
  77. activation: type[Module] = nn.GELU,
  78. ) -> None:
  79. KORNIA_CHECK(stride in (1, 2), "stride must be either 1 or 2")
  80. super().__init__()
  81. self.input_resolution = _make_pair(input_resolution)
  82. self.conv1 = ConvBN(dim, out_dim, 1, activation=activation)
  83. self.conv2 = ConvBN(out_dim, out_dim, 3, stride, 1, groups=out_dim, activation=activation)
  84. self.conv3 = ConvBN(out_dim, out_dim, 1)
  85. def forward(self, x: Tensor) -> Tensor:
  86. if x.ndim == 3:
  87. x = x.transpose(1, 2).unflatten(2, self.input_resolution) # (B, H * W, C) -> (B, C, H, W)
  88. x = self.conv3(self.conv2(self.conv1(x)))
  89. x = x.flatten(2).transpose(1, 2) # (B, C, H, W) -> (B, H * W, C)
  90. return x
  91. class ConvLayer(Module):
  92. def __init__(
  93. self,
  94. dim: int,
  95. depth: int,
  96. activation: type[Module] = nn.GELU,
  97. drop_path: float | list[float] = 0.0,
  98. downsample: Optional[Module] = None,
  99. use_checkpoint: bool = False,
  100. conv_expand_ratio: float = 4.0,
  101. ) -> None:
  102. super().__init__()
  103. self.use_checkpoint = use_checkpoint
  104. # build blocks
  105. if not isinstance(drop_path, list):
  106. drop_path = [drop_path] * depth
  107. self.blocks = nn.ModuleList(
  108. [MBConv(dim, dim, conv_expand_ratio, activation, drop_path[i]) for i in range(depth)]
  109. )
  110. # patch merging layer
  111. self.downsample = downsample
  112. def forward(self, x: Tensor) -> Tensor:
  113. for blk in self.blocks:
  114. x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
  115. if self.downsample is not None:
  116. x = self.downsample(x)
  117. return x
  118. class MLP(nn.Sequential):
  119. def __init__(
  120. self,
  121. in_features: int,
  122. hidden_features: int,
  123. out_features: int,
  124. activation: type[Module] = nn.GELU,
  125. drop: float = 0.0,
  126. ) -> None:
  127. super().__init__()
  128. self.norm = nn.LayerNorm(in_features)
  129. self.fc1 = nn.Linear(in_features, hidden_features)
  130. self.act1 = activation()
  131. self.drop1 = nn.Dropout(drop)
  132. self.fc2 = nn.Linear(hidden_features, out_features)
  133. self.drop2 = nn.Dropout(drop)
  134. # NOTE: differences from image_encoder.Attention:
  135. # - different relative position encoding mechanism (separable/decomposed vs joint)
  136. # - this impl supports attn_ratio (increase output size for value), though it is not used
  137. class Attention(Module):
  138. def __init__(
  139. self,
  140. dim: int,
  141. key_dim: int,
  142. num_heads: int = 8,
  143. attn_ratio: float = 4.0,
  144. resolution: tuple[int, int] = (14, 14),
  145. ) -> None:
  146. super().__init__()
  147. self.num_heads = num_heads
  148. self.scale = key_dim**-0.5
  149. self.key_dim = key_dim
  150. self.nh_kd = key_dim * num_heads
  151. self.d = int(attn_ratio * key_dim)
  152. self.dh = int(attn_ratio * key_dim) * num_heads
  153. self.attn_ratio = attn_ratio
  154. h = self.dh + self.nh_kd * 2
  155. self.norm = nn.LayerNorm(dim)
  156. self.qkv = nn.Linear(dim, h)
  157. self.proj = nn.Linear(self.dh, dim)
  158. indices, attn_offset_size = self.build_attention_bias(resolution)
  159. self.attention_biases = nn.Parameter(torch.zeros(num_heads, attn_offset_size))
  160. self.register_buffer("attention_bias_idxs", indices, persistent=False)
  161. self.attention_bias_idxs: Tensor
  162. self.ab: Optional[Tensor] = None
  163. @staticmethod
  164. def build_attention_bias(resolution: tuple[int, int]) -> tuple[Tensor, int]:
  165. H, W = resolution
  166. rows = torch.arange(H)
  167. cols = torch.arange(W)
  168. rr = rows.repeat_interleave(W)
  169. cc = cols.repeat(H)
  170. dr = (rr[:, None] - rr[None, :]).abs()
  171. dc = (cc[:, None] - cc[None, :]).abs()
  172. keys = dr * W + dc
  173. unique_keys, inverse = torch.unique(keys, return_inverse=True)
  174. indices = inverse.view(H * W, H * W)
  175. attn_offset_size = unique_keys.numel()
  176. return indices, attn_offset_size
  177. # is this really necessary?
  178. @torch.no_grad()
  179. def train(self, mode: bool = True) -> Attention:
  180. super().train(mode)
  181. self.ab = None if (mode and self.ab is not None) else self.attention_biases[:, self.attention_bias_idxs]
  182. return self
  183. def forward(self, x: Tensor) -> Tensor:
  184. B, N, _ = x.shape
  185. x = self.norm(x)
  186. qkv = self.qkv(x)
  187. qkv = qkv.view(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
  188. q, k, v = qkv.split([self.key_dim, self.key_dim, self.d], dim=3)
  189. bias = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
  190. attn = (q @ k.transpose(-2, -1)) * self.scale + bias
  191. attn = attn.softmax(dim=-1)
  192. x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
  193. x = self.proj(x)
  194. return x
  195. class TinyViTBlock(Module):
  196. def __init__(
  197. self,
  198. dim: int,
  199. input_resolution: int | tuple[int, int],
  200. num_heads: int,
  201. window_size: int = 7,
  202. mlp_ratio: float = 4.0,
  203. drop: float = 0.0,
  204. drop_path: float = 0.0,
  205. local_conv_size: int = 3,
  206. activation: type[Module] = nn.GELU,
  207. ) -> None:
  208. KORNIA_CHECK(dim % num_heads == 0, "dim must be divislbe by num_heads")
  209. super().__init__()
  210. self.input_resolution = _make_pair(input_resolution)
  211. self.window_size = window_size
  212. head_dim = dim // num_heads
  213. self.attn = Attention(dim, head_dim, num_heads, 1.0, (window_size, window_size))
  214. self.drop_path1 = DropPath(drop_path)
  215. self.local_conv = ConvBN(dim, dim, local_conv_size, 1, local_conv_size // 2, dim)
  216. self.mlp = MLP(dim, int(dim * mlp_ratio), dim, activation, drop)
  217. self.drop_path2 = DropPath(drop_path)
  218. def forward(self, x: Tensor) -> Tensor:
  219. H, W = self.input_resolution
  220. B, L, C = x.shape
  221. res_x = x
  222. x = x.view(B, H, W, C)
  223. x, pad_hw = window_partition(x, self.window_size) # (B * num_windows, window_size, window_size, C)
  224. x = self.attn(x.flatten(1, 2))
  225. x = window_unpartition(x, self.window_size, pad_hw, (H, W))
  226. x = x.view(B, L, C)
  227. x = res_x + self.drop_path1(x)
  228. x = x.transpose(1, 2).reshape(B, C, H, W)
  229. x = self.local_conv(x)
  230. x = x.view(B, C, L).transpose(1, 2)
  231. x = x + self.drop_path2(self.mlp(x))
  232. return x
  233. class BasicLayer(Module):
  234. def __init__(
  235. self,
  236. dim: int,
  237. input_resolution: int | tuple[int, int],
  238. depth: int,
  239. num_heads: int,
  240. window_size: int,
  241. mlp_ratio: float = 4.0,
  242. drop: float = 0.0,
  243. drop_path: float | list[float] = 0.0,
  244. downsample: Optional[Module] = None,
  245. use_checkpoint: bool = False,
  246. local_conv_size: int = 3,
  247. activation: type[Module] = nn.GELU,
  248. ) -> None:
  249. super().__init__()
  250. self.use_checkpoint = use_checkpoint
  251. self.blocks = nn.ModuleList(
  252. [
  253. TinyViTBlock(
  254. dim,
  255. input_resolution,
  256. num_heads,
  257. window_size,
  258. mlp_ratio,
  259. drop,
  260. drop_path[i] if isinstance(drop_path, list) else drop_path,
  261. local_conv_size,
  262. activation,
  263. )
  264. for i in range(depth)
  265. ]
  266. )
  267. # patch merging layer
  268. self.downsample = downsample
  269. def forward(self, x: Tensor) -> Tensor:
  270. for blk in self.blocks:
  271. x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
  272. if self.downsample is not None:
  273. x = self.downsample(x)
  274. return x
  275. class TinyViT(Module):
  276. """TinyViT model, as described in https://arxiv.org/abs/2207.10666.
  277. Args:
  278. img_size: Size of input image.
  279. in_chans: Number of input image's channels.
  280. num_classes: Number of output classes.
  281. embed_dims: List of embedding dimensions.
  282. depths: List of block count for each downsampling stage
  283. num_heads: List of attention heads used in self-attention for each downsampling stage.
  284. window_sizes: List of self-attention's window size for each downsampling stage.
  285. mlp_ratio: Ratio of MLP dimension to embedding dimension in self-attention.
  286. drop_rate: Dropout rate.
  287. drop_path_rate: Stochastic depth rate.
  288. use_checkpoint: Whether to use activation checkpointing to trade compute for memory.
  289. mbconv_expand_ratio: Expansion ratio used in MBConv block.
  290. local_conv_size: Kernel size of convolution used in TinyViTBlock
  291. activation: activation function.
  292. mobile_same: Whether to use modifications for MobileSAM.
  293. """
  294. def __init__(
  295. self,
  296. img_size: int = 224,
  297. in_chans: int = 3,
  298. num_classes: int = 1000,
  299. embed_dims: Sequence[int] = (96, 192, 384, 768),
  300. depths: Sequence[int] = (2, 2, 6, 2),
  301. num_heads: Sequence[int] = (3, 6, 12, 24),
  302. window_sizes: Sequence[int] = (7, 7, 14, 7),
  303. mlp_ratio: float = 4.0,
  304. drop_rate: float = 0.0,
  305. drop_path_rate: float = 0.0,
  306. use_checkpoint: bool = False,
  307. mbconv_expand_ratio: float = 4.0,
  308. local_conv_size: int = 3,
  309. # layer_lr_decay: float = 1.0,
  310. activation: type[Module] = nn.GELU,
  311. mobile_sam: bool = False,
  312. ) -> None:
  313. super().__init__()
  314. self.img_size = img_size
  315. self.mobile_sam = mobile_sam
  316. self.neck: Optional[Module]
  317. if mobile_sam:
  318. # MobileSAM adjusts the stride to match the total stride of other ViT backbones
  319. # used in the original SAM (stride 16)
  320. strides = [2, 2, 1, 1]
  321. self.neck = nn.Sequential(
  322. nn.Conv2d(embed_dims[-1], 256, 1, bias=False),
  323. LayerNorm2d(256),
  324. nn.Conv2d(256, 256, 3, 1, 1, bias=False),
  325. LayerNorm2d(256),
  326. )
  327. else:
  328. strides = [2, 2, 2, 1]
  329. self.neck = None
  330. self.patch_embed = PatchEmbed(in_chans, embed_dims[0], activation)
  331. input_resolution = img_size // 4
  332. # NOTE: if we don't support training, this might be unimportant
  333. # stochastic depth decay rule
  334. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
  335. # build layers
  336. n_layers = len(depths)
  337. layers = []
  338. for i_layer, (embed_dim, depth, num_heads_i, window_size, stride) in enumerate(
  339. zip(embed_dims, depths, num_heads, window_sizes, strides)
  340. ):
  341. out_dim = embed_dims[min(i_layer + 1, len(embed_dims) - 1)]
  342. downsample = (
  343. PatchMerging(input_resolution, embed_dim, out_dim, stride, activation)
  344. if (i_layer < n_layers - 1)
  345. else None
  346. )
  347. kwargs: dict[str, Any] = {
  348. "dim": embed_dim,
  349. "depth": depth,
  350. "drop_path": dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
  351. "downsample": downsample,
  352. "use_checkpoint": use_checkpoint,
  353. "activation": activation,
  354. }
  355. layer: ConvLayer | BasicLayer
  356. if i_layer == 0:
  357. layer = ConvLayer(conv_expand_ratio=mbconv_expand_ratio, **kwargs)
  358. else:
  359. layer = BasicLayer(
  360. input_resolution=input_resolution,
  361. num_heads=num_heads_i,
  362. window_size=window_size,
  363. mlp_ratio=mlp_ratio,
  364. drop=drop_rate,
  365. local_conv_size=local_conv_size,
  366. **kwargs,
  367. )
  368. layers.append(layer)
  369. input_resolution //= stride
  370. self.layers = nn.Sequential(*layers)
  371. self.feat_size = input_resolution # final feature map size
  372. # Classifier head
  373. # NOTE: this is redundant for MobileSAM, but we still need it
  374. # to load pre-trained weights with strict=True
  375. # TODO: enable strict=False, or host our own weights
  376. self.norm_head = nn.LayerNorm(embed_dims[-1])
  377. self.head = nn.Linear(embed_dims[-1], num_classes)
  378. def forward(self, x: Tensor) -> Tensor:
  379. """Classify images if ``mobile_sam=False``, produce feature maps if ``mobile_sam=True``."""
  380. x = self.patch_embed(x)
  381. x = self.layers(x)
  382. if self.mobile_sam:
  383. # MobileSAM
  384. x = x.unflatten(1, (self.feat_size, self.feat_size)).permute(0, 3, 1, 2)
  385. x = self.neck(x) # type: ignore
  386. else:
  387. # classification
  388. x = x.mean(1)
  389. x = self.head(self.norm_head(x))
  390. return x
  391. @staticmethod
  392. def from_config(variant: str, pretrained: bool | str = False, **kwargs: Any) -> TinyViT:
  393. """Create a TinyViT model from pre-defined variants.
  394. Args:
  395. variant: TinyViT variant. Possible values: ``'5m'``, ``'11m'``, ``'21m'``.
  396. pretrained: whether to use pre-trained weights. Possible values: ``False``, ``True``, ``'in22k'``,
  397. ``'in1k'``. For TinyViT-21M (``variant='21m'``), ``'in1k_384'``, ``'in1k_512'`` are also available.
  398. **kwargs: other keyword arguments that will be passed to :class:`TinyViT`.
  399. .. note::
  400. When ``img_size`` is different from the pre-trained size, bicubic interpolation will be performed on
  401. attention biases. When using ``pretrained=True``, ImageNet-1k checkpoint (``'in1k'``) is used.
  402. For feature extraction or fine-tuning, ImageNet-22k checkpoint (``'in22k'``) is preferred.
  403. """
  404. KORNIA_CHECK(variant in ("5m", "11m", "21m"), "Only variant 5m, 11m, and 21m are supported")
  405. return {"5m": _tiny_vit_5m, "11m": _tiny_vit_11m, "21m": _tiny_vit_21m}[variant](pretrained, **kwargs)
  406. def _load_pretrained(model: TinyViT, url: str) -> TinyViT:
  407. model_state_dict = model.state_dict()
  408. state_dict = torch.hub.load_state_dict_from_url(url)
  409. # official checkpoint has "model" key
  410. if "model" in state_dict:
  411. state_dict = state_dict["model"]
  412. # https://github.com/microsoft/Cream/blob/8dc38822b99fff8c262c585a32a4f09ac504d693/TinyViT/utils.py#L163
  413. # bicubic interpolate attention biases
  414. ab_keys = [k for k in state_dict.keys() if "attention_biases" in k]
  415. for k in ab_keys:
  416. n_heads1, L1 = state_dict[k].shape
  417. n_heads2, L2 = model_state_dict[k].shape
  418. KORNIA_CHECK(n_heads1 == n_heads2, f"Fail to load {k}. Pre-trained checkpoint should have num_heads={n_heads1}")
  419. if L1 != L2:
  420. S1 = int(L1**0.5)
  421. S2 = int(L2**0.5)
  422. attention_biases = state_dict[k].view(1, n_heads1, S1, S1)
  423. attention_biases = F.interpolate(attention_biases, size=(S2, S2), mode="bicubic")
  424. state_dict[k] = attention_biases.view(n_heads2, L2)
  425. if state_dict["head.weight"].shape[0] != model.head.out_features:
  426. msg = "Number of classes does not match pre-trained checkpoint's. Resetting classification head to zeros"
  427. warnings.warn(msg, stacklevel=1)
  428. state_dict["head.weight"] = torch.zeros_like(model.head.weight)
  429. state_dict["head.bias"] = torch.zeros_like(model.head.bias)
  430. model.load_state_dict(state_dict)
  431. return model
  432. def _tiny_vit_5m(pretrained: bool | str = False, **kwargs: Any) -> TinyViT:
  433. model = TinyViT(
  434. embed_dims=[64, 128, 160, 320],
  435. depths=[2, 2, 6, 2],
  436. num_heads=[2, 4, 5, 10],
  437. window_sizes=[7, 7, 14, 7],
  438. drop_path_rate=0.0,
  439. **kwargs,
  440. )
  441. if pretrained:
  442. if pretrained is True:
  443. pretrained = "in1k"
  444. url = {
  445. "in22k": (
  446. "https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_22k_distill.pth"
  447. ),
  448. "in1k": "https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_22kto1k_distill.pth",
  449. }[pretrained]
  450. model = _load_pretrained(model, url)
  451. return model
  452. def _tiny_vit_11m(pretrained: bool | str = False, **kwargs: Any) -> TinyViT:
  453. model = TinyViT(
  454. embed_dims=[64, 128, 256, 448],
  455. depths=[2, 2, 6, 2],
  456. num_heads=[2, 4, 8, 14],
  457. window_sizes=[7, 7, 14, 7],
  458. drop_path_rate=0.1,
  459. **kwargs,
  460. )
  461. if pretrained:
  462. if pretrained is True:
  463. pretrained = "in1k"
  464. url = {
  465. "in22k": (
  466. "https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_22k_distill.pth"
  467. ),
  468. "in1k": "https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_22kto1k_distill.pth",
  469. }[pretrained]
  470. model = _load_pretrained(model, url)
  471. return model
  472. def _tiny_vit_21m(pretrained: bool | str = False, **kwargs: Any) -> TinyViT:
  473. model = TinyViT(
  474. embed_dims=[96, 192, 384, 576],
  475. depths=[2, 2, 6, 2],
  476. num_heads=[3, 6, 12, 18],
  477. window_sizes=[7, 7, 14, 7],
  478. drop_path_rate=0.2,
  479. **kwargs,
  480. )
  481. if pretrained:
  482. if pretrained is True:
  483. pretrained = "in1k"
  484. img_size = kwargs.get("img_size", 224)
  485. if img_size >= 384:
  486. pretrained = "in1k_384"
  487. if img_size >= 512:
  488. pretrained = "in1k_512"
  489. url = {
  490. "in22k": (
  491. "https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22k_distill.pth"
  492. ),
  493. "in1k": "https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_distill.pth",
  494. "in1k_384": "https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_384_distill.pth",
  495. "in1k_512": "https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_512_distill.pth",
  496. }[pretrained]
  497. model = _load_pretrained(model, url)
  498. return model