vision_transformer_hybrid.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. """ Hybrid Vision Transformer (ViT) in PyTorch
  2. A PyTorch implement of the Hybrid Vision Transformers as described in:
  3. 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
  4. - https://arxiv.org/abs/2010.11929
  5. `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
  6. - https://arxiv.org/abs/2106.10270
  7. NOTE These hybrid model definitions depend on code in vision_transformer.py.
  8. They were moved here to keep file sizes sane.
  9. Hacked together by / Copyright 2020, Ross Wightman
  10. """
  11. from functools import partial
  12. from typing import Dict, Tuple, Type, Union
  13. import torch
  14. import torch.nn as nn
  15. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  16. from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_ntuple, HybridEmbed
  17. from ._builder import build_model_with_cfg
  18. from ._registry import generate_default_cfgs, register_model, register_model_deprecations
  19. from .resnet import resnet26d, resnet50d
  20. from .resnetv2 import ResNetV2, create_resnetv2_stem
  21. from .vision_transformer import VisionTransformer
  22. class ConvStem(nn.Sequential):
  23. def __init__(
  24. self,
  25. in_chans: int = 3,
  26. depth: int = 3,
  27. channels: Union[int, Tuple[int, ...]] = 64,
  28. kernel_size: Union[int, Tuple[int, ...]] = 3,
  29. stride: Union[int, Tuple[int, ...]] = (2, 2, 2),
  30. padding: Union[str, int, Tuple[int, ...]] = "",
  31. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  32. act_layer: Type[nn.Module] = nn.ReLU,
  33. device=None,
  34. dtype=None,
  35. ):
  36. dd = {'device': device, 'dtype': dtype}
  37. super().__init__()
  38. if isinstance(channels, int):
  39. # a default tiered channel strategy
  40. channels = tuple([channels // 2**i for i in range(depth)][::-1])
  41. kernel_size = to_ntuple(depth)(kernel_size)
  42. padding = to_ntuple(depth)(padding)
  43. assert depth == len(stride) == len(kernel_size) == len(channels)
  44. in_chs = in_chans
  45. for i in range(len(channels)):
  46. last_conv = i == len(channels) - 1
  47. self.add_module(f'{i}', ConvNormAct(
  48. in_chs,
  49. channels[i],
  50. kernel_size=kernel_size[i],
  51. stride=stride[i],
  52. padding=padding[i],
  53. bias=last_conv,
  54. apply_norm=not last_conv,
  55. apply_act=not last_conv,
  56. norm_layer=norm_layer,
  57. act_layer=act_layer,
  58. **dd,
  59. ))
  60. in_chs = channels[i]
  61. def _dd_from_kwargs(**kwargs):
  62. return {'device': kwargs.get('device', None), 'dtype': kwargs.get('dtype', None)}
  63. def _resnetv2(layers=(3, 4, 9), **kwargs):
  64. """ ResNet-V2 backbone helper"""
  65. padding_same = kwargs.get('padding_same', True)
  66. stem_type = 'same' if padding_same else ''
  67. conv_layer = partial(StdConv2dSame, eps=1e-8) if padding_same else partial(StdConv2d, eps=1e-8)
  68. if len(layers):
  69. backbone = ResNetV2(
  70. layers=layers,
  71. num_classes=0,
  72. global_pool='',
  73. in_chans=kwargs.get('in_chans', 3),
  74. preact=False,
  75. stem_type=stem_type,
  76. conv_layer=conv_layer,
  77. **_dd_from_kwargs(**kwargs),
  78. )
  79. else:
  80. backbone = create_resnetv2_stem(
  81. kwargs.get('in_chans', 3),
  82. stem_type=stem_type,
  83. preact=False,
  84. conv_layer=conv_layer,
  85. **_dd_from_kwargs(**kwargs),
  86. )
  87. return backbone
  88. def _convert_mobileclip(state_dict, model, prefix='image_encoder.model.'):
  89. out = {}
  90. for k, v in state_dict.items():
  91. if not k.startswith(prefix):
  92. continue
  93. k = k.replace(prefix, '')
  94. k = k.replace('patch_emb.', 'patch_embed.backbone.')
  95. k = k.replace('block.conv', 'conv')
  96. k = k.replace('block.norm', 'bn')
  97. k = k.replace('post_transformer_norm.', 'norm.')
  98. k = k.replace('pre_norm_mha.0', 'norm1')
  99. k = k.replace('pre_norm_mha.1', 'attn')
  100. k = k.replace('pre_norm_ffn.0', 'norm2')
  101. k = k.replace('pre_norm_ffn.1', 'mlp.fc1')
  102. k = k.replace('pre_norm_ffn.4', 'mlp.fc2')
  103. k = k.replace('qkv_proj.', 'qkv.')
  104. k = k.replace('out_proj.', 'proj.')
  105. k = k.replace('transformer.', 'blocks.')
  106. if k == 'pos_embed.pos_embed.pos_embed':
  107. k = 'pos_embed'
  108. v = v.squeeze(0)
  109. if 'classifier.proj' in k:
  110. bias_k = k.replace('classifier.proj', 'head.bias')
  111. k = k.replace('classifier.proj', 'head.weight')
  112. v = v.T
  113. out[bias_k] = torch.zeros(v.shape[0])
  114. out[k] = v
  115. return out
  116. def checkpoint_filter_fn(
  117. state_dict: Dict[str, torch.Tensor],
  118. model: VisionTransformer,
  119. interpolation: str = 'bicubic',
  120. antialias: bool = True,
  121. ) -> Dict[str, torch.Tensor]:
  122. from .vision_transformer import checkpoint_filter_fn as _filter_fn
  123. if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict:
  124. state_dict = _convert_mobileclip(state_dict, model)
  125. return _filter_fn(state_dict, model, interpolation=interpolation, antialias=antialias)
  126. def _create_vision_transformer_hybrid(variant, backbone, embed_args=None, pretrained=False, **kwargs):
  127. out_indices = kwargs.pop('out_indices', 3)
  128. embed_args = embed_args or {}
  129. embed_layer = partial(HybridEmbed, backbone=backbone, **embed_args)
  130. kwargs.setdefault('embed_layer', embed_layer)
  131. kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
  132. return build_model_with_cfg(
  133. VisionTransformer,
  134. variant,
  135. pretrained,
  136. pretrained_filter_fn=checkpoint_filter_fn,
  137. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  138. **kwargs,
  139. )
  140. def _cfg(url='', **kwargs):
  141. return {
  142. 'url': url,
  143. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  144. 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
  145. 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
  146. 'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head',
  147. 'license': 'apache-2.0',
  148. **kwargs
  149. }
  150. default_cfgs = generate_default_cfgs({
  151. # hybrid in-1k models (weights from official JAX impl where they exist)
  152. 'vit_tiny_r_s16_p8_224.augreg_in21k_ft_in1k': _cfg(
  153. url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
  154. hf_hub_id='timm/',
  155. custom_load=True,
  156. first_conv='patch_embed.backbone.conv'),
  157. 'vit_tiny_r_s16_p8_384.augreg_in21k_ft_in1k': _cfg(
  158. url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
  159. hf_hub_id='timm/',
  160. first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0, custom_load=True),
  161. 'vit_small_r26_s32_224.augreg_in21k_ft_in1k': _cfg(
  162. url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz',
  163. hf_hub_id='timm/',
  164. custom_load=True,
  165. ),
  166. 'vit_small_r26_s32_384.augreg_in21k_ft_in1k': _cfg(
  167. url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
  168. hf_hub_id='timm/',
  169. input_size=(3, 384, 384), crop_pct=1.0, custom_load=True),
  170. 'vit_base_r26_s32_224.untrained': _cfg(),
  171. 'vit_base_r50_s16_384.orig_in21k_ft_in1k': _cfg(
  172. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
  173. hf_hub_id='timm/',
  174. input_size=(3, 384, 384), crop_pct=1.0),
  175. 'vit_large_r50_s32_224.augreg_in21k_ft_in1k': _cfg(
  176. url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
  177. hf_hub_id='timm/',
  178. custom_load=True,
  179. ),
  180. 'vit_large_r50_s32_384.augreg_in21k_ft_in1k': _cfg(
  181. url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
  182. hf_hub_id='timm/',
  183. input_size=(3, 384, 384), crop_pct=1.0, custom_load=True,
  184. ),
  185. # hybrid in-21k models (weights from official Google JAX impl where they exist)
  186. 'vit_tiny_r_s16_p8_224.augreg_in21k': _cfg(
  187. url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
  188. hf_hub_id='timm/',
  189. num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv', custom_load=True),
  190. 'vit_small_r26_s32_224.augreg_in21k': _cfg(
  191. url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0.npz',
  192. hf_hub_id='timm/',
  193. num_classes=21843, crop_pct=0.9, custom_load=True),
  194. 'vit_base_r50_s16_224.orig_in21k': _cfg(
  195. #url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
  196. hf_hub_id='timm/',
  197. num_classes=0, crop_pct=0.9),
  198. 'vit_large_r50_s32_224.augreg_in21k': _cfg(
  199. url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz',
  200. hf_hub_id='timm/',
  201. num_classes=21843, crop_pct=0.9, custom_load=True),
  202. # hybrid models (using timm resnet backbones)
  203. 'vit_small_resnet26d_224.untrained': _cfg(
  204. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
  205. 'vit_small_resnet50d_s16_224.untrained': _cfg(
  206. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
  207. 'vit_base_resnet26d_224.untrained': _cfg(
  208. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
  209. 'vit_base_resnet50d_224.untrained': _cfg(
  210. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
  211. 'vit_base_mci_224.apple_mclip_lt': _cfg(
  212. hf_hub_id='apple/mobileclip_b_lt_timm',
  213. url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt',
  214. license='apple-amlr',
  215. num_classes=512,
  216. mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv',
  217. ),
  218. 'vit_base_mci_224.apple_mclip': _cfg(
  219. hf_hub_id='apple/mobileclip_b_timm',
  220. url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_b.pt',
  221. num_classes=512,
  222. license='apple-amlr',
  223. mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv',
  224. ),
  225. 'vit_base_mci_224.apple_mclip2_dfndr2b': _cfg(
  226. hf_hub_id='timm/',
  227. num_classes=512,
  228. mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv',
  229. license='apple-amlr'
  230. ),
  231. })
  232. @register_model
  233. def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs) -> VisionTransformer:
  234. """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
  235. """
  236. backbone = _resnetv2(layers=(), **kwargs)
  237. model_args = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3)
  238. model = _create_vision_transformer_hybrid(
  239. 'vit_tiny_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
  240. return model
  241. @register_model
  242. def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs) -> VisionTransformer:
  243. """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 384 x 384.
  244. """
  245. backbone = _resnetv2(layers=(), **kwargs)
  246. model_args = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3)
  247. model = _create_vision_transformer_hybrid(
  248. 'vit_tiny_r_s16_p8_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
  249. return model
  250. @register_model
  251. def vit_small_r26_s32_224(pretrained=False, **kwargs) -> VisionTransformer:
  252. """ R26+ViT-S/S32 hybrid.
  253. """
  254. backbone = _resnetv2((2, 2, 2, 2), **kwargs)
  255. model_args = dict(embed_dim=384, depth=12, num_heads=6)
  256. model = _create_vision_transformer_hybrid(
  257. 'vit_small_r26_s32_224', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
  258. return model
  259. @register_model
  260. def vit_small_r26_s32_384(pretrained=False, **kwargs) -> VisionTransformer:
  261. """ R26+ViT-S/S32 hybrid.
  262. """
  263. backbone = _resnetv2((2, 2, 2, 2), **kwargs)
  264. model_args = dict(embed_dim=384, depth=12, num_heads=6)
  265. model = _create_vision_transformer_hybrid(
  266. 'vit_small_r26_s32_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
  267. return model
  268. @register_model
  269. def vit_base_r26_s32_224(pretrained=False, **kwargs) -> VisionTransformer:
  270. """ R26+ViT-B/S32 hybrid.
  271. """
  272. backbone = _resnetv2((2, 2, 2, 2), **kwargs)
  273. model_args = dict(embed_dim=768, depth=12, num_heads=12)
  274. model = _create_vision_transformer_hybrid(
  275. 'vit_base_r26_s32_224', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
  276. return model
  277. @register_model
  278. def vit_base_r50_s16_224(pretrained=False, **kwargs) -> VisionTransformer:
  279. """ R50+ViT-B/S16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
  280. """
  281. backbone = _resnetv2((3, 4, 9), **kwargs)
  282. model_args = dict(embed_dim=768, depth=12, num_heads=12)
  283. model = _create_vision_transformer_hybrid(
  284. 'vit_base_r50_s16_224', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
  285. return model
  286. @register_model
  287. def vit_base_r50_s16_384(pretrained=False, **kwargs) -> VisionTransformer:
  288. """ R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
  289. ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
  290. """
  291. backbone = _resnetv2((3, 4, 9), **kwargs)
  292. model_args = dict(embed_dim=768, depth=12, num_heads=12)
  293. model = _create_vision_transformer_hybrid(
  294. 'vit_base_r50_s16_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
  295. return model
  296. @register_model
  297. def vit_large_r50_s32_224(pretrained=False, **kwargs) -> VisionTransformer:
  298. """ R50+ViT-L/S32 hybrid.
  299. """
  300. backbone = _resnetv2((3, 4, 6, 3), **kwargs)
  301. model_args = dict(embed_dim=1024, depth=24, num_heads=16)
  302. model = _create_vision_transformer_hybrid(
  303. 'vit_large_r50_s32_224', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
  304. return model
  305. @register_model
  306. def vit_large_r50_s32_384(pretrained=False, **kwargs) -> VisionTransformer:
  307. """ R50+ViT-L/S32 hybrid.
  308. """
  309. backbone = _resnetv2((3, 4, 6, 3), **kwargs)
  310. model_args = dict(embed_dim=1024, depth=24, num_heads=16)
  311. model = _create_vision_transformer_hybrid(
  312. 'vit_large_r50_s32_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
  313. return model
  314. @register_model
  315. def vit_small_resnet26d_224(pretrained=False, **kwargs) -> VisionTransformer:
  316. """ Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.
  317. """
  318. backbone = resnet26d(
  319. pretrained=pretrained,
  320. in_chans=kwargs.get('in_chans', 3),
  321. features_only=True,
  322. out_indices=[4],
  323. **_dd_from_kwargs(**kwargs),
  324. )
  325. model_args = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3)
  326. model = _create_vision_transformer_hybrid(
  327. 'vit_small_resnet26d_224', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
  328. return model
  329. @register_model
  330. def vit_small_resnet50d_s16_224(pretrained=False, **kwargs) -> VisionTransformer:
  331. """ Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights.
  332. """
  333. backbone = resnet50d(
  334. pretrained=pretrained,
  335. in_chans=kwargs.get('in_chans', 3),
  336. features_only=True,
  337. out_indices=[3],
  338. **_dd_from_kwargs(**kwargs),
  339. )
  340. model_args = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3)
  341. model = _create_vision_transformer_hybrid(
  342. 'vit_small_resnet50d_s16_224', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
  343. return model
  344. @register_model
  345. def vit_base_resnet26d_224(pretrained=False, **kwargs) -> VisionTransformer:
  346. """ Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights.
  347. """
  348. backbone = resnet26d(
  349. pretrained=pretrained,
  350. in_chans=kwargs.get('in_chans', 3),
  351. features_only=True,
  352. out_indices=[4],
  353. **_dd_from_kwargs(**kwargs),
  354. )
  355. model_args = dict(embed_dim=768, depth=12, num_heads=12)
  356. model = _create_vision_transformer_hybrid(
  357. 'vit_base_resnet26d_224', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
  358. return model
  359. @register_model
  360. def vit_base_resnet50d_224(pretrained=False, **kwargs) -> VisionTransformer:
  361. """ Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights.
  362. """
  363. backbone = resnet50d(
  364. pretrained=pretrained,
  365. in_chans=kwargs.get('in_chans', 3),
  366. features_only=True,
  367. out_indices=[4],
  368. **_dd_from_kwargs(**kwargs),
  369. )
  370. model_args = dict(embed_dim=768, depth=12, num_heads=12)
  371. model = _create_vision_transformer_hybrid(
  372. 'vit_base_resnet50d_224', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
  373. return model
  374. @register_model
  375. def vit_base_mci_224(pretrained=False, **kwargs) -> VisionTransformer:
  376. """ Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights.
  377. """
  378. backbone = ConvStem(
  379. channels=(768//4, 768//4, 768),
  380. stride=(4, 2, 2),
  381. kernel_size=(4, 2, 2),
  382. padding=0,
  383. in_chans=kwargs.get('in_chans', 3),
  384. act_layer=nn.GELU,
  385. **_dd_from_kwargs(**kwargs),
  386. )
  387. model_args = dict(embed_dim=768, depth=12, num_heads=12, no_embed_class=True)
  388. model = _create_vision_transformer_hybrid(
  389. 'vit_base_mci_224', backbone=backbone, embed_args=dict(proj=False),
  390. pretrained=pretrained, **dict(model_args, **kwargs)
  391. )
  392. return model
  393. register_model_deprecations(__name__, {
  394. 'vit_tiny_r_s16_p8_224_in21k': 'vit_tiny_r_s16_p8_224.augreg_in21k',
  395. 'vit_small_r26_s32_224_in21k': 'vit_small_r26_s32_224.augreg_in21k',
  396. 'vit_base_r50_s16_224_in21k': 'vit_base_r50_s16_224.orig_in21k',
  397. 'vit_base_resnet50_224_in21k': 'vit_base_r50_s16_224.orig_in21k',
  398. 'vit_large_r50_s32_224_in21k': 'vit_large_r50_s32_224.augreg_in21k',
  399. 'vit_base_resnet50_384': 'vit_base_r50_s16_384.orig_in21k_ft_in1k'
  400. })