| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461 |
- """ Hybrid Vision Transformer (ViT) in PyTorch
- A PyTorch implement of the Hybrid Vision Transformers as described in:
- 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
- - https://arxiv.org/abs/2010.11929
- `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
- - https://arxiv.org/abs/2106.10270
- NOTE These hybrid model definitions depend on code in vision_transformer.py.
- They were moved here to keep file sizes sane.
- Hacked together by / Copyright 2020, Ross Wightman
- """
- from functools import partial
- from typing import Dict, Tuple, Type, Union
- import torch
- import torch.nn as nn
- from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
- from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_ntuple, HybridEmbed
- from ._builder import build_model_with_cfg
- from ._registry import generate_default_cfgs, register_model, register_model_deprecations
- from .resnet import resnet26d, resnet50d
- from .resnetv2 import ResNetV2, create_resnetv2_stem
- from .vision_transformer import VisionTransformer
- class ConvStem(nn.Sequential):
- def __init__(
- self,
- in_chans: int = 3,
- depth: int = 3,
- channels: Union[int, Tuple[int, ...]] = 64,
- kernel_size: Union[int, Tuple[int, ...]] = 3,
- stride: Union[int, Tuple[int, ...]] = (2, 2, 2),
- padding: Union[str, int, Tuple[int, ...]] = "",
- norm_layer: Type[nn.Module] = nn.BatchNorm2d,
- act_layer: Type[nn.Module] = nn.ReLU,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- if isinstance(channels, int):
- # a default tiered channel strategy
- channels = tuple([channels // 2**i for i in range(depth)][::-1])
- kernel_size = to_ntuple(depth)(kernel_size)
- padding = to_ntuple(depth)(padding)
- assert depth == len(stride) == len(kernel_size) == len(channels)
- in_chs = in_chans
- for i in range(len(channels)):
- last_conv = i == len(channels) - 1
- self.add_module(f'{i}', ConvNormAct(
- in_chs,
- channels[i],
- kernel_size=kernel_size[i],
- stride=stride[i],
- padding=padding[i],
- bias=last_conv,
- apply_norm=not last_conv,
- apply_act=not last_conv,
- norm_layer=norm_layer,
- act_layer=act_layer,
- **dd,
- ))
- in_chs = channels[i]
- def _dd_from_kwargs(**kwargs):
- return {'device': kwargs.get('device', None), 'dtype': kwargs.get('dtype', None)}
- def _resnetv2(layers=(3, 4, 9), **kwargs):
- """ ResNet-V2 backbone helper"""
- padding_same = kwargs.get('padding_same', True)
- stem_type = 'same' if padding_same else ''
- conv_layer = partial(StdConv2dSame, eps=1e-8) if padding_same else partial(StdConv2d, eps=1e-8)
- if len(layers):
- backbone = ResNetV2(
- layers=layers,
- num_classes=0,
- global_pool='',
- in_chans=kwargs.get('in_chans', 3),
- preact=False,
- stem_type=stem_type,
- conv_layer=conv_layer,
- **_dd_from_kwargs(**kwargs),
- )
- else:
- backbone = create_resnetv2_stem(
- kwargs.get('in_chans', 3),
- stem_type=stem_type,
- preact=False,
- conv_layer=conv_layer,
- **_dd_from_kwargs(**kwargs),
- )
- return backbone
- def _convert_mobileclip(state_dict, model, prefix='image_encoder.model.'):
- out = {}
- for k, v in state_dict.items():
- if not k.startswith(prefix):
- continue
- k = k.replace(prefix, '')
- k = k.replace('patch_emb.', 'patch_embed.backbone.')
- k = k.replace('block.conv', 'conv')
- k = k.replace('block.norm', 'bn')
- k = k.replace('post_transformer_norm.', 'norm.')
- k = k.replace('pre_norm_mha.0', 'norm1')
- k = k.replace('pre_norm_mha.1', 'attn')
- k = k.replace('pre_norm_ffn.0', 'norm2')
- k = k.replace('pre_norm_ffn.1', 'mlp.fc1')
- k = k.replace('pre_norm_ffn.4', 'mlp.fc2')
- k = k.replace('qkv_proj.', 'qkv.')
- k = k.replace('out_proj.', 'proj.')
- k = k.replace('transformer.', 'blocks.')
- if k == 'pos_embed.pos_embed.pos_embed':
- k = 'pos_embed'
- v = v.squeeze(0)
- if 'classifier.proj' in k:
- bias_k = k.replace('classifier.proj', 'head.bias')
- k = k.replace('classifier.proj', 'head.weight')
- v = v.T
- out[bias_k] = torch.zeros(v.shape[0])
- out[k] = v
- return out
- def checkpoint_filter_fn(
- state_dict: Dict[str, torch.Tensor],
- model: VisionTransformer,
- interpolation: str = 'bicubic',
- antialias: bool = True,
- ) -> Dict[str, torch.Tensor]:
- from .vision_transformer import checkpoint_filter_fn as _filter_fn
- if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict:
- state_dict = _convert_mobileclip(state_dict, model)
- return _filter_fn(state_dict, model, interpolation=interpolation, antialias=antialias)
- def _create_vision_transformer_hybrid(variant, backbone, embed_args=None, pretrained=False, **kwargs):
- out_indices = kwargs.pop('out_indices', 3)
- embed_args = embed_args or {}
- embed_layer = partial(HybridEmbed, backbone=backbone, **embed_args)
- kwargs.setdefault('embed_layer', embed_layer)
- kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
- return build_model_with_cfg(
- VisionTransformer,
- variant,
- pretrained,
- pretrained_filter_fn=checkpoint_filter_fn,
- feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
- **kwargs,
- )
- def _cfg(url='', **kwargs):
- return {
- 'url': url,
- 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
- 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
- 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
- 'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head',
- 'license': 'apache-2.0',
- **kwargs
- }
- default_cfgs = generate_default_cfgs({
- # hybrid in-1k models (weights from official JAX impl where they exist)
- 'vit_tiny_r_s16_p8_224.augreg_in21k_ft_in1k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
- hf_hub_id='timm/',
- custom_load=True,
- first_conv='patch_embed.backbone.conv'),
- 'vit_tiny_r_s16_p8_384.augreg_in21k_ft_in1k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
- hf_hub_id='timm/',
- first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0, custom_load=True),
- 'vit_small_r26_s32_224.augreg_in21k_ft_in1k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz',
- hf_hub_id='timm/',
- custom_load=True,
- ),
- 'vit_small_r26_s32_384.augreg_in21k_ft_in1k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
- hf_hub_id='timm/',
- input_size=(3, 384, 384), crop_pct=1.0, custom_load=True),
- 'vit_base_r26_s32_224.untrained': _cfg(),
- 'vit_base_r50_s16_384.orig_in21k_ft_in1k': _cfg(
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
- hf_hub_id='timm/',
- input_size=(3, 384, 384), crop_pct=1.0),
- 'vit_large_r50_s32_224.augreg_in21k_ft_in1k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
- hf_hub_id='timm/',
- custom_load=True,
- ),
- 'vit_large_r50_s32_384.augreg_in21k_ft_in1k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
- hf_hub_id='timm/',
- input_size=(3, 384, 384), crop_pct=1.0, custom_load=True,
- ),
- # hybrid in-21k models (weights from official Google JAX impl where they exist)
- 'vit_tiny_r_s16_p8_224.augreg_in21k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
- hf_hub_id='timm/',
- num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv', custom_load=True),
- 'vit_small_r26_s32_224.augreg_in21k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0.npz',
- hf_hub_id='timm/',
- num_classes=21843, crop_pct=0.9, custom_load=True),
- 'vit_base_r50_s16_224.orig_in21k': _cfg(
- #url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
- hf_hub_id='timm/',
- num_classes=0, crop_pct=0.9),
- 'vit_large_r50_s32_224.augreg_in21k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz',
- hf_hub_id='timm/',
- num_classes=21843, crop_pct=0.9, custom_load=True),
- # hybrid models (using timm resnet backbones)
- 'vit_small_resnet26d_224.untrained': _cfg(
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
- 'vit_small_resnet50d_s16_224.untrained': _cfg(
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
- 'vit_base_resnet26d_224.untrained': _cfg(
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
- 'vit_base_resnet50d_224.untrained': _cfg(
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
- 'vit_base_mci_224.apple_mclip_lt': _cfg(
- hf_hub_id='apple/mobileclip_b_lt_timm',
- url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt',
- license='apple-amlr',
- num_classes=512,
- mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv',
- ),
- 'vit_base_mci_224.apple_mclip': _cfg(
- hf_hub_id='apple/mobileclip_b_timm',
- url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_b.pt',
- num_classes=512,
- license='apple-amlr',
- mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv',
- ),
- 'vit_base_mci_224.apple_mclip2_dfndr2b': _cfg(
- hf_hub_id='timm/',
- num_classes=512,
- mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv',
- license='apple-amlr'
- ),
- })
- @register_model
- def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs) -> VisionTransformer:
- """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
- """
- backbone = _resnetv2(layers=(), **kwargs)
- model_args = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3)
- model = _create_vision_transformer_hybrid(
- 'vit_tiny_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs) -> VisionTransformer:
- """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 384 x 384.
- """
- backbone = _resnetv2(layers=(), **kwargs)
- model_args = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3)
- model = _create_vision_transformer_hybrid(
- 'vit_tiny_r_s16_p8_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_small_r26_s32_224(pretrained=False, **kwargs) -> VisionTransformer:
- """ R26+ViT-S/S32 hybrid.
- """
- backbone = _resnetv2((2, 2, 2, 2), **kwargs)
- model_args = dict(embed_dim=384, depth=12, num_heads=6)
- model = _create_vision_transformer_hybrid(
- 'vit_small_r26_s32_224', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_small_r26_s32_384(pretrained=False, **kwargs) -> VisionTransformer:
- """ R26+ViT-S/S32 hybrid.
- """
- backbone = _resnetv2((2, 2, 2, 2), **kwargs)
- model_args = dict(embed_dim=384, depth=12, num_heads=6)
- model = _create_vision_transformer_hybrid(
- 'vit_small_r26_s32_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_r26_s32_224(pretrained=False, **kwargs) -> VisionTransformer:
- """ R26+ViT-B/S32 hybrid.
- """
- backbone = _resnetv2((2, 2, 2, 2), **kwargs)
- model_args = dict(embed_dim=768, depth=12, num_heads=12)
- model = _create_vision_transformer_hybrid(
- 'vit_base_r26_s32_224', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_r50_s16_224(pretrained=False, **kwargs) -> VisionTransformer:
- """ R50+ViT-B/S16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
- """
- backbone = _resnetv2((3, 4, 9), **kwargs)
- model_args = dict(embed_dim=768, depth=12, num_heads=12)
- model = _create_vision_transformer_hybrid(
- 'vit_base_r50_s16_224', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_r50_s16_384(pretrained=False, **kwargs) -> VisionTransformer:
- """ R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
- ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
- """
- backbone = _resnetv2((3, 4, 9), **kwargs)
- model_args = dict(embed_dim=768, depth=12, num_heads=12)
- model = _create_vision_transformer_hybrid(
- 'vit_base_r50_s16_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_large_r50_s32_224(pretrained=False, **kwargs) -> VisionTransformer:
- """ R50+ViT-L/S32 hybrid.
- """
- backbone = _resnetv2((3, 4, 6, 3), **kwargs)
- model_args = dict(embed_dim=1024, depth=24, num_heads=16)
- model = _create_vision_transformer_hybrid(
- 'vit_large_r50_s32_224', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_large_r50_s32_384(pretrained=False, **kwargs) -> VisionTransformer:
- """ R50+ViT-L/S32 hybrid.
- """
- backbone = _resnetv2((3, 4, 6, 3), **kwargs)
- model_args = dict(embed_dim=1024, depth=24, num_heads=16)
- model = _create_vision_transformer_hybrid(
- 'vit_large_r50_s32_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_small_resnet26d_224(pretrained=False, **kwargs) -> VisionTransformer:
- """ Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.
- """
- backbone = resnet26d(
- pretrained=pretrained,
- in_chans=kwargs.get('in_chans', 3),
- features_only=True,
- out_indices=[4],
- **_dd_from_kwargs(**kwargs),
- )
- model_args = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3)
- model = _create_vision_transformer_hybrid(
- 'vit_small_resnet26d_224', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_small_resnet50d_s16_224(pretrained=False, **kwargs) -> VisionTransformer:
- """ Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights.
- """
- backbone = resnet50d(
- pretrained=pretrained,
- in_chans=kwargs.get('in_chans', 3),
- features_only=True,
- out_indices=[3],
- **_dd_from_kwargs(**kwargs),
- )
- model_args = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3)
- model = _create_vision_transformer_hybrid(
- 'vit_small_resnet50d_s16_224', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_resnet26d_224(pretrained=False, **kwargs) -> VisionTransformer:
- """ Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights.
- """
- backbone = resnet26d(
- pretrained=pretrained,
- in_chans=kwargs.get('in_chans', 3),
- features_only=True,
- out_indices=[4],
- **_dd_from_kwargs(**kwargs),
- )
- model_args = dict(embed_dim=768, depth=12, num_heads=12)
- model = _create_vision_transformer_hybrid(
- 'vit_base_resnet26d_224', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_resnet50d_224(pretrained=False, **kwargs) -> VisionTransformer:
- """ Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights.
- """
- backbone = resnet50d(
- pretrained=pretrained,
- in_chans=kwargs.get('in_chans', 3),
- features_only=True,
- out_indices=[4],
- **_dd_from_kwargs(**kwargs),
- )
- model_args = dict(embed_dim=768, depth=12, num_heads=12)
- model = _create_vision_transformer_hybrid(
- 'vit_base_resnet50d_224', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_mci_224(pretrained=False, **kwargs) -> VisionTransformer:
- """ Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights.
- """
- backbone = ConvStem(
- channels=(768//4, 768//4, 768),
- stride=(4, 2, 2),
- kernel_size=(4, 2, 2),
- padding=0,
- in_chans=kwargs.get('in_chans', 3),
- act_layer=nn.GELU,
- **_dd_from_kwargs(**kwargs),
- )
- model_args = dict(embed_dim=768, depth=12, num_heads=12, no_embed_class=True)
- model = _create_vision_transformer_hybrid(
- 'vit_base_mci_224', backbone=backbone, embed_args=dict(proj=False),
- pretrained=pretrained, **dict(model_args, **kwargs)
- )
- return model
- register_model_deprecations(__name__, {
- 'vit_tiny_r_s16_p8_224_in21k': 'vit_tiny_r_s16_p8_224.augreg_in21k',
- 'vit_small_r26_s32_224_in21k': 'vit_small_r26_s32_224.augreg_in21k',
- 'vit_base_r50_s16_224_in21k': 'vit_base_r50_s16_224.orig_in21k',
- 'vit_base_resnet50_224_in21k': 'vit_base_r50_s16_224.orig_in21k',
- 'vit_large_r50_s32_224_in21k': 'vit_large_r50_s32_224.augreg_in21k',
- 'vit_base_resnet50_384': 'vit_base_r50_s16_384.orig_in21k_ft_in1k'
- })
|