| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152 |
- """ LeViT
- Paper: `LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference`
- - https://arxiv.org/abs/2104.01136
- @article{graham2021levit,
- title={LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference},
- author={Benjamin Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Herv\'e J\'egou and Matthijs Douze},
- journal={arXiv preprint arXiv:22104.01136},
- year={2021}
- }
- Adapted from official impl at https://github.com/facebookresearch/LeViT, original copyright bellow.
- This version combines both conv/linear models and fixes torchscript compatibility.
- Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
- """
- # Copyright (c) 2015-present, Facebook, Inc.
- # All rights reserved.
- # Modified from
- # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
- # Copyright 2020 Ross Wightman, Apache-2.0 License
- from collections import OrderedDict
- from functools import partial
- from typing import Dict, List, Optional, Tuple, Type, Union
- import torch
- import torch.nn as nn
- from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
- from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_, ndgrid
- from ._builder import build_model_with_cfg
- from ._features import feature_take_indices
- from ._manipulate import checkpoint, checkpoint_seq
- from ._registry import generate_default_cfgs, register_model
- __all__ = ['Levit']
- class ConvNorm(nn.Module):
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- kernel_size: int = 1,
- stride: int = 1,
- padding: int = 0,
- dilation: int = 1,
- groups: int = 1,
- bn_weight_init: float = 1,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.linear = nn.Conv2d(in_chs, out_chs, kernel_size, stride, padding, dilation, groups, bias=False, **dd)
- self.bn = nn.BatchNorm2d(out_chs, **dd)
- nn.init.constant_(self.bn.weight, bn_weight_init)
- @torch.no_grad()
- def fuse(self):
- c, bn = self.linear, self.bn
- w = bn.weight / (bn.running_var + bn.eps) ** 0.5
- w = c.weight * w[:, None, None, None]
- b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
- m = nn.Conv2d(
- w.size(1), w.size(0), w.shape[2:], stride=self.linear.stride,
- padding=self.linear.padding, dilation=self.linear.dilation, groups=self.linear.groups)
- m.weight.data.copy_(w)
- m.bias.data.copy_(b)
- return m
- def forward(self, x):
- return self.bn(self.linear(x))
- class LinearNorm(nn.Module):
- def __init__(
- self,
- in_features: int,
- out_features: int,
- bn_weight_init: float = 1,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.linear = nn.Linear(in_features, out_features, bias=False, **dd)
- self.bn = nn.BatchNorm1d(out_features, **dd)
- nn.init.constant_(self.bn.weight, bn_weight_init)
- @torch.no_grad()
- def fuse(self):
- l, bn = self.linear, self.bn
- w = bn.weight / (bn.running_var + bn.eps) ** 0.5
- w = l.weight * w[:, None]
- b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
- m = nn.Linear(w.size(1), w.size(0))
- m.weight.data.copy_(w)
- m.bias.data.copy_(b)
- return m
- def forward(self, x):
- x = self.linear(x)
- return self.bn(x.flatten(0, 1)).reshape_as(x)
- class NormLinear(nn.Module):
- def __init__(
- self,
- in_features: int,
- out_features: int,
- bias: bool = True,
- std: float = 0.02,
- drop: float = 0.,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.bn = nn.BatchNorm1d(in_features, **dd)
- self.drop = nn.Dropout(drop)
- self.linear = nn.Linear(in_features, out_features, bias=bias, **dd)
- trunc_normal_(self.linear.weight, std=std)
- if self.linear.bias is not None:
- nn.init.constant_(self.linear.bias, 0)
- @torch.no_grad()
- def fuse(self):
- bn, l = self.bn, self.linear
- w = bn.weight / (bn.running_var + bn.eps) ** 0.5
- b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5
- w = l.weight * w[None, :]
- if l.bias is None:
- b = b @ self.linear.weight.T
- else:
- b = (l.weight @ b[:, None]).view(-1) + self.linear.bias
- m = nn.Linear(w.size(1), w.size(0))
- m.weight.data.copy_(w)
- m.bias.data.copy_(b)
- return m
- def forward(self, x):
- return self.linear(self.drop(self.bn(x)))
- class Stem8(nn.Sequential):
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- act_layer: Type[nn.Module],
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.stride = 8
- self.add_module('conv1', ConvNorm(in_chs, out_chs // 4, 3, stride=2, padding=1, **dd))
- self.add_module('act1', act_layer())
- self.add_module('conv2', ConvNorm(out_chs // 4, out_chs // 2, 3, stride=2, padding=1, **dd))
- self.add_module('act2', act_layer())
- self.add_module('conv3', ConvNorm(out_chs // 2, out_chs, 3, stride=2, padding=1, **dd))
- class Stem16(nn.Sequential):
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- act_layer: Type[nn.Module],
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.stride = 16
- self.add_module('conv1', ConvNorm(in_chs, out_chs // 8, 3, stride=2, padding=1, **dd))
- self.add_module('act1', act_layer())
- self.add_module('conv2', ConvNorm(out_chs // 8, out_chs // 4, 3, stride=2, padding=1, **dd))
- self.add_module('act2', act_layer())
- self.add_module('conv3', ConvNorm(out_chs // 4, out_chs // 2, 3, stride=2, padding=1, **dd))
- self.add_module('act3', act_layer())
- self.add_module('conv4', ConvNorm(out_chs // 2, out_chs, 3, stride=2, padding=1, **dd))
- class Downsample(nn.Module):
- def __init__(
- self,
- stride: int,
- resolution: Union[int, Tuple[int, int]],
- use_pool: bool = False,
- device=None,
- dtype=None,
- ):
- super().__init__()
- self.stride = stride
- self.resolution = to_2tuple(resolution)
- self.pool = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) if use_pool else None
- def forward(self, x):
- B, N, C = x.shape
- x = x.view(B, self.resolution[0], self.resolution[1], C)
- if self.pool is not None:
- x = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
- else:
- x = x[:, ::self.stride, ::self.stride]
- return x.reshape(B, -1, C)
- class Attention(nn.Module):
- attention_bias_cache: Dict[str, torch.Tensor]
- def __init__(
- self,
- dim: int,
- key_dim: int,
- num_heads: int = 8,
- attn_ratio: float = 4.,
- resolution: Union[int, Tuple[int, int]] = 14,
- use_conv: bool = False,
- act_layer: Type[nn.Module] = nn.SiLU,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- ln_layer = ConvNorm if use_conv else LinearNorm
- resolution = to_2tuple(resolution)
- self.use_conv = use_conv
- self.num_heads = num_heads
- self.scale = key_dim ** -0.5
- self.key_dim = key_dim
- self.key_attn_dim = key_dim * num_heads
- self.val_dim = int(attn_ratio * key_dim)
- self.val_attn_dim = int(attn_ratio * key_dim) * num_heads
- self.resolution = resolution
- self.qkv = ln_layer(dim, self.val_attn_dim + self.key_attn_dim * 2, **dd)
- self.proj = nn.Sequential(OrderedDict([
- ('act', act_layer()),
- ('ln', ln_layer(self.val_attn_dim, dim, bn_weight_init=0, **dd))
- ]))
- N = resolution[0] * resolution[1]
- self.attention_biases = nn.Parameter(torch.empty(num_heads, N, **dd))
- self.register_buffer(
- 'attention_bias_idxs', torch.empty((N, N), device=device, dtype=torch.long), persistent=False)
- self.attention_bias_cache = {}
- # TODO: skip init when on meta device when safe to do so
- self.reset_parameters()
- @torch.no_grad()
- def train(self, mode=True):
- super().train(mode)
- if mode and self.attention_bias_cache:
- self.attention_bias_cache = {} # clear ab cache
- def reset_parameters(self) -> None:
- """Initialize parameters and buffers."""
- nn.init.zeros_(self.attention_biases)
- self._init_buffers()
- def _compute_attention_bias_idxs(self, device=None):
- """Compute relative position indices for attention bias."""
- pos = torch.stack(ndgrid(
- torch.arange(self.resolution[0], device=device, dtype=torch.long),
- torch.arange(self.resolution[1], device=device, dtype=torch.long),
- )).flatten(1)
- rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
- rel_pos = (rel_pos[0] * self.resolution[1]) + rel_pos[1]
- return rel_pos
- def _init_buffers(self) -> None:
- """Compute and fill non-persistent buffer values."""
- self.attention_bias_idxs.copy_(
- self._compute_attention_bias_idxs(device=self.attention_bias_idxs.device)
- )
- self.attention_bias_cache = {}
- def init_non_persistent_buffers(self) -> None:
- """Initialize non-persistent buffers."""
- self._init_buffers()
- def get_attention_biases(self, device: torch.device) -> torch.Tensor:
- if torch.jit.is_tracing() or self.training:
- return self.attention_biases[:, self.attention_bias_idxs]
- else:
- device_key = str(device)
- if device_key not in self.attention_bias_cache:
- self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
- return self.attention_bias_cache[device_key]
- def forward(self, x): # x (B,C,H,W)
- if self.use_conv:
- B, C, H, W = x.shape
- q, k, v = self.qkv(x).view(
- B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.val_dim], dim=2)
- attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
- attn = attn.softmax(dim=-1)
- x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
- else:
- B, N, C = x.shape
- q, k, v = self.qkv(x).view(
- B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.val_dim], dim=3)
- q = q.permute(0, 2, 1, 3)
- k = k.permute(0, 2, 3, 1)
- v = v.permute(0, 2, 1, 3)
- attn = q @ k * self.scale + self.get_attention_biases(x.device)
- attn = attn.softmax(dim=-1)
- x = (attn @ v).transpose(1, 2).reshape(B, N, self.val_attn_dim)
- x = self.proj(x)
- return x
- class AttentionDownsample(nn.Module):
- attention_bias_cache: Dict[str, torch.Tensor]
- def __init__(
- self,
- in_dim: int,
- out_dim: int,
- key_dim: int,
- num_heads: int = 8,
- attn_ratio: float = 2.0,
- stride: int = 2,
- resolution: Union[int, Tuple[int, int]] = 14,
- use_conv: bool = False,
- use_pool: bool = False,
- act_layer: Type[nn.Module] = nn.SiLU,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- resolution = to_2tuple(resolution)
- self.stride = stride
- self.resolution = resolution
- self.num_heads = num_heads
- self.key_dim = key_dim
- self.key_attn_dim = key_dim * num_heads
- self.val_dim = int(attn_ratio * key_dim)
- self.val_attn_dim = self.val_dim * self.num_heads
- self.scale = key_dim ** -0.5
- self.use_conv = use_conv
- if self.use_conv:
- ln_layer = ConvNorm
- sub_layer = partial(
- nn.AvgPool2d,
- kernel_size=3 if use_pool else 1, padding=1 if use_pool else 0, count_include_pad=False)
- else:
- ln_layer = LinearNorm
- sub_layer = partial(Downsample, resolution=resolution, use_pool=use_pool, **dd)
- self.kv = ln_layer(in_dim, self.val_attn_dim + self.key_attn_dim, **dd)
- self.q = nn.Sequential(OrderedDict([
- ('down', sub_layer(stride=stride)),
- ('ln', ln_layer(in_dim, self.key_attn_dim, **dd))
- ]))
- self.proj = nn.Sequential(OrderedDict([
- ('act', act_layer()),
- ('ln', ln_layer(self.val_attn_dim, out_dim, **dd))
- ]))
- N_k = resolution[0] * resolution[1]
- N_q = -(-resolution[0] // stride) * -(-resolution[1] // stride) # ceiling division
- self.attention_biases = nn.Parameter(torch.empty(num_heads, N_k, **dd))
- self.register_buffer('attention_bias_idxs', torch.empty((N_q, N_k), device=device, dtype=torch.long), persistent=False)
- self.attention_bias_cache = {}
- # TODO: skip init when on meta device when safe to do so
- self.reset_parameters()
- @torch.no_grad()
- def train(self, mode=True):
- super().train(mode)
- if mode and self.attention_bias_cache:
- self.attention_bias_cache = {} # clear ab cache
- def reset_parameters(self) -> None:
- """Initialize parameters and buffers."""
- nn.init.zeros_(self.attention_biases)
- self._init_buffers()
- def _compute_attention_bias_idxs(self, device=None):
- """Compute relative position indices for attention bias."""
- k_pos = torch.stack(ndgrid(
- torch.arange(self.resolution[0], device=device, dtype=torch.long),
- torch.arange(self.resolution[1], device=device, dtype=torch.long),
- )).flatten(1)
- q_pos = torch.stack(ndgrid(
- torch.arange(0, self.resolution[0], step=self.stride, device=device, dtype=torch.long),
- torch.arange(0, self.resolution[1], step=self.stride, device=device, dtype=torch.long),
- )).flatten(1)
- rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
- rel_pos = (rel_pos[0] * self.resolution[1]) + rel_pos[1]
- return rel_pos
- def _init_buffers(self) -> None:
- """Compute and fill non-persistent buffer values."""
- self.attention_bias_idxs.copy_(
- self._compute_attention_bias_idxs(device=self.attention_bias_idxs.device)
- )
- self.attention_bias_cache = {}
- def init_non_persistent_buffers(self) -> None:
- """Initialize non-persistent buffers."""
- self._init_buffers()
- def get_attention_biases(self, device: torch.device) -> torch.Tensor:
- if torch.jit.is_tracing() or self.training:
- return self.attention_biases[:, self.attention_bias_idxs]
- else:
- device_key = str(device)
- if device_key not in self.attention_bias_cache:
- self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
- return self.attention_bias_cache[device_key]
- def forward(self, x):
- if self.use_conv:
- B, C, H, W = x.shape
- HH, WW = (H - 1) // self.stride + 1, (W - 1) // self.stride + 1
- k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.val_dim], dim=2)
- q = self.q(x).view(B, self.num_heads, self.key_dim, -1)
- attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
- attn = attn.softmax(dim=-1)
- x = (v @ attn.transpose(-2, -1)).reshape(B, self.val_attn_dim, HH, WW)
- else:
- B, N, C = x.shape
- k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.val_dim], dim=3)
- k = k.permute(0, 2, 3, 1) # BHCN
- v = v.permute(0, 2, 1, 3) # BHNC
- q = self.q(x).view(B, -1, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
- attn = q @ k * self.scale + self.get_attention_biases(x.device)
- attn = attn.softmax(dim=-1)
- x = (attn @ v).transpose(1, 2).reshape(B, -1, self.val_attn_dim)
- x = self.proj(x)
- return x
- class LevitMlp(nn.Module):
- """ MLP for Levit w/ normalization + ability to switch btw conv and linear
- """
- def __init__(
- self,
- in_features: int,
- hidden_features: Optional[int] = None,
- out_features: Optional[int] = None,
- use_conv: bool = False,
- act_layer: Type[nn.Module] = nn.SiLU,
- drop: float = 0.,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- ln_layer = ConvNorm if use_conv else LinearNorm
- self.ln1 = ln_layer(in_features, hidden_features, **dd)
- self.act = act_layer()
- self.drop = nn.Dropout(drop)
- self.ln2 = ln_layer(hidden_features, out_features, bn_weight_init=0, **dd)
- def forward(self, x):
- x = self.ln1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.ln2(x)
- return x
- class LevitDownsample(nn.Module):
- def __init__(
- self,
- in_dim: int,
- out_dim: int,
- key_dim: int,
- num_heads: int = 8,
- attn_ratio: float = 4.,
- mlp_ratio: float = 2.,
- act_layer: Type[nn.Module] = nn.SiLU,
- attn_act_layer: Optional[Type[nn.Module]] = None,
- resolution: Union[int, Tuple[int, int]] = 14,
- use_conv: bool = False,
- use_pool: bool = False,
- drop_path: float = 0.,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- attn_act_layer = attn_act_layer or act_layer
- self.attn_downsample = AttentionDownsample(
- in_dim=in_dim,
- out_dim=out_dim,
- key_dim=key_dim,
- num_heads=num_heads,
- attn_ratio=attn_ratio,
- act_layer=attn_act_layer,
- resolution=resolution,
- use_conv=use_conv,
- use_pool=use_pool,
- **dd,
- )
- self.mlp = LevitMlp(
- out_dim,
- int(out_dim * mlp_ratio),
- use_conv=use_conv,
- act_layer=act_layer,
- **dd,
- )
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- def forward(self, x):
- x = self.attn_downsample(x)
- x = x + self.drop_path(self.mlp(x))
- return x
- class LevitBlock(nn.Module):
- def __init__(
- self,
- dim: int,
- key_dim: int,
- num_heads: int = 8,
- attn_ratio: float = 4.,
- mlp_ratio: float = 2.,
- resolution: Union[int, Tuple[int, int]] = 14,
- use_conv: bool = False,
- act_layer: Type[nn.Module] = nn.SiLU,
- attn_act_layer: Optional[Type[nn.Module]] = None,
- drop_path: float = 0.,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- attn_act_layer = attn_act_layer or act_layer
- self.attn = Attention(
- dim=dim,
- key_dim=key_dim,
- num_heads=num_heads,
- attn_ratio=attn_ratio,
- resolution=resolution,
- use_conv=use_conv,
- act_layer=attn_act_layer,
- **dd,
- )
- self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.mlp = LevitMlp(
- dim,
- int(dim * mlp_ratio),
- use_conv=use_conv,
- act_layer=act_layer,
- **dd,
- )
- self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- def forward(self, x):
- x = x + self.drop_path1(self.attn(x))
- x = x + self.drop_path2(self.mlp(x))
- return x
- class LevitStage(nn.Module):
- def __init__(
- self,
- in_dim: int,
- out_dim: int,
- key_dim: int,
- depth: int = 4,
- num_heads: int = 8,
- attn_ratio: float = 4.0,
- mlp_ratio: float = 4.0,
- act_layer: Type[nn.Module] = nn.SiLU,
- attn_act_layer: Optional[Type[nn.Module]] = None,
- resolution: Union[int, Tuple[int, int]] = 14,
- downsample: str = '',
- use_conv: bool = False,
- drop_path: float = 0.,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- resolution = to_2tuple(resolution)
- if downsample:
- self.downsample = LevitDownsample(
- in_dim,
- out_dim,
- key_dim=key_dim,
- num_heads=in_dim // key_dim,
- attn_ratio=4.,
- mlp_ratio=2.,
- act_layer=act_layer,
- attn_act_layer=attn_act_layer,
- resolution=resolution,
- use_conv=use_conv,
- drop_path=drop_path,
- **dd,
- )
- resolution = [(r - 1) // 2 + 1 for r in resolution]
- else:
- assert in_dim == out_dim
- self.downsample = nn.Identity()
- blocks = []
- for _ in range(depth):
- blocks += [LevitBlock(
- out_dim,
- key_dim,
- num_heads=num_heads,
- attn_ratio=attn_ratio,
- mlp_ratio=mlp_ratio,
- act_layer=act_layer,
- attn_act_layer=attn_act_layer,
- resolution=resolution,
- use_conv=use_conv,
- drop_path=drop_path,
- **dd,
- )]
- self.blocks = nn.Sequential(*blocks)
- def forward(self, x):
- x = self.downsample(x)
- x = self.blocks(x)
- return x
- class Levit(nn.Module):
- """ Vision Transformer with support for patch or hybrid CNN input stage
- NOTE: distillation is defaulted to True since pretrained weights use it, will cause problems
- w/ train scripts that don't take tuple outputs,
- """
- def __init__(
- self,
- img_size: Union[int, Tuple[int, int]] = 224,
- in_chans: int = 3,
- num_classes: int = 1000,
- embed_dim: Tuple[int, ...] = (192,),
- key_dim: int = 64,
- depth: Tuple[int, ...] = (12,),
- num_heads: Union[int, Tuple[int, ...]] = (3,),
- attn_ratio: Union[float, Tuple[float, ...]] = 2.,
- mlp_ratio: Union[float, Tuple[float, ...]] = 2.,
- stem_backbone: Optional[nn.Module] = None,
- stem_stride: Optional[int] = None,
- stem_type: str = 's16',
- down_op: str = 'subsample',
- act_layer: str = 'hard_swish',
- attn_act_layer: Optional[str] = None,
- use_conv: bool = False,
- global_pool: str = 'avg',
- drop_rate: float = 0.,
- drop_path_rate: float = 0.,
- device=None,
- dtype=None,
- ):
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- act_layer = get_act_layer(act_layer)
- attn_act_layer = get_act_layer(attn_act_layer or act_layer)
- self.use_conv = use_conv
- self.num_classes = num_classes
- self.in_chans = in_chans
- self.global_pool = global_pool
- self.num_features = self.head_hidden_size = embed_dim[-1]
- self.embed_dim = embed_dim
- self.drop_rate = drop_rate
- self.grad_checkpointing = False
- self.feature_info = []
- num_stages = len(embed_dim)
- assert len(depth) == num_stages
- num_heads = to_ntuple(num_stages)(num_heads)
- attn_ratio = to_ntuple(num_stages)(attn_ratio)
- mlp_ratio = to_ntuple(num_stages)(mlp_ratio)
- if stem_backbone is not None:
- assert stem_stride >= 2
- self.stem = stem_backbone
- stride = stem_stride
- else:
- assert stem_type in ('s16', 's8')
- if stem_type == 's16':
- self.stem = Stem16(in_chans, embed_dim[0], act_layer=act_layer, **dd)
- else:
- self.stem = Stem8(in_chans, embed_dim[0], act_layer=act_layer, **dd)
- stride = self.stem.stride
- resolution = tuple([i // p for i, p in zip(to_2tuple(img_size), to_2tuple(stride))])
- in_dim = embed_dim[0]
- stages = []
- for i in range(num_stages):
- stage_stride = 2 if i > 0 else 1
- stages += [LevitStage(
- in_dim,
- embed_dim[i],
- key_dim,
- depth=depth[i],
- num_heads=num_heads[i],
- attn_ratio=attn_ratio[i],
- mlp_ratio=mlp_ratio[i],
- act_layer=act_layer,
- attn_act_layer=attn_act_layer,
- resolution=resolution,
- use_conv=use_conv,
- downsample=down_op if stage_stride == 2 else '',
- drop_path=drop_path_rate,
- **dd,
- )]
- stride *= stage_stride
- resolution = tuple([(r - 1) // stage_stride + 1 for r in resolution])
- self.feature_info += [dict(num_chs=embed_dim[i], reduction=stride, module=f'stages.{i}')]
- in_dim = embed_dim[i]
- self.stages = nn.Sequential(*stages)
- # Classifier head
- self.head = NormLinear(embed_dim[-1], num_classes, drop=drop_rate, **dd) if num_classes > 0 else nn.Identity()
- # TODO: skip init when on meta device when safe to do so
- self.init_weights(needs_reset=False)
- def init_weights(self, needs_reset: bool = True):
- self.apply(partial(self._init_weights, needs_reset=needs_reset))
- def _init_weights(self, m: nn.Module, needs_reset: bool = True) -> None:
- if needs_reset and hasattr(m, 'reset_parameters'):
- m.reset_parameters()
- @torch.jit.ignore
- def no_weight_decay(self):
- return {x for x in self.state_dict().keys() if 'attention_biases' in x}
- @torch.jit.ignore
- def group_matcher(self, coarse=False):
- matcher = dict(
- stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
- blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
- )
- return matcher
- @torch.jit.ignore
- def set_grad_checkpointing(self, enable=True):
- self.grad_checkpointing = enable
- @torch.jit.ignore
- def get_classifier(self) -> nn.Module:
- return self.head
- def reset_classifier(self, num_classes: int , global_pool: Optional[str] = None):
- self.num_classes = num_classes
- if global_pool is not None:
- self.global_pool = global_pool
- self.head = NormLinear(
- self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else nn.Identity()
- def forward_intermediates(
- self,
- x: torch.Tensor,
- indices: Optional[Union[int, List[int]]] = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
- ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
- """ Forward features that returns intermediates.
- Args:
- x: Input image tensor
- indices: Take last n blocks if int, all if None, select matching indices if sequence
- norm: Apply norm layer to compatible intermediates
- stop_early: Stop iterating over blocks when last desired intermediate hit
- output_fmt: Shape of intermediate feature outputs
- intermediates_only: Only return intermediate features
- Returns:
- """
- assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
- intermediates = []
- take_indices, max_index = feature_take_indices(len(self.stages), indices)
- # forward pass
- x = self.stem(x)
- B, C, H, W = x.shape
- if not self.use_conv:
- x = x.flatten(2).transpose(1, 2)
- if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
- stages = self.stages
- else:
- stages = self.stages[:max_index + 1]
- for feat_idx, stage in enumerate(stages):
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint(stage, x)
- else:
- x = stage(x)
- if feat_idx in take_indices:
- if self.use_conv:
- intermediates.append(x)
- else:
- intermediates.append(x.reshape(B, H, W, -1).permute(0, 3, 1, 2))
- H = (H + 2 - 1) // 2
- W = (W + 2 - 1) // 2
- if intermediates_only:
- return intermediates
- return x, intermediates
- def prune_intermediate_layers(
- self,
- indices: Union[int, List[int]] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
- ):
- """ Prune layers not required for specified intermediates.
- """
- take_indices, max_index = feature_take_indices(len(self.stages), indices)
- self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
- if prune_head:
- self.reset_classifier(0, '')
- return take_indices
- def forward_features(self, x):
- x = self.stem(x)
- if not self.use_conv:
- x = x.flatten(2).transpose(1, 2)
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint_seq(self.stages, x)
- else:
- x = self.stages(x)
- return x
- def forward_head(self, x, pre_logits: bool = False):
- if self.global_pool == 'avg':
- x = x.mean(dim=(-2, -1)) if self.use_conv else x.mean(dim=1)
- return x if pre_logits else self.head(x)
- def forward(self, x):
- x = self.forward_features(x)
- x = self.forward_head(x)
- return x
- class LevitDistilled(Levit):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- dd = {'device': kwargs.get('device', None), 'dtype': kwargs.get('dtype', None)}
- self.head_dist = NormLinear(self.num_features, self.num_classes, **dd) if self.num_classes > 0 else nn.Identity()
- self.distilled_training = False # must set this True to train w/ distillation token
- @torch.jit.ignore
- def get_classifier(self) -> nn.Module:
- return self.head, self.head_dist
- def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
- self.num_classes = num_classes
- if global_pool is not None:
- self.global_pool = global_pool
- self.head = NormLinear(
- self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else nn.Identity()
- self.head_dist = NormLinear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
- @torch.jit.ignore
- def set_distilled_training(self, enable=True):
- self.distilled_training = enable
- def forward_head(self, x, pre_logits: bool = False):
- if self.global_pool == 'avg':
- x = x.mean(dim=(-2, -1)) if self.use_conv else x.mean(dim=1)
- if pre_logits:
- return x
- x, x_dist = self.head(x), self.head_dist(x)
- if self.distilled_training and self.training and not torch.jit.is_scripting():
- # only return separate classification predictions when training in distilled mode
- return x, x_dist
- else:
- # during standard train/finetune, inference average the classifier predictions
- return (x + x_dist) / 2
- def checkpoint_filter_fn(state_dict, model):
- if 'model' in state_dict:
- state_dict = state_dict['model']
- # filter out attn biases, should not have been persistent
- state_dict = {k: v for k, v in state_dict.items() if 'attention_bias_idxs' not in k}
- # NOTE: old weight conversion code, disabled
- # D = model.state_dict()
- # out_dict = {}
- # for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()):
- # if va.ndim == 4 and vb.ndim == 2:
- # vb = vb[:, :, None, None]
- # if va.shape != vb.shape:
- # # head or first-conv shapes may change for fine-tune
- # assert 'head' in ka or 'stem.conv1.linear' in ka
- # out_dict[ka] = vb
- return state_dict
- model_cfgs = dict(
- levit_128s=dict(
- embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 6, 8), depth=(2, 3, 4)),
- levit_128=dict(
- embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 8, 12), depth=(4, 4, 4)),
- levit_192=dict(
- embed_dim=(192, 288, 384), key_dim=32, num_heads=(3, 5, 6), depth=(4, 4, 4)),
- levit_256=dict(
- embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 4, 4)),
- levit_384=dict(
- embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4)),
- # stride-8 stem experiments
- levit_384_s8=dict(
- embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4),
- act_layer='silu', stem_type='s8'),
- levit_512_s8=dict(
- embed_dim=(512, 640, 896), key_dim=64, num_heads=(8, 10, 14), depth=(4, 4, 4),
- act_layer='silu', stem_type='s8'),
- # wider experiments
- levit_512=dict(
- embed_dim=(512, 768, 1024), key_dim=64, num_heads=(8, 12, 16), depth=(4, 4, 4), act_layer='silu'),
- # deeper experiments
- levit_256d=dict(
- embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 8, 6), act_layer='silu'),
- levit_512d=dict(
- embed_dim=(512, 640, 768), key_dim=64, num_heads=(8, 10, 12), depth=(4, 8, 6), act_layer='silu'),
- )
- def create_levit(variant, cfg_variant=None, pretrained=False, distilled=True, **kwargs):
- is_conv = '_conv' in variant
- out_indices = kwargs.pop('out_indices', (0, 1, 2))
- if kwargs.get('features_only', False) and not is_conv:
- kwargs.setdefault('feature_cls', 'getter')
- if cfg_variant is None:
- if variant in model_cfgs:
- cfg_variant = variant
- elif is_conv:
- cfg_variant = variant.replace('_conv', '')
- model_cfg = dict(model_cfgs[cfg_variant], **kwargs)
- model = build_model_with_cfg(
- LevitDistilled if distilled else Levit,
- variant,
- pretrained,
- pretrained_filter_fn=checkpoint_filter_fn,
- feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
- **model_cfg,
- )
- return model
- def _cfg(url='', **kwargs):
- return {
- 'url': url,
- 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
- 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
- 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
- 'first_conv': 'stem.conv1.linear', 'classifier': ('head.linear', 'head_dist.linear'),
- 'license': 'apache-2.0',
- **kwargs
- }
- default_cfgs = generate_default_cfgs({
- # weights in nn.Linear mode
- 'levit_128s.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- ),
- 'levit_128.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- ),
- 'levit_192.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- ),
- 'levit_256.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- ),
- 'levit_384.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- ),
- # weights in nn.Conv2d mode
- 'levit_conv_128s.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- pool_size=(4, 4),
- ),
- 'levit_conv_128.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- pool_size=(4, 4),
- ),
- 'levit_conv_192.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- pool_size=(4, 4),
- ),
- 'levit_conv_256.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- pool_size=(4, 4),
- ),
- 'levit_conv_384.fb_dist_in1k': _cfg(
- hf_hub_id='timm/',
- pool_size=(4, 4),
- ),
- 'levit_384_s8.untrained': _cfg(classifier='head.linear'),
- 'levit_512_s8.untrained': _cfg(classifier='head.linear'),
- 'levit_512.untrained': _cfg(classifier='head.linear'),
- 'levit_256d.untrained': _cfg(classifier='head.linear'),
- 'levit_512d.untrained': _cfg(classifier='head.linear'),
- 'levit_conv_384_s8.untrained': _cfg(classifier='head.linear'),
- 'levit_conv_512_s8.untrained': _cfg(classifier='head.linear'),
- 'levit_conv_512.untrained': _cfg(classifier='head.linear'),
- 'levit_conv_256d.untrained': _cfg(classifier='head.linear'),
- 'levit_conv_512d.untrained': _cfg(classifier='head.linear'),
- })
- @register_model
- def levit_128s(pretrained=False, **kwargs) -> Levit:
- return create_levit('levit_128s', pretrained=pretrained, **kwargs)
- @register_model
- def levit_128(pretrained=False, **kwargs) -> Levit:
- return create_levit('levit_128', pretrained=pretrained, **kwargs)
- @register_model
- def levit_192(pretrained=False, **kwargs) -> Levit:
- return create_levit('levit_192', pretrained=pretrained, **kwargs)
- @register_model
- def levit_256(pretrained=False, **kwargs) -> Levit:
- return create_levit('levit_256', pretrained=pretrained, **kwargs)
- @register_model
- def levit_384(pretrained=False, **kwargs) -> Levit:
- return create_levit('levit_384', pretrained=pretrained, **kwargs)
- @register_model
- def levit_384_s8(pretrained=False, **kwargs) -> Levit:
- return create_levit('levit_384_s8', pretrained=pretrained, **kwargs)
- @register_model
- def levit_512_s8(pretrained=False, **kwargs) -> Levit:
- return create_levit('levit_512_s8', pretrained=pretrained, distilled=False, **kwargs)
- @register_model
- def levit_512(pretrained=False, **kwargs) -> Levit:
- return create_levit('levit_512', pretrained=pretrained, distilled=False, **kwargs)
- @register_model
- def levit_256d(pretrained=False, **kwargs) -> Levit:
- return create_levit('levit_256d', pretrained=pretrained, distilled=False, **kwargs)
- @register_model
- def levit_512d(pretrained=False, **kwargs) -> Levit:
- return create_levit('levit_512d', pretrained=pretrained, distilled=False, **kwargs)
- @register_model
- def levit_conv_128s(pretrained=False, **kwargs) -> Levit:
- return create_levit('levit_conv_128s', pretrained=pretrained, use_conv=True, **kwargs)
- @register_model
- def levit_conv_128(pretrained=False, **kwargs) -> Levit:
- return create_levit('levit_conv_128', pretrained=pretrained, use_conv=True, **kwargs)
- @register_model
- def levit_conv_192(pretrained=False, **kwargs) -> Levit:
- return create_levit('levit_conv_192', pretrained=pretrained, use_conv=True, **kwargs)
- @register_model
- def levit_conv_256(pretrained=False, **kwargs) -> Levit:
- return create_levit('levit_conv_256', pretrained=pretrained, use_conv=True, **kwargs)
- @register_model
- def levit_conv_384(pretrained=False, **kwargs) -> Levit:
- return create_levit('levit_conv_384', pretrained=pretrained, use_conv=True, **kwargs)
- @register_model
- def levit_conv_384_s8(pretrained=False, **kwargs) -> Levit:
- return create_levit('levit_conv_384_s8', pretrained=pretrained, use_conv=True, **kwargs)
- @register_model
- def levit_conv_512_s8(pretrained=False, **kwargs) -> Levit:
- return create_levit('levit_conv_512_s8', pretrained=pretrained, use_conv=True, distilled=False, **kwargs)
- @register_model
- def levit_conv_512(pretrained=False, **kwargs) -> Levit:
- return create_levit('levit_conv_512', pretrained=pretrained, use_conv=True, distilled=False, **kwargs)
- @register_model
- def levit_conv_256d(pretrained=False, **kwargs) -> Levit:
- return create_levit('levit_conv_256d', pretrained=pretrained, use_conv=True, distilled=False, **kwargs)
- @register_model
- def levit_conv_512d(pretrained=False, **kwargs) -> Levit:
- return create_levit('levit_conv_512d', pretrained=pretrained, use_conv=True, distilled=False, **kwargs)
|