| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458 |
- """ ConViT Model
- @article{d2021convit,
- title={ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases},
- author={d'Ascoli, St{\'e}phane and Touvron, Hugo and Leavitt, Matthew and Morcos, Ari and Biroli, Giulio and Sagun, Levent},
- journal={arXiv preprint arXiv:2103.10697},
- year={2021}
- }
- Paper link: https://arxiv.org/abs/2103.10697
- Original code: https://github.com/facebookresearch/convit, original copyright below
- Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
- """
- # Copyright (c) 2015-present, Facebook, Inc.
- # All rights reserved.
- #
- # This source code is licensed under the CC-by-NC license found in the
- # LICENSE file in the root directory of this source tree.
- #
- '''These modules are adapted from those of timm, see
- https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
- '''
- from typing import Optional, Union, Type, Any
- import torch
- import torch.nn as nn
- from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
- from timm.layers import DropPath, calculate_drop_path_rates, trunc_normal_, PatchEmbed, Mlp, LayerNorm, HybridEmbed
- from ._builder import build_model_with_cfg
- from ._features_fx import register_notrace_module
- from ._registry import register_model, generate_default_cfgs
- __all__ = ['ConVit']
- @register_notrace_module # reason: FX can't symbolically trace control flow in forward method
- class GPSA(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- qkv_bias: bool = False,
- attn_drop: float = 0.,
- proj_drop: float = 0.,
- locality_strength: float = 1.,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.num_heads = num_heads
- self.dim = dim
- head_dim = dim // num_heads
- self.scale = head_dim ** -0.5
- self.locality_strength = locality_strength
- self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias, **dd)
- self.v = nn.Linear(dim, dim, bias=qkv_bias, **dd)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim, **dd)
- self.pos_proj = nn.Linear(3, num_heads, **dd)
- self.proj_drop = nn.Dropout(proj_drop)
- self.gating_param = nn.Parameter(torch.ones(self.num_heads, **dd))
- self.rel_indices: torch.Tensor = torch.zeros(1, 1, 1, 3, **dd) # silly torchscript hack, won't work with None
- def forward(self, x):
- B, N, C = x.shape
- if self.rel_indices is None or self.rel_indices.shape[1] != N:
- self.rel_indices = self.get_rel_indices(N)
- attn = self.get_attention(x)
- v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
- def get_attention(self, x):
- B, N, C = x.shape
- qk = self.qk(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- q, k = qk[0], qk[1]
- pos_score = self.rel_indices.expand(B, -1, -1, -1)
- pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2)
- patch_score = (q @ k.transpose(-2, -1)) * self.scale
- patch_score = patch_score.softmax(dim=-1)
- pos_score = pos_score.softmax(dim=-1)
- gating = self.gating_param.view(1, -1, 1, 1)
- attn = (1. - torch.sigmoid(gating)) * patch_score + torch.sigmoid(gating) * pos_score
- attn /= attn.sum(dim=-1).unsqueeze(-1)
- attn = self.attn_drop(attn)
- return attn
- def get_attention_map(self, x, return_map=False):
- attn_map = self.get_attention(x).mean(0) # average over batch
- distances = self.rel_indices.squeeze()[:, :, -1] ** .5
- dist = torch.einsum('nm,hnm->h', (distances, attn_map)) / distances.size(0)
- if return_map:
- return dist, attn_map
- else:
- return dist
- def local_init(self):
- self.v.weight.data.copy_(torch.eye(self.dim))
- locality_distance = 1 # max(1,1/locality_strength**.5)
- kernel_size = int(self.num_heads ** .5)
- center = (kernel_size - 1) / 2 if kernel_size % 2 == 0 else kernel_size // 2
- for h1 in range(kernel_size):
- for h2 in range(kernel_size):
- position = h1 + kernel_size * h2
- self.pos_proj.weight.data[position, 2] = -1
- self.pos_proj.weight.data[position, 1] = 2 * (h1 - center) * locality_distance
- self.pos_proj.weight.data[position, 0] = 2 * (h2 - center) * locality_distance
- self.pos_proj.weight.data *= self.locality_strength
- def get_rel_indices(self, num_patches: int) -> torch.Tensor:
- img_size = int(num_patches ** .5)
- rel_indices = torch.zeros(1, num_patches, num_patches, 3)
- ind = (
- torch.arange(img_size, dtype=torch.float32).view(1, -1)
- - torch.arange(img_size, dtype=torch.float32).view(-1, 1)
- )
- indx = ind.repeat(img_size, img_size)
- indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1)
- indd = indx ** 2 + indy ** 2
- rel_indices[:, :, :, 2] = indd.unsqueeze(0)
- rel_indices[:, :, :, 1] = indy.unsqueeze(0)
- rel_indices[:, :, :, 0] = indx.unsqueeze(0)
- device = self.qk.weight.device
- dtype = self.qk.weight.dtype
- return rel_indices.to(device=device, dtype=dtype)
- class MHSA(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- qkv_bias: bool = False,
- attn_drop: float = 0.,
- proj_drop: float = 0.,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.num_heads = num_heads
- head_dim = dim // num_heads
- self.scale = head_dim ** -0.5
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim, **dd)
- self.proj_drop = nn.Dropout(proj_drop)
- def get_attention_map(self, x, return_map=False):
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- q, k, v = qkv[0], qkv[1], qkv[2]
- attn_map = (q @ k.transpose(-2, -1)) * self.scale
- attn_map = attn_map.softmax(dim=-1).mean(0)
- img_size = int(N ** .5)
- ind = (
- torch.arange(img_size, dtype=torch.float32).view(1, -1)
- - torch.arange(img_size, dtype=torch.float32).view(-1, 1)
- )
- indx = ind.repeat(img_size, img_size)
- indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1)
- indd = indx ** 2 + indy ** 2
- distances = indd ** .5
- distances = distances.to(attn_map.device, attn_map.dtype)
- dist = torch.einsum('nm,hnm->h', (distances, attn_map)) / N
- if return_map:
- return dist, attn_map
- else:
- return dist
- def forward(self, x):
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- q, k, v = qkv.unbind(0)
- attn = (q @ k.transpose(-2, -1)) * self.scale
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
- class Block(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int,
- mlp_ratio: float = 4.,
- qkv_bias: bool = False,
- proj_drop: float = 0.,
- attn_drop: float = 0.,
- drop_path: float = 0.,
- act_layer: Type[nn.Module] = nn.GELU,
- norm_layer: Type[nn.Module] = LayerNorm,
- use_gpsa: bool = True,
- locality_strength: float = 1.,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.norm1 = norm_layer(dim, **dd)
- self.use_gpsa = use_gpsa
- if self.use_gpsa:
- self.attn = GPSA(
- dim,
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- attn_drop=attn_drop,
- proj_drop=proj_drop,
- locality_strength=locality_strength,
- **dd,
- )
- else:
- self.attn = MHSA(
- dim,
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- attn_drop=attn_drop,
- proj_drop=proj_drop,
- **dd,
- )
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim, **dd)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(
- in_features=dim,
- hidden_features=mlp_hidden_dim,
- act_layer=act_layer,
- drop=proj_drop,
- **dd,
- )
- def forward(self, x):
- x = x + self.drop_path(self.attn(self.norm1(x)))
- x = x + self.drop_path(self.mlp(self.norm2(x)))
- return x
- class ConVit(nn.Module):
- """ Vision Transformer with support for patch or hybrid CNN input stage
- """
- def __init__(
- self,
- img_size: int = 224,
- patch_size: int = 16,
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'token',
- embed_dim: int = 768,
- depth: int = 12,
- num_heads: int = 12,
- mlp_ratio: float = 4.,
- qkv_bias: bool = False,
- drop_rate: float = 0.,
- pos_drop_rate: float = 0.,
- proj_drop_rate: float = 0.,
- attn_drop_rate: float = 0.,
- drop_path_rate: float = 0.,
- hybrid_backbone: Optional[Any] = None,
- norm_layer: Type[nn.Module] = LayerNorm,
- local_up_to_layer: int = 3,
- locality_strength: float = 1.,
- use_pos_embed: bool = True,
- device=None,
- dtype=None,
- ):
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- assert global_pool in ('', 'avg', 'token')
- embed_dim *= num_heads
- self.num_classes = num_classes
- self.in_chans = in_chans
- self.global_pool = global_pool
- self.local_up_to_layer = local_up_to_layer
- self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
- self.locality_strength = locality_strength
- self.use_pos_embed = use_pos_embed
- if hybrid_backbone is not None:
- self.patch_embed = HybridEmbed(
- hybrid_backbone,
- img_size=img_size,
- in_chans=in_chans,
- embed_dim=embed_dim,
- **dd,
- )
- else:
- self.patch_embed = PatchEmbed(
- img_size=img_size,
- patch_size=patch_size,
- in_chans=in_chans,
- embed_dim=embed_dim,
- **dd,
- )
- num_patches = self.patch_embed.num_patches
- self.num_patches = num_patches
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, **dd))
- self.pos_drop = nn.Dropout(p=pos_drop_rate)
- if self.use_pos_embed:
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim, **dd))
- trunc_normal_(self.pos_embed, std=.02)
- dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule
- self.blocks = nn.ModuleList([
- Block(
- dim=embed_dim,
- num_heads=num_heads,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- proj_drop=proj_drop_rate,
- attn_drop=attn_drop_rate,
- drop_path=dpr[i],
- norm_layer=norm_layer,
- use_gpsa=i < local_up_to_layer,
- locality_strength=locality_strength,
- **dd,
- ) for i in range(depth)])
- self.norm = norm_layer(embed_dim, **dd)
- # Classifier head
- self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]
- self.head_drop = nn.Dropout(drop_rate)
- self.head = nn.Linear(embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
- trunc_normal_(self.cls_token, std=.02)
- self.apply(self._init_weights)
- for n, m in self.named_modules():
- if hasattr(m, 'local_init'):
- m.local_init()
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
- @torch.jit.ignore
- def no_weight_decay(self):
- return {'pos_embed', 'cls_token'}
- @torch.jit.ignore
- def group_matcher(self, coarse=False):
- return dict(
- stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
- blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
- )
- @torch.jit.ignore
- def set_grad_checkpointing(self, enable=True):
- assert not enable, 'gradient checkpointing not supported'
- @torch.jit.ignore
- def get_classifier(self) -> nn.Module:
- return self.head
- def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
- self.num_classes = num_classes
- if global_pool is not None:
- assert global_pool in ('', 'token', 'avg')
- self.global_pool = global_pool
- self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
- def forward_features(self, x):
- x = self.patch_embed(x)
- if self.use_pos_embed:
- x = x + self.pos_embed
- x = self.pos_drop(x)
- cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
- for u, blk in enumerate(self.blocks):
- if u == self.local_up_to_layer:
- x = torch.cat((cls_tokens, x), dim=1)
- x = blk(x)
- x = self.norm(x)
- return x
- def forward_head(self, x, pre_logits: bool = False):
- if self.global_pool:
- x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
- x = self.head_drop(x)
- return x if pre_logits else self.head(x)
- def forward(self, x):
- x = self.forward_features(x)
- x = self.forward_head(x)
- return x
- def _create_convit(variant, pretrained=False, **kwargs):
- if kwargs.get('features_only', None):
- raise RuntimeError('features_only not implemented for Vision Transformer models.')
- return build_model_with_cfg(ConVit, variant, pretrained, **kwargs)
- def _cfg(url='', **kwargs):
- return {
- 'url': url,
- 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
- 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True,
- 'first_conv': 'patch_embed.proj', 'classifier': 'head', 'license': 'apache-2.0',
- **kwargs
- }
- default_cfgs = generate_default_cfgs({
- # ConViT
- 'convit_tiny.fb_in1k': _cfg(hf_hub_id='timm/'),
- 'convit_small.fb_in1k': _cfg(hf_hub_id='timm/'),
- 'convit_base.fb_in1k': _cfg(hf_hub_id='timm/')
- })
- @register_model
- def convit_tiny(pretrained=False, **kwargs) -> ConVit:
- model_args = dict(
- local_up_to_layer=10, locality_strength=1.0, embed_dim=48, num_heads=4)
- model = _create_convit(variant='convit_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convit_small(pretrained=False, **kwargs) -> ConVit:
- model_args = dict(
- local_up_to_layer=10, locality_strength=1.0, embed_dim=48, num_heads=9)
- model = _create_convit(variant='convit_small', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convit_base(pretrained=False, **kwargs) -> ConVit:
- model_args = dict(
- local_up_to_layer=10, locality_strength=1.0, embed_dim=48, num_heads=16)
- model = _create_convit(variant='convit_base', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
|