| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693 |
- """ RepViT
- Paper: `RepViT: Revisiting Mobile CNN From ViT Perspective`
- - https://arxiv.org/abs/2307.09283
- @misc{wang2023repvit,
- title={RepViT: Revisiting Mobile CNN From ViT Perspective},
- author={Ao Wang and Hui Chen and Zijia Lin and Hengjun Pu and Guiguang Ding},
- year={2023},
- eprint={2307.09283},
- archivePrefix={arXiv},
- primaryClass={cs.CV}
- }
- Adapted from official impl at https://github.com/jameslahm/RepViT
- """
- from typing import List, Optional, Tuple, Union, Type
- import torch
- import torch.nn as nn
- from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
- from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple
- from ._builder import build_model_with_cfg
- from ._features import feature_take_indices
- from ._manipulate import checkpoint, checkpoint_seq
- from ._registry import register_model, generate_default_cfgs
- __all__ = ['RepVit']
- class ConvNorm(nn.Sequential):
- def __init__(
- self,
- in_dim: int,
- out_dim: int,
- ks: int = 1,
- stride: int = 1,
- pad: int = 0,
- dilation: int = 1,
- groups: int = 1,
- bn_weight_init: float = 1,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.add_module('c', nn.Conv2d(in_dim, out_dim, ks, stride, pad, dilation, groups, bias=False, **dd))
- self.add_module('bn', nn.BatchNorm2d(out_dim, **dd))
- nn.init.constant_(self.bn.weight, bn_weight_init)
- nn.init.constant_(self.bn.bias, 0)
- @torch.no_grad()
- def fuse(self):
- c, bn = self._modules.values()
- w = bn.weight / (bn.running_var + bn.eps) ** 0.5
- w = c.weight * w[:, None, None, None]
- b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
- m = nn.Conv2d(
- w.size(1) * self.c.groups,
- w.size(0),
- w.shape[2:],
- stride=self.c.stride,
- padding=self.c.padding,
- dilation=self.c.dilation,
- groups=self.c.groups,
- device=c.weight.device,
- )
- m.weight.data.copy_(w)
- m.bias.data.copy_(b)
- return m
- class NormLinear(nn.Sequential):
- def __init__(
- self,
- in_dim: int,
- out_dim: int,
- bias: bool = True,
- std: float = 0.02,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.add_module('bn', nn.BatchNorm1d(in_dim, **dd))
- self.add_module('l', nn.Linear(in_dim, out_dim, bias=bias, **dd))
- trunc_normal_(self.l.weight, std=std)
- if bias:
- nn.init.constant_(self.l.bias, 0)
- @torch.no_grad()
- def fuse(self):
- bn, l = self._modules.values()
- w = bn.weight / (bn.running_var + bn.eps) ** 0.5
- b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5
- w = l.weight * w[None, :]
- if l.bias is None:
- b = b @ self.l.weight.T
- else:
- b = (l.weight @ b[:, None]).view(-1) + self.l.bias
- m = nn.Linear(w.size(1), w.size(0), device=l.weight.device)
- m.weight.data.copy_(w)
- m.bias.data.copy_(b)
- return m
- class RepVggDw(nn.Module):
- def __init__(
- self,
- ed: int,
- kernel_size: int,
- legacy: bool = False,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.conv = ConvNorm(ed, ed, kernel_size, 1, (kernel_size - 1) // 2, groups=ed, **dd)
- if legacy:
- self.conv1 = ConvNorm(ed, ed, 1, 1, 0, groups=ed, **dd)
- # Make torchscript happy.
- self.bn = nn.Identity()
- else:
- self.conv1 = nn.Conv2d(ed, ed, 1, 1, 0, groups=ed, **dd)
- self.bn = nn.BatchNorm2d(ed, **dd)
- self.dim = ed
- self.legacy = legacy
- def forward(self, x):
- return self.bn(self.conv(x) + self.conv1(x) + x)
- @torch.no_grad()
- def fuse(self):
- conv = self.conv.fuse()
- if self.legacy:
- conv1 = self.conv1.fuse()
- else:
- conv1 = self.conv1
- conv_w = conv.weight
- conv_b = conv.bias
- conv1_w = conv1.weight
- conv1_b = conv1.bias
- conv1_w = nn.functional.pad(conv1_w, [1, 1, 1, 1])
- identity = nn.functional.pad(
- torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device), [1, 1, 1, 1]
- )
- final_conv_w = conv_w + conv1_w + identity
- final_conv_b = conv_b + conv1_b
- conv.weight.data.copy_(final_conv_w)
- conv.bias.data.copy_(final_conv_b)
- if not self.legacy:
- bn = self.bn
- w = bn.weight / (bn.running_var + bn.eps) ** 0.5
- w = conv.weight * w[:, None, None, None]
- b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / (bn.running_var + bn.eps) ** 0.5
- conv.weight.data.copy_(w)
- conv.bias.data.copy_(b)
- return conv
- class RepVitMlp(nn.Module):
- def __init__(
- self,
- in_dim: int,
- hidden_dim: int,
- act_layer: Type[nn.Module],
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.conv1 = ConvNorm(in_dim, hidden_dim, 1, 1, 0, **dd)
- self.act = act_layer()
- self.conv2 = ConvNorm(hidden_dim, in_dim, 1, 1, 0, bn_weight_init=0, **dd)
- def forward(self, x):
- return self.conv2(self.act(self.conv1(x)))
- class RepViTBlock(nn.Module):
- def __init__(
- self,
- in_dim: int,
- mlp_ratio: float,
- kernel_size: int,
- use_se: bool,
- act_layer: Type[nn.Module],
- legacy: bool = False,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.token_mixer = RepVggDw(in_dim, kernel_size, legacy, **dd)
- self.se = SqueezeExcite(in_dim, 0.25, **dd) if use_se else nn.Identity()
- self.channel_mixer = RepVitMlp(in_dim, in_dim * mlp_ratio, act_layer, **dd)
- def forward(self, x):
- x = self.token_mixer(x)
- x = self.se(x)
- identity = x
- x = self.channel_mixer(x)
- return identity + x
- class RepVitStem(nn.Module):
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- act_layer: Type[nn.Module],
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.conv1 = ConvNorm(in_chs, out_chs // 2, 3, 2, 1, **dd)
- self.act1 = act_layer()
- self.conv2 = ConvNorm(out_chs // 2, out_chs, 3, 2, 1, **dd)
- self.stride = 4
- def forward(self, x):
- return self.conv2(self.act1(self.conv1(x)))
- class RepVitDownsample(nn.Module):
- def __init__(
- self,
- in_dim: int,
- mlp_ratio: float,
- out_dim: int,
- kernel_size: int,
- act_layer: Type[nn.Module],
- legacy: bool = False,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.pre_block = RepViTBlock(
- in_dim,
- mlp_ratio,
- kernel_size,
- use_se=False,
- act_layer=act_layer,
- legacy=legacy,
- **dd,
- )
- self.spatial_downsample = ConvNorm(
- in_dim,
- in_dim,
- kernel_size,
- stride=2,
- pad=(kernel_size - 1) // 2,
- groups=in_dim,
- **dd,
- )
- self.channel_downsample = ConvNorm(in_dim, out_dim, 1, 1, **dd)
- self.ffn = RepVitMlp(out_dim, out_dim * mlp_ratio, act_layer, **dd)
- def forward(self, x):
- x = self.pre_block(x)
- x = self.spatial_downsample(x)
- x = self.channel_downsample(x)
- identity = x
- x = self.ffn(x)
- return x + identity
- class RepVitClassifier(nn.Module):
- def __init__(
- self,
- dim: int,
- num_classes: int,
- distillation: bool = False,
- drop: float = 0.0,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.head_drop = nn.Dropout(drop)
- self.head = NormLinear(dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
- self.distillation = distillation
- self.distilled_training = False
- self.num_classes = num_classes
- if distillation:
- self.head_dist = NormLinear(dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
- def forward(self, x):
- x = self.head_drop(x)
- if self.distillation:
- x1, x2 = self.head(x), self.head_dist(x)
- if self.training and self.distilled_training and not torch.jit.is_scripting():
- return x1, x2
- else:
- return (x1 + x2) / 2
- else:
- x = self.head(x)
- return x
- @torch.no_grad()
- def fuse(self):
- if not self.num_classes > 0:
- return nn.Identity()
- head = self.head.fuse()
- if self.distillation:
- head_dist = self.head_dist.fuse()
- head.weight += head_dist.weight
- head.bias += head_dist.bias
- head.weight /= 2
- head.bias /= 2
- return head
- else:
- return head
- class RepVitStage(nn.Module):
- def __init__(
- self,
- in_dim: int,
- out_dim: int,
- depth: int,
- mlp_ratio: float,
- act_layer: Type[nn.Module],
- kernel_size: int = 3,
- downsample: bool = True,
- legacy: bool = False,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- if downsample:
- self.downsample = RepVitDownsample(
- in_dim,
- mlp_ratio,
- out_dim,
- kernel_size,
- act_layer=act_layer,
- legacy=legacy,
- **dd,
- )
- else:
- assert in_dim == out_dim
- self.downsample = nn.Identity()
- blocks = []
- use_se = True
- for _ in range(depth):
- blocks.append(RepViTBlock(out_dim, mlp_ratio, kernel_size, use_se, act_layer, legacy, **dd))
- use_se = not use_se
- self.blocks = nn.Sequential(*blocks)
- def forward(self, x):
- x = self.downsample(x)
- x = self.blocks(x)
- return x
- class RepVit(nn.Module):
- def __init__(
- self,
- in_chans: int = 3,
- img_size: int = 224,
- embed_dim: Tuple[int, ...] = (48,),
- depth: Tuple[int, ...] = (2,),
- mlp_ratio: float = 2,
- global_pool: str = 'avg',
- kernel_size: int = 3,
- num_classes: int = 1000,
- act_layer: Type[nn.Module] = nn.GELU,
- distillation: bool = True,
- drop_rate: float = 0.0,
- legacy: bool = False,
- device=None,
- dtype=None,
- ):
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- self.grad_checkpointing = False
- self.global_pool = global_pool
- self.embed_dim = embed_dim
- self.num_classes = num_classes
- self.in_chans = in_chans
- in_dim = embed_dim[0]
- self.stem = RepVitStem(in_chans, in_dim, act_layer, **dd)
- stride = self.stem.stride
- resolution = tuple([i // p for i, p in zip(to_2tuple(img_size), to_2tuple(stride))])
- num_stages = len(embed_dim)
- mlp_ratios = to_ntuple(num_stages)(mlp_ratio)
- self.feature_info = []
- stages = []
- for i in range(num_stages):
- downsample = True if i != 0 else False
- stages.append(
- RepVitStage(
- in_dim,
- embed_dim[i],
- depth[i],
- mlp_ratio=mlp_ratios[i],
- act_layer=act_layer,
- kernel_size=kernel_size,
- downsample=downsample,
- legacy=legacy,
- **dd,
- )
- )
- stage_stride = 2 if downsample else 1
- stride *= stage_stride
- resolution = tuple([(r - 1) // stage_stride + 1 for r in resolution])
- self.feature_info += [dict(num_chs=embed_dim[i], reduction=stride, module=f'stages.{i}')]
- in_dim = embed_dim[i]
- self.stages = nn.Sequential(*stages)
- self.num_features = self.head_hidden_size = embed_dim[-1]
- self.head_drop = nn.Dropout(drop_rate)
- self.head = RepVitClassifier(embed_dim[-1], num_classes, distillation, **dd)
- @torch.jit.ignore
- def group_matcher(self, coarse=False):
- matcher = dict(stem=r'^stem', blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]) # stem and embed
- return matcher
- @torch.jit.ignore
- def set_grad_checkpointing(self, enable=True):
- self.grad_checkpointing = enable
- @torch.jit.ignore
- def get_classifier(self) -> nn.Module:
- return self.head
- def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, distillation: bool = False, device=None, dtype=None):
- self.num_classes = num_classes
- if global_pool is not None:
- self.global_pool = global_pool
- dd = {'device': device, 'dtype': dtype}
- self.head = RepVitClassifier(self.embed_dim[-1], num_classes, distillation, **dd)
- @torch.jit.ignore
- def set_distilled_training(self, enable=True):
- self.head.distilled_training = enable
- def forward_intermediates(
- self,
- x: torch.Tensor,
- indices: Optional[Union[int, List[int]]] = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
- ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
- """ Forward features that returns intermediates.
- Args:
- x: Input image tensor
- indices: Take last n blocks if int, all if None, select matching indices if sequence
- norm: Apply norm layer to compatible intermediates
- stop_early: Stop iterating over blocks when last desired intermediate hit
- output_fmt: Shape of intermediate feature outputs
- intermediates_only: Only return intermediate features
- Returns:
- """
- assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
- intermediates = []
- take_indices, max_index = feature_take_indices(len(self.stages), indices)
- # forward pass
- x = self.stem(x)
- if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
- stages = self.stages
- else:
- stages = self.stages[:max_index + 1]
- for feat_idx, stage in enumerate(stages):
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint(stage, x)
- else:
- x = stage(x)
- if feat_idx in take_indices:
- intermediates.append(x)
- if intermediates_only:
- return intermediates
- return x, intermediates
- def prune_intermediate_layers(
- self,
- indices: Union[int, List[int]] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
- ):
- """ Prune layers not required for specified intermediates.
- """
- take_indices, max_index = feature_take_indices(len(self.stages), indices)
- self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
- if prune_head:
- self.reset_classifier(0, '')
- return take_indices
- def forward_features(self, x):
- x = self.stem(x)
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint_seq(self.stages, x)
- else:
- x = self.stages(x)
- return x
- def forward_head(self, x, pre_logits: bool = False):
- if self.global_pool == 'avg':
- x = x.mean((2, 3), keepdim=False)
- x = self.head_drop(x)
- if pre_logits:
- return x
- return self.head(x)
- def forward(self, x):
- x = self.forward_features(x)
- x = self.forward_head(x)
- return x
- @torch.no_grad()
- def fuse(self):
- def fuse_children(net):
- for child_name, child in net.named_children():
- if hasattr(child, 'fuse'):
- fused = child.fuse()
- setattr(net, child_name, fused)
- fuse_children(fused)
- else:
- fuse_children(child)
- fuse_children(self)
- def _cfg(url='', **kwargs):
- return {
- 'url': url,
- 'num_classes': 1000,
- 'input_size': (3, 224, 224),
- 'pool_size': (7, 7),
- 'crop_pct': 0.95,
- 'interpolation': 'bicubic',
- 'mean': IMAGENET_DEFAULT_MEAN,
- 'std': IMAGENET_DEFAULT_STD,
- 'first_conv': 'stem.conv1.c',
- 'classifier': ('head.head.l', 'head.head_dist.l'),
- 'license': 'apache-2.0',
- **kwargs,
- }
- default_cfgs = generate_default_cfgs(
- {
- 'repvit_m1.dist_in1k': _cfg(
- hf_hub_id='timm/',
- ),
- 'repvit_m2.dist_in1k': _cfg(
- hf_hub_id='timm/',
- ),
- 'repvit_m3.dist_in1k': _cfg(
- hf_hub_id='timm/',
- ),
- 'repvit_m0_9.dist_300e_in1k': _cfg(
- hf_hub_id='timm/',
- ),
- 'repvit_m0_9.dist_450e_in1k': _cfg(
- hf_hub_id='timm/',
- ),
- 'repvit_m1_0.dist_300e_in1k': _cfg(
- hf_hub_id='timm/',
- ),
- 'repvit_m1_0.dist_450e_in1k': _cfg(
- hf_hub_id='timm/',
- ),
- 'repvit_m1_1.dist_300e_in1k': _cfg(
- hf_hub_id='timm/',
- ),
- 'repvit_m1_1.dist_450e_in1k': _cfg(
- hf_hub_id='timm/',
- ),
- 'repvit_m1_5.dist_300e_in1k': _cfg(
- hf_hub_id='timm/',
- ),
- 'repvit_m1_5.dist_450e_in1k': _cfg(
- hf_hub_id='timm/',
- ),
- 'repvit_m2_3.dist_300e_in1k': _cfg(
- hf_hub_id='timm/',
- ),
- 'repvit_m2_3.dist_450e_in1k': _cfg(
- hf_hub_id='timm/',
- ),
- }
- )
- def _create_repvit(variant, pretrained=False, **kwargs):
- out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
- model = build_model_with_cfg(
- RepVit,
- variant,
- pretrained,
- feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
- **kwargs,
- )
- return model
- @register_model
- def repvit_m1(pretrained=False, **kwargs):
- """
- Constructs a RepViT-M1 model
- """
- model_args = dict(embed_dim=(48, 96, 192, 384), depth=(2, 2, 14, 2), legacy=True)
- return _create_repvit('repvit_m1', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def repvit_m2(pretrained=False, **kwargs):
- """
- Constructs a RepViT-M2 model
- """
- model_args = dict(embed_dim=(64, 128, 256, 512), depth=(2, 2, 12, 2), legacy=True)
- return _create_repvit('repvit_m2', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def repvit_m3(pretrained=False, **kwargs):
- """
- Constructs a RepViT-M3 model
- """
- model_args = dict(embed_dim=(64, 128, 256, 512), depth=(4, 4, 18, 2), legacy=True)
- return _create_repvit('repvit_m3', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def repvit_m0_9(pretrained=False, **kwargs):
- """
- Constructs a RepViT-M0.9 model
- """
- model_args = dict(embed_dim=(48, 96, 192, 384), depth=(2, 2, 14, 2))
- return _create_repvit('repvit_m0_9', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def repvit_m1_0(pretrained=False, **kwargs):
- """
- Constructs a RepViT-M1.0 model
- """
- model_args = dict(embed_dim=(56, 112, 224, 448), depth=(2, 2, 14, 2))
- return _create_repvit('repvit_m1_0', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def repvit_m1_1(pretrained=False, **kwargs):
- """
- Constructs a RepViT-M1.1 model
- """
- model_args = dict(embed_dim=(64, 128, 256, 512), depth=(2, 2, 12, 2))
- return _create_repvit('repvit_m1_1', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def repvit_m1_5(pretrained=False, **kwargs):
- """
- Constructs a RepViT-M1.5 model
- """
- model_args = dict(embed_dim=(64, 128, 256, 512), depth=(4, 4, 24, 4))
- return _create_repvit('repvit_m1_5', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def repvit_m2_3(pretrained=False, **kwargs):
- """
- Constructs a RepViT-M2.3 model
- """
- model_args = dict(embed_dim=(80, 160, 320, 640), depth=(6, 6, 34, 2))
- return _create_repvit('repvit_m2_3', pretrained=pretrained, **dict(model_args, **kwargs))
|