| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423 |
- """ DeiT - Data-efficient Image Transformers
- DeiT model defs and weights from https://github.com/facebookresearch/deit, original copyright below
- paper: `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
- paper: `DeiT III: Revenge of the ViT` - https://arxiv.org/abs/2204.07118
- Modifications copyright 2021, Ross Wightman
- """
- # Copyright (c) 2015-present, Facebook, Inc.
- # All rights reserved.
- from functools import partial
- from typing import Optional, Type
- import torch
- from torch import nn as nn
- from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
- from timm.layers import resample_abs_pos_embed
- from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn
- from ._builder import build_model_with_cfg
- from ._registry import generate_default_cfgs, register_model, register_model_deprecations
- __all__ = ['VisionTransformerDistilled'] # model_registry will add each entrypoint fn to this
- class VisionTransformerDistilled(VisionTransformer):
- """ Vision Transformer w/ Distillation Token and Head
- Distillation token & head support for `DeiT: Data-efficient Image Transformers`
- - https://arxiv.org/abs/2012.12877
- """
- def __init__(self, *args, **kwargs):
- weight_init = kwargs.pop('weight_init', '')
- super().__init__(*args, **kwargs, weight_init='skip')
- assert self.global_pool in ('token',)
- dd = {'device': kwargs.get('device', None), 'dtype': kwargs.get('dtype', None)}
- self.num_prefix_tokens = 2
- self.dist_token = nn.Parameter(torch.empty(1, 1, self.embed_dim, **dd))
- self.pos_embed = nn.Parameter(
- torch.empty(1, self.patch_embed.num_patches + self.num_prefix_tokens, self.embed_dim, **dd))
- self.head_dist = nn.Linear(self.embed_dim, self.num_classes, **dd) if self.num_classes > 0 else nn.Identity()
- self.distilled_training = False # must set this True to train w/ distillation token
- self.weight_init_mode = 'reset' if weight_init == 'skip' else weight_init
- # TODO: skip init when on meta device when safe to do so
- if weight_init != 'skip':
- self.init_weights(needs_reset=False)
- def init_weights(self, mode='', needs_reset=True):
- mode = mode or self.weight_init_mode
- trunc_normal_(self.dist_token, std=.02)
- super().init_weights(mode=mode, needs_reset=needs_reset)
- @torch.jit.ignore
- def group_matcher(self, coarse=False):
- return dict(
- stem=r'^cls_token|pos_embed|patch_embed|dist_token',
- blocks=[
- (r'^blocks\.(\d+)', None),
- (r'^norm', (99999,))] # final norm w/ last block
- )
- @torch.jit.ignore
- def get_classifier(self) -> nn.Module:
- return self.head, self.head_dist
- def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
- self.num_classes = num_classes
- self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
- self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
- @torch.jit.ignore
- def set_distilled_training(self, enable=True):
- self.distilled_training = enable
- def _pos_embed(self, x):
- if self.dynamic_img_size:
- B, H, W, C = x.shape
- prev_grid_size = self.patch_embed.grid_size
- pos_embed = resample_abs_pos_embed(
- self.pos_embed,
- new_size=(H, W),
- old_size=prev_grid_size,
- num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
- )
- x = x.view(B, -1, C)
- else:
- pos_embed = self.pos_embed
- if self.no_embed_class:
- # deit-3, updated JAX (big vision)
- # position embedding does not overlap with class token, add then concat
- x = x + pos_embed
- x = torch.cat((
- self.cls_token.expand(x.shape[0], -1, -1),
- self.dist_token.expand(x.shape[0], -1, -1),
- x),
- dim=1)
- else:
- # original timm, JAX, and deit vit impl
- # pos_embed has entry for class token, concat then add
- x = torch.cat((
- self.cls_token.expand(x.shape[0], -1, -1),
- self.dist_token.expand(x.shape[0], -1, -1),
- x),
- dim=1)
- x = x + pos_embed
- return self.pos_drop(x)
- def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
- x, x_dist = x[:, 0], x[:, 1]
- if pre_logits:
- return (x + x_dist) / 2
- x = self.head(x)
- x_dist = self.head_dist(x_dist)
- if self.distilled_training and self.training and not torch.jit.is_scripting():
- # only return separate classification predictions when training in distilled mode
- return x, x_dist
- else:
- # during standard train / finetune, inference average the classifier predictions
- return (x + x_dist) / 2
- def _create_deit(variant, pretrained=False, distilled=False, **kwargs):
- out_indices = kwargs.pop('out_indices', 3)
- model_cls = VisionTransformerDistilled if distilled else VisionTransformer
- model = build_model_with_cfg(
- model_cls,
- variant,
- pretrained,
- pretrained_filter_fn=partial(checkpoint_filter_fn, adapt_layer_scale=True),
- feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
- **kwargs,
- )
- return model
- def _cfg(url='', **kwargs):
- return {
- 'url': url,
- 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
- 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
- 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
- 'first_conv': 'patch_embed.proj', 'classifier': 'head',
- 'license': 'apache-2.0',
- **kwargs
- }
- default_cfgs = generate_default_cfgs({
- # deit models (FB weights)
- 'deit_tiny_patch16_224.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
- 'deit_small_patch16_224.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
- 'deit_base_patch16_224.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth'),
- 'deit_base_patch16_384.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
- input_size=(3, 384, 384), crop_pct=1.0),
- 'deit_tiny_distilled_patch16_224.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
- classifier=('head', 'head_dist')),
- 'deit_small_distilled_patch16_224.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
- classifier=('head', 'head_dist')),
- 'deit_base_distilled_patch16_224.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
- classifier=('head', 'head_dist')),
- 'deit_base_distilled_patch16_384.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
- input_size=(3, 384, 384), crop_pct=1.0,
- classifier=('head', 'head_dist')),
- 'deit3_small_patch16_224.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_1k.pth'),
- 'deit3_small_patch16_384.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_1k.pth',
- input_size=(3, 384, 384), crop_pct=1.0),
- 'deit3_medium_patch16_224.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/deit/deit_3_medium_224_1k.pth'),
- 'deit3_base_patch16_224.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_1k.pth'),
- 'deit3_base_patch16_384.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_1k.pth',
- input_size=(3, 384, 384), crop_pct=1.0),
- 'deit3_large_patch16_224.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_1k.pth'),
- 'deit3_large_patch16_384.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_1k.pth',
- input_size=(3, 384, 384), crop_pct=1.0),
- 'deit3_huge_patch14_224.fb_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_1k.pth'),
- 'deit3_small_patch16_224.fb_in22k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_21k.pth',
- crop_pct=1.0),
- 'deit3_small_patch16_384.fb_in22k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_21k.pth',
- input_size=(3, 384, 384), crop_pct=1.0),
- 'deit3_medium_patch16_224.fb_in22k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/deit/deit_3_medium_224_21k.pth',
- crop_pct=1.0),
- 'deit3_base_patch16_224.fb_in22k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_21k.pth',
- crop_pct=1.0),
- 'deit3_base_patch16_384.fb_in22k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_21k.pth',
- input_size=(3, 384, 384), crop_pct=1.0),
- 'deit3_large_patch16_224.fb_in22k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_21k.pth',
- crop_pct=1.0),
- 'deit3_large_patch16_384.fb_in22k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_21k.pth',
- input_size=(3, 384, 384), crop_pct=1.0),
- 'deit3_huge_patch14_224.fb_in22k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_21k_v1.pth',
- crop_pct=1.0),
- })
- @register_model
- def deit_tiny_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
- """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
- ImageNet-1k weights from https://github.com/facebookresearch/deit.
- """
- model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
- model = _create_deit('deit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def deit_small_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
- """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
- ImageNet-1k weights from https://github.com/facebookresearch/deit.
- """
- model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
- model = _create_deit('deit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def deit_base_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
- """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
- ImageNet-1k weights from https://github.com/facebookresearch/deit.
- """
- model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
- model = _create_deit('deit_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def deit_base_patch16_384(pretrained=False, **kwargs) -> VisionTransformer:
- """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
- ImageNet-1k weights from https://github.com/facebookresearch/deit.
- """
- model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
- model = _create_deit('deit_base_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs) -> VisionTransformerDistilled:
- """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
- ImageNet-1k weights from https://github.com/facebookresearch/deit.
- """
- model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
- model = _create_deit(
- 'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **dict(model_args, **kwargs))
- return model
- @register_model
- def deit_small_distilled_patch16_224(pretrained=False, **kwargs) -> VisionTransformerDistilled:
- """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
- ImageNet-1k weights from https://github.com/facebookresearch/deit.
- """
- model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
- model = _create_deit(
- 'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **dict(model_args, **kwargs))
- return model
- @register_model
- def deit_base_distilled_patch16_224(pretrained=False, **kwargs) -> VisionTransformerDistilled:
- """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
- ImageNet-1k weights from https://github.com/facebookresearch/deit.
- """
- model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
- model = _create_deit(
- 'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **dict(model_args, **kwargs))
- return model
- @register_model
- def deit_base_distilled_patch16_384(pretrained=False, **kwargs) -> VisionTransformerDistilled:
- """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
- ImageNet-1k weights from https://github.com/facebookresearch/deit.
- """
- model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
- model = _create_deit(
- 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **dict(model_args, **kwargs))
- return model
- @register_model
- def deit3_small_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
- """ DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
- ImageNet-1k weights from https://github.com/facebookresearch/deit.
- """
- model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6)
- model = _create_deit('deit3_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def deit3_small_patch16_384(pretrained=False, **kwargs) -> VisionTransformer:
- """ DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
- ImageNet-1k weights from https://github.com/facebookresearch/deit.
- """
- model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6)
- model = _create_deit('deit3_small_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def deit3_medium_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
- """ DeiT-3 medium model @ 224x224 (https://arxiv.org/abs/2012.12877).
- ImageNet-1k weights from https://github.com/facebookresearch/deit.
- """
- model_args = dict(patch_size=16, embed_dim=512, depth=12, num_heads=8, no_embed_class=True, init_values=1e-6)
- model = _create_deit('deit3_medium_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def deit3_base_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
- """ DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
- ImageNet-1k weights from https://github.com/facebookresearch/deit.
- """
- model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6)
- model = _create_deit('deit3_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def deit3_base_patch16_384(pretrained=False, **kwargs) -> VisionTransformer:
- """ DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
- ImageNet-1k weights from https://github.com/facebookresearch/deit.
- """
- model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6)
- model = _create_deit('deit3_base_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def deit3_large_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
- """ DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
- ImageNet-1k weights from https://github.com/facebookresearch/deit.
- """
- model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6)
- model = _create_deit('deit3_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def deit3_large_patch16_384(pretrained=False, **kwargs) -> VisionTransformer:
- """ DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
- ImageNet-1k weights from https://github.com/facebookresearch/deit.
- """
- model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6)
- model = _create_deit('deit3_large_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def deit3_huge_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
- """ DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
- ImageNet-1k weights from https://github.com/facebookresearch/deit.
- """
- model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6)
- model = _create_deit('deit3_huge_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- register_model_deprecations(__name__, {
- 'deit3_small_patch16_224_in21ft1k': 'deit3_small_patch16_224.fb_in22k_ft_in1k',
- 'deit3_small_patch16_384_in21ft1k': 'deit3_small_patch16_384.fb_in22k_ft_in1k',
- 'deit3_medium_patch16_224_in21ft1k': 'deit3_medium_patch16_224.fb_in22k_ft_in1k',
- 'deit3_base_patch16_224_in21ft1k': 'deit3_base_patch16_224.fb_in22k_ft_in1k',
- 'deit3_base_patch16_384_in21ft1k': 'deit3_base_patch16_384.fb_in22k_ft_in1k',
- 'deit3_large_patch16_224_in21ft1k': 'deit3_large_patch16_224.fb_in22k_ft_in1k',
- 'deit3_large_patch16_384_in21ft1k': 'deit3_large_patch16_384.fb_in22k_ft_in1k',
- 'deit3_huge_patch14_224_in21ft1k': 'deit3_huge_patch14_224.fb_in22k_ft_in1k'
- })
|