deit.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. """ DeiT - Data-efficient Image Transformers
  2. DeiT model defs and weights from https://github.com/facebookresearch/deit, original copyright below
  3. paper: `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
  4. paper: `DeiT III: Revenge of the ViT` - https://arxiv.org/abs/2204.07118
  5. Modifications copyright 2021, Ross Wightman
  6. """
  7. # Copyright (c) 2015-present, Facebook, Inc.
  8. # All rights reserved.
  9. from functools import partial
  10. from typing import Optional, Type
  11. import torch
  12. from torch import nn as nn
  13. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  14. from timm.layers import resample_abs_pos_embed
  15. from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn
  16. from ._builder import build_model_with_cfg
  17. from ._registry import generate_default_cfgs, register_model, register_model_deprecations
  18. __all__ = ['VisionTransformerDistilled'] # model_registry will add each entrypoint fn to this
  19. class VisionTransformerDistilled(VisionTransformer):
  20. """ Vision Transformer w/ Distillation Token and Head
  21. Distillation token & head support for `DeiT: Data-efficient Image Transformers`
  22. - https://arxiv.org/abs/2012.12877
  23. """
  24. def __init__(self, *args, **kwargs):
  25. weight_init = kwargs.pop('weight_init', '')
  26. super().__init__(*args, **kwargs, weight_init='skip')
  27. assert self.global_pool in ('token',)
  28. dd = {'device': kwargs.get('device', None), 'dtype': kwargs.get('dtype', None)}
  29. self.num_prefix_tokens = 2
  30. self.dist_token = nn.Parameter(torch.empty(1, 1, self.embed_dim, **dd))
  31. self.pos_embed = nn.Parameter(
  32. torch.empty(1, self.patch_embed.num_patches + self.num_prefix_tokens, self.embed_dim, **dd))
  33. self.head_dist = nn.Linear(self.embed_dim, self.num_classes, **dd) if self.num_classes > 0 else nn.Identity()
  34. self.distilled_training = False # must set this True to train w/ distillation token
  35. self.weight_init_mode = 'reset' if weight_init == 'skip' else weight_init
  36. # TODO: skip init when on meta device when safe to do so
  37. if weight_init != 'skip':
  38. self.init_weights(needs_reset=False)
  39. def init_weights(self, mode='', needs_reset=True):
  40. mode = mode or self.weight_init_mode
  41. trunc_normal_(self.dist_token, std=.02)
  42. super().init_weights(mode=mode, needs_reset=needs_reset)
  43. @torch.jit.ignore
  44. def group_matcher(self, coarse=False):
  45. return dict(
  46. stem=r'^cls_token|pos_embed|patch_embed|dist_token',
  47. blocks=[
  48. (r'^blocks\.(\d+)', None),
  49. (r'^norm', (99999,))] # final norm w/ last block
  50. )
  51. @torch.jit.ignore
  52. def get_classifier(self) -> nn.Module:
  53. return self.head, self.head_dist
  54. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  55. self.num_classes = num_classes
  56. self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
  57. self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
  58. @torch.jit.ignore
  59. def set_distilled_training(self, enable=True):
  60. self.distilled_training = enable
  61. def _pos_embed(self, x):
  62. if self.dynamic_img_size:
  63. B, H, W, C = x.shape
  64. prev_grid_size = self.patch_embed.grid_size
  65. pos_embed = resample_abs_pos_embed(
  66. self.pos_embed,
  67. new_size=(H, W),
  68. old_size=prev_grid_size,
  69. num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
  70. )
  71. x = x.view(B, -1, C)
  72. else:
  73. pos_embed = self.pos_embed
  74. if self.no_embed_class:
  75. # deit-3, updated JAX (big vision)
  76. # position embedding does not overlap with class token, add then concat
  77. x = x + pos_embed
  78. x = torch.cat((
  79. self.cls_token.expand(x.shape[0], -1, -1),
  80. self.dist_token.expand(x.shape[0], -1, -1),
  81. x),
  82. dim=1)
  83. else:
  84. # original timm, JAX, and deit vit impl
  85. # pos_embed has entry for class token, concat then add
  86. x = torch.cat((
  87. self.cls_token.expand(x.shape[0], -1, -1),
  88. self.dist_token.expand(x.shape[0], -1, -1),
  89. x),
  90. dim=1)
  91. x = x + pos_embed
  92. return self.pos_drop(x)
  93. def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
  94. x, x_dist = x[:, 0], x[:, 1]
  95. if pre_logits:
  96. return (x + x_dist) / 2
  97. x = self.head(x)
  98. x_dist = self.head_dist(x_dist)
  99. if self.distilled_training and self.training and not torch.jit.is_scripting():
  100. # only return separate classification predictions when training in distilled mode
  101. return x, x_dist
  102. else:
  103. # during standard train / finetune, inference average the classifier predictions
  104. return (x + x_dist) / 2
  105. def _create_deit(variant, pretrained=False, distilled=False, **kwargs):
  106. out_indices = kwargs.pop('out_indices', 3)
  107. model_cls = VisionTransformerDistilled if distilled else VisionTransformer
  108. model = build_model_with_cfg(
  109. model_cls,
  110. variant,
  111. pretrained,
  112. pretrained_filter_fn=partial(checkpoint_filter_fn, adapt_layer_scale=True),
  113. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  114. **kwargs,
  115. )
  116. return model
  117. def _cfg(url='', **kwargs):
  118. return {
  119. 'url': url,
  120. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  121. 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
  122. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  123. 'first_conv': 'patch_embed.proj', 'classifier': 'head',
  124. 'license': 'apache-2.0',
  125. **kwargs
  126. }
  127. default_cfgs = generate_default_cfgs({
  128. # deit models (FB weights)
  129. 'deit_tiny_patch16_224.fb_in1k': _cfg(
  130. hf_hub_id='timm/',
  131. url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
  132. 'deit_small_patch16_224.fb_in1k': _cfg(
  133. hf_hub_id='timm/',
  134. url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
  135. 'deit_base_patch16_224.fb_in1k': _cfg(
  136. hf_hub_id='timm/',
  137. url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth'),
  138. 'deit_base_patch16_384.fb_in1k': _cfg(
  139. hf_hub_id='timm/',
  140. url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
  141. input_size=(3, 384, 384), crop_pct=1.0),
  142. 'deit_tiny_distilled_patch16_224.fb_in1k': _cfg(
  143. hf_hub_id='timm/',
  144. url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
  145. classifier=('head', 'head_dist')),
  146. 'deit_small_distilled_patch16_224.fb_in1k': _cfg(
  147. hf_hub_id='timm/',
  148. url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
  149. classifier=('head', 'head_dist')),
  150. 'deit_base_distilled_patch16_224.fb_in1k': _cfg(
  151. hf_hub_id='timm/',
  152. url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
  153. classifier=('head', 'head_dist')),
  154. 'deit_base_distilled_patch16_384.fb_in1k': _cfg(
  155. hf_hub_id='timm/',
  156. url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
  157. input_size=(3, 384, 384), crop_pct=1.0,
  158. classifier=('head', 'head_dist')),
  159. 'deit3_small_patch16_224.fb_in1k': _cfg(
  160. hf_hub_id='timm/',
  161. url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_1k.pth'),
  162. 'deit3_small_patch16_384.fb_in1k': _cfg(
  163. hf_hub_id='timm/',
  164. url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_1k.pth',
  165. input_size=(3, 384, 384), crop_pct=1.0),
  166. 'deit3_medium_patch16_224.fb_in1k': _cfg(
  167. hf_hub_id='timm/',
  168. url='https://dl.fbaipublicfiles.com/deit/deit_3_medium_224_1k.pth'),
  169. 'deit3_base_patch16_224.fb_in1k': _cfg(
  170. hf_hub_id='timm/',
  171. url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_1k.pth'),
  172. 'deit3_base_patch16_384.fb_in1k': _cfg(
  173. hf_hub_id='timm/',
  174. url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_1k.pth',
  175. input_size=(3, 384, 384), crop_pct=1.0),
  176. 'deit3_large_patch16_224.fb_in1k': _cfg(
  177. hf_hub_id='timm/',
  178. url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_1k.pth'),
  179. 'deit3_large_patch16_384.fb_in1k': _cfg(
  180. hf_hub_id='timm/',
  181. url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_1k.pth',
  182. input_size=(3, 384, 384), crop_pct=1.0),
  183. 'deit3_huge_patch14_224.fb_in1k': _cfg(
  184. hf_hub_id='timm/',
  185. url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_1k.pth'),
  186. 'deit3_small_patch16_224.fb_in22k_ft_in1k': _cfg(
  187. hf_hub_id='timm/',
  188. url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_21k.pth',
  189. crop_pct=1.0),
  190. 'deit3_small_patch16_384.fb_in22k_ft_in1k': _cfg(
  191. hf_hub_id='timm/',
  192. url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_21k.pth',
  193. input_size=(3, 384, 384), crop_pct=1.0),
  194. 'deit3_medium_patch16_224.fb_in22k_ft_in1k': _cfg(
  195. hf_hub_id='timm/',
  196. url='https://dl.fbaipublicfiles.com/deit/deit_3_medium_224_21k.pth',
  197. crop_pct=1.0),
  198. 'deit3_base_patch16_224.fb_in22k_ft_in1k': _cfg(
  199. hf_hub_id='timm/',
  200. url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_21k.pth',
  201. crop_pct=1.0),
  202. 'deit3_base_patch16_384.fb_in22k_ft_in1k': _cfg(
  203. hf_hub_id='timm/',
  204. url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_21k.pth',
  205. input_size=(3, 384, 384), crop_pct=1.0),
  206. 'deit3_large_patch16_224.fb_in22k_ft_in1k': _cfg(
  207. hf_hub_id='timm/',
  208. url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_21k.pth',
  209. crop_pct=1.0),
  210. 'deit3_large_patch16_384.fb_in22k_ft_in1k': _cfg(
  211. hf_hub_id='timm/',
  212. url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_21k.pth',
  213. input_size=(3, 384, 384), crop_pct=1.0),
  214. 'deit3_huge_patch14_224.fb_in22k_ft_in1k': _cfg(
  215. hf_hub_id='timm/',
  216. url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_21k_v1.pth',
  217. crop_pct=1.0),
  218. })
  219. @register_model
  220. def deit_tiny_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
  221. """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
  222. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  223. """
  224. model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
  225. model = _create_deit('deit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  226. return model
  227. @register_model
  228. def deit_small_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
  229. """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
  230. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  231. """
  232. model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
  233. model = _create_deit('deit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  234. return model
  235. @register_model
  236. def deit_base_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
  237. """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
  238. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  239. """
  240. model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
  241. model = _create_deit('deit_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  242. return model
  243. @register_model
  244. def deit_base_patch16_384(pretrained=False, **kwargs) -> VisionTransformer:
  245. """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
  246. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  247. """
  248. model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
  249. model = _create_deit('deit_base_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  250. return model
  251. @register_model
  252. def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs) -> VisionTransformerDistilled:
  253. """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
  254. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  255. """
  256. model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
  257. model = _create_deit(
  258. 'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **dict(model_args, **kwargs))
  259. return model
  260. @register_model
  261. def deit_small_distilled_patch16_224(pretrained=False, **kwargs) -> VisionTransformerDistilled:
  262. """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
  263. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  264. """
  265. model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
  266. model = _create_deit(
  267. 'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **dict(model_args, **kwargs))
  268. return model
  269. @register_model
  270. def deit_base_distilled_patch16_224(pretrained=False, **kwargs) -> VisionTransformerDistilled:
  271. """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
  272. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  273. """
  274. model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
  275. model = _create_deit(
  276. 'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **dict(model_args, **kwargs))
  277. return model
  278. @register_model
  279. def deit_base_distilled_patch16_384(pretrained=False, **kwargs) -> VisionTransformerDistilled:
  280. """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
  281. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  282. """
  283. model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
  284. model = _create_deit(
  285. 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **dict(model_args, **kwargs))
  286. return model
  287. @register_model
  288. def deit3_small_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
  289. """ DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
  290. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  291. """
  292. model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6)
  293. model = _create_deit('deit3_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  294. return model
  295. @register_model
  296. def deit3_small_patch16_384(pretrained=False, **kwargs) -> VisionTransformer:
  297. """ DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
  298. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  299. """
  300. model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6)
  301. model = _create_deit('deit3_small_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  302. return model
  303. @register_model
  304. def deit3_medium_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
  305. """ DeiT-3 medium model @ 224x224 (https://arxiv.org/abs/2012.12877).
  306. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  307. """
  308. model_args = dict(patch_size=16, embed_dim=512, depth=12, num_heads=8, no_embed_class=True, init_values=1e-6)
  309. model = _create_deit('deit3_medium_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  310. return model
  311. @register_model
  312. def deit3_base_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
  313. """ DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
  314. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  315. """
  316. model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6)
  317. model = _create_deit('deit3_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  318. return model
  319. @register_model
  320. def deit3_base_patch16_384(pretrained=False, **kwargs) -> VisionTransformer:
  321. """ DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
  322. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  323. """
  324. model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6)
  325. model = _create_deit('deit3_base_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  326. return model
  327. @register_model
  328. def deit3_large_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
  329. """ DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
  330. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  331. """
  332. model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6)
  333. model = _create_deit('deit3_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  334. return model
  335. @register_model
  336. def deit3_large_patch16_384(pretrained=False, **kwargs) -> VisionTransformer:
  337. """ DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
  338. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  339. """
  340. model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6)
  341. model = _create_deit('deit3_large_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  342. return model
  343. @register_model
  344. def deit3_huge_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
  345. """ DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
  346. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  347. """
  348. model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6)
  349. model = _create_deit('deit3_huge_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  350. return model
  351. register_model_deprecations(__name__, {
  352. 'deit3_small_patch16_224_in21ft1k': 'deit3_small_patch16_224.fb_in22k_ft_in1k',
  353. 'deit3_small_patch16_384_in21ft1k': 'deit3_small_patch16_384.fb_in22k_ft_in1k',
  354. 'deit3_medium_patch16_224_in21ft1k': 'deit3_medium_patch16_224.fb_in22k_ft_in1k',
  355. 'deit3_base_patch16_224_in21ft1k': 'deit3_base_patch16_224.fb_in22k_ft_in1k',
  356. 'deit3_base_patch16_384_in21ft1k': 'deit3_base_patch16_384.fb_in22k_ft_in1k',
  357. 'deit3_large_patch16_224_in21ft1k': 'deit3_large_patch16_224.fb_in22k_ft_in1k',
  358. 'deit3_large_patch16_384_in21ft1k': 'deit3_large_patch16_384.fb_in22k_ft_in1k',
  359. 'deit3_huge_patch14_224_in21ft1k': 'deit3_huge_patch14_224.fb_in22k_ft_in1k'
  360. })