| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410 |
- """ ConvNeXt
- Papers:
- * `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
- @Article{liu2022convnet,
- author = {Zhuang Liu and Hanzi Mao and Chao-Yuan Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
- title = {A ConvNet for the 2020s},
- journal = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
- year = {2022},
- }
- * `ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808
- @article{Woo2023ConvNeXtV2,
- title={ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders},
- author={Sanghyun Woo, Shoubhik Debnath, Ronghang Hu, Xinlei Chen, Zhuang Liu, In So Kweon and Saining Xie},
- year={2023},
- journal={arXiv preprint arXiv:2301.00808},
- }
- Original code and weights from:
- * https://github.com/facebookresearch/ConvNeXt, original copyright below
- * https://github.com/facebookresearch/ConvNeXt-V2, original copyright below
- Model defs atto, femto, pico, nano and _ols / _hnf variants are timm originals.
- Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman
- """
- # ConvNeXt
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the MIT license
- # ConvNeXt-V2
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree (Attribution-NonCommercial 4.0 International (CC BY-NC 4.0))
- # No code was used directly from ConvNeXt-V2, however the weights are CC BY-NC 4.0 so beware if using commercially.
- from functools import partial
- from typing import Callable, Dict, List, Optional, Tuple, Union
- import torch
- import torch.nn as nn
- from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
- from timm.layers import (
- trunc_normal_,
- AvgPool2dSame,
- DropPath,
- calculate_drop_path_rates,
- Mlp,
- GlobalResponseNormMlp,
- LayerNorm2d,
- LayerNorm,
- RmsNorm2d,
- RmsNorm,
- SimpleNorm2d,
- SimpleNorm,
- create_conv2d,
- get_act_layer,
- get_norm_layer,
- make_divisible,
- to_ntuple,
- NormMlpClassifierHead,
- ClassifierHead,
- )
- from ._builder import build_model_with_cfg
- from ._features import feature_take_indices
- from ._manipulate import named_apply, checkpoint_seq
- from ._registry import generate_default_cfgs, register_model, register_model_deprecations
- __all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
- class Downsample(nn.Module):
- """Downsample module for ConvNeXt."""
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- stride: int = 1,
- dilation: int = 1,
- device=None,
- dtype=None,
- ) -> None:
- """Initialize Downsample module.
- Args:
- in_chs: Number of input channels.
- out_chs: Number of output channels.
- stride: Stride for downsampling.
- dilation: Dilation rate.
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- avg_stride = stride if dilation == 1 else 1
- if stride > 1 or dilation > 1:
- avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
- self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
- else:
- self.pool = nn.Identity()
- if in_chs != out_chs:
- self.conv = create_conv2d(in_chs, out_chs, 1, stride=1, **dd)
- else:
- self.conv = nn.Identity()
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Forward pass."""
- x = self.pool(x)
- x = self.conv(x)
- return x
- class ConvNeXtBlock(nn.Module):
- """ConvNeXt Block.
- There are two equivalent implementations:
- (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
- (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
- Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
- choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
- is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
- """
- def __init__(
- self,
- in_chs: int,
- out_chs: Optional[int] = None,
- kernel_size: int = 7,
- stride: int = 1,
- dilation: Union[int, Tuple[int, int]] = (1, 1),
- mlp_ratio: float = 4,
- conv_mlp: bool = False,
- conv_bias: bool = True,
- use_grn: bool = False,
- ls_init_value: Optional[float] = 1e-6,
- act_layer: Union[str, Callable] = 'gelu',
- norm_layer: Optional[Callable] = None,
- drop_path: float = 0.,
- device=None,
- dtype=None,
- ):
- """
- Args:
- in_chs: Block input channels.
- out_chs: Block output channels (same as in_chs if None).
- kernel_size: Depthwise convolution kernel size.
- stride: Stride of depthwise convolution.
- dilation: Tuple specifying input and output dilation of block.
- mlp_ratio: MLP expansion ratio.
- conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
- conv_bias: Apply bias for all convolution (linear) layers.
- use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
- ls_init_value: Layer-scale init values, layer-scale applied if not None.
- act_layer: Activation layer.
- norm_layer: Normalization layer (defaults to LN if not specified).
- drop_path: Stochastic depth probability.
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- out_chs = out_chs or in_chs
- dilation = to_ntuple(2)(dilation)
- act_layer = get_act_layer(act_layer)
- if not norm_layer:
- norm_layer = LayerNorm2d if conv_mlp else LayerNorm
- mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp)
- self.use_conv_mlp = conv_mlp
- self.conv_dw = create_conv2d(
- in_chs,
- out_chs,
- kernel_size=kernel_size,
- stride=stride,
- dilation=dilation[0],
- depthwise=True,
- bias=conv_bias,
- **dd,
- )
- self.norm = norm_layer(out_chs, **dd)
- self.mlp = mlp_layer(
- out_chs,
- int(mlp_ratio * out_chs),
- act_layer=act_layer,
- **dd,
- )
- self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs, **dd)) if ls_init_value is not None else None
- if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
- self.shortcut = Downsample(in_chs, out_chs, stride=stride, dilation=dilation[0], **dd)
- else:
- self.shortcut = nn.Identity()
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Forward pass."""
- shortcut = x
- x = self.conv_dw(x)
- if self.use_conv_mlp:
- x = self.norm(x)
- x = self.mlp(x)
- else:
- x = x.permute(0, 2, 3, 1)
- x = self.norm(x)
- x = self.mlp(x)
- x = x.permute(0, 3, 1, 2)
- if self.gamma is not None:
- x = x.mul(self.gamma.reshape(1, -1, 1, 1))
- x = self.drop_path(x) + self.shortcut(shortcut)
- return x
- class ConvNeXtStage(nn.Module):
- """ConvNeXt stage (multiple blocks)."""
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- kernel_size: int = 7,
- stride: int = 2,
- depth: int = 2,
- dilation: Tuple[int, int] = (1, 1),
- drop_path_rates: Optional[List[float]] = None,
- ls_init_value: float = 1.0,
- conv_mlp: bool = False,
- conv_bias: bool = True,
- use_grn: bool = False,
- act_layer: Union[str, Callable] = 'gelu',
- norm_layer: Optional[Callable] = None,
- norm_layer_cl: Optional[Callable] = None,
- device=None,
- dtype=None,
- ) -> None:
- """Initialize ConvNeXt stage.
- Args:
- in_chs: Number of input channels.
- out_chs: Number of output channels.
- kernel_size: Kernel size for depthwise convolution.
- stride: Stride for downsampling.
- depth: Number of blocks in stage.
- dilation: Dilation rates.
- drop_path_rates: Drop path rates for each block.
- ls_init_value: Initial value for layer scale.
- conv_mlp: Use convolutional MLP.
- conv_bias: Use bias in convolutions.
- use_grn: Use global response normalization.
- act_layer: Activation layer.
- norm_layer: Normalization layer.
- norm_layer_cl: Normalization layer for channels last.
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.grad_checkpointing = False
- if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]:
- ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1
- pad = 'same' if dilation[1] > 1 else 0 # same padding needed if dilation used
- self.downsample = nn.Sequential(
- norm_layer(in_chs, **dd),
- create_conv2d(
- in_chs,
- out_chs,
- kernel_size=ds_ks,
- stride=stride,
- dilation=dilation[0],
- padding=pad,
- bias=conv_bias,
- **dd,
- ),
- )
- in_chs = out_chs
- else:
- self.downsample = nn.Identity()
- drop_path_rates = drop_path_rates or [0.] * depth
- stage_blocks = []
- for i in range(depth):
- stage_blocks.append(ConvNeXtBlock(
- in_chs=in_chs,
- out_chs=out_chs,
- kernel_size=kernel_size,
- dilation=dilation[1],
- drop_path=drop_path_rates[i],
- ls_init_value=ls_init_value,
- conv_mlp=conv_mlp,
- conv_bias=conv_bias,
- use_grn=use_grn,
- act_layer=act_layer,
- norm_layer=norm_layer if conv_mlp else norm_layer_cl,
- **dd,
- ))
- in_chs = out_chs
- self.blocks = nn.Sequential(*stage_blocks)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Forward pass."""
- x = self.downsample(x)
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint_seq(self.blocks, x)
- else:
- x = self.blocks(x)
- return x
- # map of norm layers with NCHW (2D) and channels last variants
- _NORM_MAP = {
- 'layernorm': (LayerNorm2d, LayerNorm),
- 'layernorm2d': (LayerNorm2d, LayerNorm),
- 'simplenorm': (SimpleNorm2d, SimpleNorm),
- 'simplenorm2d': (SimpleNorm2d, SimpleNorm),
- 'rmsnorm': (RmsNorm2d, RmsNorm),
- 'rmsnorm2d': (RmsNorm2d, RmsNorm),
- }
- def _get_norm_layers(norm_layer: Union[Callable, str], conv_mlp: bool, norm_eps: float):
- norm_layer = norm_layer or 'layernorm'
- if norm_layer in _NORM_MAP:
- norm_layer_cl = _NORM_MAP[norm_layer][0] if conv_mlp else _NORM_MAP[norm_layer][1]
- norm_layer = _NORM_MAP[norm_layer][0]
- if norm_eps is not None:
- norm_layer = partial(norm_layer, eps=norm_eps)
- norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
- else:
- assert conv_mlp, \
- 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
- norm_layer = get_norm_layer(norm_layer)
- norm_layer_cl = norm_layer
- if norm_eps is not None:
- norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
- return norm_layer, norm_layer_cl
- class ConvNeXt(nn.Module):
- """ConvNeXt model architecture.
- A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
- """
- def __init__(
- self,
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- output_stride: int = 32,
- depths: Tuple[int, ...] = (3, 3, 9, 3),
- dims: Tuple[int, ...] = (96, 192, 384, 768),
- kernel_sizes: Union[int, Tuple[int, ...]] = 7,
- ls_init_value: Optional[float] = 1e-6,
- stem_type: str = 'patch',
- patch_size: int = 4,
- head_init_scale: float = 1.,
- head_norm_first: bool = False,
- head_hidden_size: Optional[int] = None,
- conv_mlp: bool = False,
- conv_bias: bool = True,
- use_grn: bool = False,
- act_layer: Union[str, Callable] = 'gelu',
- norm_layer: Optional[Union[str, Callable]] = None,
- norm_eps: Optional[float] = None,
- drop_rate: float = 0.,
- drop_path_rate: float = 0.,
- device=None,
- dtype=None,
- ):
- """
- Args:
- in_chans: Number of input image channels.
- num_classes: Number of classes for classification head.
- global_pool: Global pooling type.
- output_stride: Output stride of network, one of (8, 16, 32).
- depths: Number of blocks at each stage.
- dims: Feature dimension at each stage.
- kernel_sizes: Depthwise convolution kernel-sizes for each stage.
- ls_init_value: Init value for Layer Scale, disabled if None.
- stem_type: Type of stem.
- patch_size: Stem patch size for patch stem.
- head_init_scale: Init scaling value for classifier weights and biases.
- head_norm_first: Apply normalization before global pool + head.
- head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False.
- conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last.
- conv_bias: Use bias layers w/ all convolutions.
- use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP.
- act_layer: Activation layer type.
- norm_layer: Normalization layer type.
- drop_rate: Head pre-classifier dropout rate.
- drop_path_rate: Stochastic depth drop rate.
- """
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- assert output_stride in (8, 16, 32)
- kernel_sizes = to_ntuple(4)(kernel_sizes)
- norm_layer, norm_layer_cl = _get_norm_layers(norm_layer, conv_mlp, norm_eps)
- act_layer = get_act_layer(act_layer)
- self.num_classes = num_classes
- self.in_chans = in_chans
- self.drop_rate = drop_rate
- self.feature_info = []
- assert stem_type in ('patch', 'overlap', 'overlap_tiered', 'overlap_act')
- if stem_type == 'patch':
- # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
- self.stem = nn.Sequential(
- nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias, **dd),
- norm_layer(dims[0], **dd),
- )
- stem_stride = patch_size
- else:
- mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0]
- self.stem = nn.Sequential(*filter(None, [
- nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias, **dd),
- act_layer() if 'act' in stem_type else None,
- nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias, **dd),
- norm_layer(dims[0], **dd),
- ]))
- stem_stride = 4
- self.stages = nn.Sequential()
- dp_rates = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
- stages = []
- prev_chs = dims[0]
- curr_stride = stem_stride
- dilation = 1
- # 4 feature resolution stages, each consisting of multiple residual blocks
- for i in range(4):
- stride = 2 if curr_stride == 2 or i > 0 else 1
- if curr_stride >= output_stride and stride > 1:
- dilation *= stride
- stride = 1
- curr_stride *= stride
- first_dilation = 1 if dilation in (1, 2) else 2
- out_chs = dims[i]
- stages.append(ConvNeXtStage(
- prev_chs,
- out_chs,
- kernel_size=kernel_sizes[i],
- stride=stride,
- dilation=(first_dilation, dilation),
- depth=depths[i],
- drop_path_rates=dp_rates[i],
- ls_init_value=ls_init_value,
- conv_mlp=conv_mlp,
- conv_bias=conv_bias,
- use_grn=use_grn,
- act_layer=act_layer,
- norm_layer=norm_layer,
- norm_layer_cl=norm_layer_cl,
- **dd,
- ))
- prev_chs = out_chs
- # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
- self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
- self.stages = nn.Sequential(*stages)
- self.num_features = self.head_hidden_size = prev_chs
- # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
- # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
- if head_norm_first:
- assert not head_hidden_size
- self.norm_pre = norm_layer(self.num_features, **dd)
- self.head = ClassifierHead(
- self.num_features,
- num_classes,
- pool_type=global_pool,
- drop_rate=self.drop_rate,
- **dd,
- )
- else:
- self.norm_pre = nn.Identity()
- self.head = NormMlpClassifierHead(
- self.num_features,
- num_classes,
- hidden_size=head_hidden_size,
- pool_type=global_pool,
- drop_rate=self.drop_rate,
- norm_layer=norm_layer,
- act_layer='gelu',
- **dd,
- )
- self.head_hidden_size = self.head.num_features
- named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
- @torch.jit.ignore
- def group_matcher(self, coarse: bool = False) -> Dict[str, Union[str, List]]:
- """Create regex patterns for parameter grouping.
- Args:
- coarse: Use coarse grouping.
- Returns:
- Dictionary mapping group names to regex patterns.
- """
- return dict(
- stem=r'^stem',
- blocks=r'^stages\.(\d+)' if coarse else [
- (r'^stages\.(\d+)\.downsample', (0,)), # blocks
- (r'^stages\.(\d+)\.blocks\.(\d+)', None),
- (r'^norm_pre', (99999,))
- ]
- )
- @torch.jit.ignore
- def set_grad_checkpointing(self, enable: bool = True) -> None:
- """Enable or disable gradient checkpointing.
- Args:
- enable: Whether to enable gradient checkpointing.
- """
- for s in self.stages:
- s.grad_checkpointing = enable
- @torch.jit.ignore
- def get_classifier(self) -> nn.Module:
- """Get the classifier module."""
- return self.head.fc
- def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
- """Reset the classifier head.
- Args:
- num_classes: Number of classes for new classifier.
- global_pool: Global pooling type.
- """
- self.num_classes = num_classes
- self.head.reset(num_classes, global_pool)
- def forward_intermediates(
- self,
- x: torch.Tensor,
- indices: Optional[Union[int, List[int]]] = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
- ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
- """Forward features that returns intermediates.
- Args:
- x: Input image tensor.
- indices: Take last n blocks if int, all if None, select matching indices if sequence.
- norm: Apply norm layer to compatible intermediates.
- stop_early: Stop iterating over blocks when last desired intermediate hit.
- output_fmt: Shape of intermediate feature outputs.
- intermediates_only: Only return intermediate features.
- Returns:
- List of intermediate features or tuple of (final features, intermediates).
- """
- assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
- intermediates = []
- take_indices, max_index = feature_take_indices(len(self.stages), indices)
- # forward pass
- x = self.stem(x)
- last_idx = len(self.stages) - 1
- if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
- stages = self.stages
- else:
- stages = self.stages[:max_index + 1]
- for feat_idx, stage in enumerate(stages):
- x = stage(x)
- if feat_idx in take_indices:
- if norm and feat_idx == last_idx:
- intermediates.append(self.norm_pre(x))
- else:
- intermediates.append(x)
- if intermediates_only:
- return intermediates
- if feat_idx == last_idx:
- x = self.norm_pre(x)
- return x, intermediates
- def prune_intermediate_layers(
- self,
- indices: Union[int, List[int]] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
- ) -> List[int]:
- """Prune layers not required for specified intermediates.
- Args:
- indices: Indices of intermediate layers to keep.
- prune_norm: Whether to prune normalization layer.
- prune_head: Whether to prune the classifier head.
- Returns:
- List of indices that were kept.
- """
- take_indices, max_index = feature_take_indices(len(self.stages), indices)
- self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
- if prune_norm:
- self.norm_pre = nn.Identity()
- if prune_head:
- self.reset_classifier(0, '')
- return take_indices
- def forward_features(self, x: torch.Tensor) -> torch.Tensor:
- """Forward pass through feature extraction layers."""
- x = self.stem(x)
- x = self.stages(x)
- x = self.norm_pre(x)
- return x
- def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
- """Forward pass through classifier head.
- Args:
- x: Feature tensor.
- pre_logits: Return features before final classifier.
- Returns:
- Output tensor.
- """
- return self.head(x, pre_logits=True) if pre_logits else self.head(x)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Forward pass."""
- x = self.forward_features(x)
- x = self.forward_head(x)
- return x
- def _init_weights(module: nn.Module, name: Optional[str] = None, head_init_scale: float = 1.0) -> None:
- """Initialize model weights.
- Args:
- module: Module to initialize.
- name: Module name.
- head_init_scale: Scale factor for head initialization.
- """
- if isinstance(module, nn.Conv2d):
- trunc_normal_(module.weight, std=.02)
- if module.bias is not None:
- nn.init.zeros_(module.bias)
- elif isinstance(module, nn.Linear):
- trunc_normal_(module.weight, std=.02)
- nn.init.zeros_(module.bias)
- if name and 'head.' in name:
- module.weight.data.mul_(head_init_scale)
- module.bias.data.mul_(head_init_scale)
- def checkpoint_filter_fn(state_dict, model):
- """ Remap FB checkpoints -> timm """
- if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict:
- return state_dict # non-FB checkpoint
- if 'model' in state_dict:
- state_dict = state_dict['model']
- out_dict = {}
- if 'visual.trunk.stem.0.weight' in state_dict:
- out_dict = {k.replace('visual.trunk.', ''): v for k, v in state_dict.items() if k.startswith('visual.trunk.')}
- if 'visual.head.proj.weight' in state_dict:
- out_dict['head.fc.weight'] = state_dict['visual.head.proj.weight']
- out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
- elif 'visual.head.mlp.fc1.weight' in state_dict:
- out_dict['head.pre_logits.fc.weight'] = state_dict['visual.head.mlp.fc1.weight']
- out_dict['head.pre_logits.fc.bias'] = state_dict['visual.head.mlp.fc1.bias']
- out_dict['head.fc.weight'] = state_dict['visual.head.mlp.fc2.weight']
- out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.mlp.fc2.weight'].shape[0])
- return out_dict
- import re
- for k, v in state_dict.items():
- k = k.replace('downsample_layers.0.', 'stem.')
- k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
- k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k)
- k = k.replace('dwconv', 'conv_dw')
- k = k.replace('pwconv', 'mlp.fc')
- if 'grn' in k:
- k = k.replace('grn.beta', 'mlp.grn.bias')
- k = k.replace('grn.gamma', 'mlp.grn.weight')
- v = v.reshape(v.shape[-1])
- k = k.replace('head.', 'head.fc.')
- if k.startswith('norm.'):
- k = k.replace('norm', 'head.norm')
- if v.ndim == 2 and 'head' not in k:
- model_shape = model.state_dict()[k].shape
- v = v.reshape(model_shape)
- out_dict[k] = v
- return out_dict
- def _create_convnext(variant, pretrained=False, **kwargs):
- if kwargs.get('pretrained_cfg', '') == 'fcmae':
- # NOTE fcmae pretrained weights have no classifier or final norm-layer (`head.norm`)
- # This is workaround loading with num_classes=0 w/o removing norm-layer.
- kwargs.setdefault('pretrained_strict', False)
- model = build_model_with_cfg(
- ConvNeXt, variant, pretrained,
- pretrained_filter_fn=checkpoint_filter_fn,
- feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
- **kwargs)
- return model
- def _cfg(url='', **kwargs):
- return {
- 'url': url,
- 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
- 'crop_pct': 0.875, 'interpolation': 'bicubic',
- 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
- 'first_conv': 'stem.0', 'classifier': 'head.fc',
- 'license': 'apache-2.0', **kwargs
- }
- def _cfgv2(url='', **kwargs):
- return {
- 'url': url,
- 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
- 'crop_pct': 0.875, 'interpolation': 'bicubic',
- 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
- 'first_conv': 'stem.0', 'classifier': 'head.fc',
- 'license': 'cc-by-nc-4.0', 'paper_ids': 'arXiv:2301.00808',
- 'paper_name': 'ConvNeXt-V2: Co-designing and Scaling ConvNets with Masked Autoencoders',
- 'origin_url': 'https://github.com/facebookresearch/ConvNeXt-V2',
- **kwargs
- }
- default_cfgs = generate_default_cfgs({
- # timm specific variants
- 'convnext_tiny.in12k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'convnext_small.in12k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'convnext_zepto_rms.ra4_e3600_r224_in1k': _cfg(
- hf_hub_id='timm/',
- mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
- 'convnext_zepto_rms_ols.ra4_e3600_r224_in1k': _cfg(
- hf_hub_id='timm/',
- mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
- crop_pct=0.9),
- 'convnext_atto.d2_in1k': _cfg(
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=0.95),
- 'convnext_atto_ols.a2_in1k': _cfg(
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth',
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=0.95),
- 'convnext_atto_rms.untrained': _cfg(
- #hf_hub_id='timm/',
- test_input_size=(3, 256, 256), test_crop_pct=0.95),
- 'convnext_femto.d1_in1k': _cfg(
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=0.95),
- 'convnext_femto_ols.d1_in1k': _cfg(
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth',
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=0.95),
- 'convnext_pico.d1_in1k': _cfg(
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth',
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=0.95),
- 'convnext_pico_ols.d1_in1k': _cfg(
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth',
- hf_hub_id='timm/',
- crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'convnext_nano.in12k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'convnext_nano.d1h_in1k': _cfg(
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth',
- hf_hub_id='timm/',
- crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'convnext_nano_ols.d1h_in1k': _cfg(
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth',
- hf_hub_id='timm/',
- crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'convnext_tiny_hnf.a2h_in1k': _cfg(
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
- hf_hub_id='timm/',
- crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'convnext_nano.r384_in12k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
- 'convnext_tiny.in12k_ft_in1k_384': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
- 'convnext_small.in12k_ft_in1k_384': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
- 'convnext_nano.in12k': _cfg(
- hf_hub_id='timm/',
- crop_pct=0.95, num_classes=11821),
- 'convnext_nano.r384_in12k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=11821),
- 'convnext_nano.r384_ad_in12k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=11821),
- 'convnext_tiny.in12k': _cfg(
- hf_hub_id='timm/',
- crop_pct=0.95, num_classes=11821),
- 'convnext_small.in12k': _cfg(
- hf_hub_id='timm/',
- crop_pct=0.95, num_classes=11821),
- 'convnext_tiny.fb_in22k_ft_in1k': _cfg(
- url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth',
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'convnext_small.fb_in22k_ft_in1k': _cfg(
- url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth',
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'convnext_base.fb_in22k_ft_in1k': _cfg(
- url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth',
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'convnext_large.fb_in22k_ft_in1k': _cfg(
- url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth',
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'convnext_xlarge.fb_in22k_ft_in1k': _cfg(
- url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth',
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'convnext_tiny.fb_in1k': _cfg(
- url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'convnext_small.fb_in1k': _cfg(
- url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'convnext_base.fb_in1k': _cfg(
- url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'convnext_large.fb_in1k': _cfg(
- url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'convnext_tiny.fb_in22k_ft_in1k_384': _cfg(
- url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
- 'convnext_small.fb_in22k_ft_in1k_384': _cfg(
- url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth',
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
- 'convnext_base.fb_in22k_ft_in1k_384': _cfg(
- url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth',
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
- 'convnext_large.fb_in22k_ft_in1k_384': _cfg(
- url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth',
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
- 'convnext_xlarge.fb_in22k_ft_in1k_384': _cfg(
- url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth',
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
- 'convnext_tiny.fb_in22k': _cfg(
- url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
- hf_hub_id='timm/',
- num_classes=21841),
- 'convnext_small.fb_in22k': _cfg(
- url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
- hf_hub_id='timm/',
- num_classes=21841),
- 'convnext_base.fb_in22k': _cfg(
- url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
- hf_hub_id='timm/',
- num_classes=21841),
- 'convnext_large.fb_in22k': _cfg(
- url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
- hf_hub_id='timm/',
- num_classes=21841),
- 'convnext_xlarge.fb_in22k': _cfg(
- url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
- hf_hub_id='timm/',
- num_classes=21841),
- 'convnextv2_nano.fcmae_ft_in22k_in1k': _cfgv2(
- url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_224_ema.pt',
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'convnextv2_nano.fcmae_ft_in22k_in1k_384': _cfgv2(
- url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_384_ema.pt',
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
- 'convnextv2_tiny.fcmae_ft_in22k_in1k': _cfgv2(
- url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_224_ema.pt",
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'convnextv2_tiny.fcmae_ft_in22k_in1k_384': _cfgv2(
- url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_384_ema.pt",
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
- 'convnextv2_base.fcmae_ft_in22k_in1k': _cfgv2(
- url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_224_ema.pt",
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'convnextv2_base.fcmae_ft_in22k_in1k_384': _cfgv2(
- url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt",
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
- 'convnextv2_large.fcmae_ft_in22k_in1k': _cfgv2(
- url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_224_ema.pt",
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'convnextv2_large.fcmae_ft_in22k_in1k_384': _cfgv2(
- url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt",
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
- 'convnextv2_huge.fcmae_ft_in22k_in1k_384': _cfgv2(
- url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_384_ema.pt",
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
- 'convnextv2_huge.fcmae_ft_in22k_in1k_512': _cfgv2(
- url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_512_ema.pt",
- hf_hub_id='timm/',
- input_size=(3, 512, 512), pool_size=(15, 15), crop_pct=1.0, crop_mode='squash'),
- 'convnextv2_atto.fcmae_ft_in1k': _cfgv2(
- url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_atto_1k_224_ema.pt',
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=0.95),
- 'convnextv2_femto.fcmae_ft_in1k': _cfgv2(
- url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_femto_1k_224_ema.pt',
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=0.95),
- 'convnextv2_pico.fcmae_ft_in1k': _cfgv2(
- url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_pico_1k_224_ema.pt',
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=0.95),
- 'convnextv2_nano.fcmae_ft_in1k': _cfgv2(
- url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_nano_1k_224_ema.pt',
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'convnextv2_tiny.fcmae_ft_in1k': _cfgv2(
- url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_tiny_1k_224_ema.pt",
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'convnextv2_base.fcmae_ft_in1k': _cfgv2(
- url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_base_1k_224_ema.pt",
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'convnextv2_large.fcmae_ft_in1k': _cfgv2(
- url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_large_1k_224_ema.pt",
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'convnextv2_huge.fcmae_ft_in1k': _cfgv2(
- url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_huge_1k_224_ema.pt",
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'convnextv2_atto.fcmae': _cfgv2(
- url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_atto_1k_224_fcmae.pt',
- hf_hub_id='timm/',
- num_classes=0),
- 'convnextv2_femto.fcmae': _cfgv2(
- url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_femto_1k_224_fcmae.pt',
- hf_hub_id='timm/',
- num_classes=0),
- 'convnextv2_pico.fcmae': _cfgv2(
- url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_pico_1k_224_fcmae.pt',
- hf_hub_id='timm/',
- num_classes=0),
- 'convnextv2_nano.fcmae': _cfgv2(
- url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_nano_1k_224_fcmae.pt',
- hf_hub_id='timm/',
- num_classes=0),
- 'convnextv2_tiny.fcmae': _cfgv2(
- url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_tiny_1k_224_fcmae.pt",
- hf_hub_id='timm/',
- num_classes=0),
- 'convnextv2_base.fcmae': _cfgv2(
- url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_base_1k_224_fcmae.pt",
- hf_hub_id='timm/',
- num_classes=0),
- 'convnextv2_large.fcmae': _cfgv2(
- url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_large_1k_224_fcmae.pt",
- hf_hub_id='timm/',
- num_classes=0),
- 'convnextv2_huge.fcmae': _cfgv2(
- url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_huge_1k_224_fcmae.pt",
- hf_hub_id='timm/',
- num_classes=0),
- 'convnextv2_small.untrained': _cfg(),
- # CLIP weights, fine-tuned on in1k or in12k + in1k
- 'convnext_base.clip_laion2b_augreg_ft_in12k_in1k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
- 'convnext_base.clip_laion2b_augreg_ft_in12k_in1k_384': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
- 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_320': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0),
- 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
- 'convnext_base.clip_laion2b_augreg_ft_in1k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
- 'convnext_base.clip_laiona_augreg_ft_in1k_384': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
- 'convnext_large_mlp.clip_laion2b_augreg_ft_in1k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0
- ),
- 'convnext_large_mlp.clip_laion2b_augreg_ft_in1k_384': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'
- ),
- 'convnext_xxlarge.clip_laion2b_soup_ft_in1k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
- 'convnext_base.clip_laion2b_augreg_ft_in12k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
- input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
- 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_320': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
- input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0),
- 'convnext_large_mlp.clip_laion2b_augreg_ft_in12k_384': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
- 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_384': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
- 'convnext_xxlarge.clip_laion2b_soup_ft_in12k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
- input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
- # CLIP original image tower weights
- 'convnext_base.clip_laion2b': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
- 'convnext_base.clip_laion2b_augreg': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
- 'convnext_base.clip_laiona': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
- 'convnext_base.clip_laiona_320': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
- 'convnext_base.clip_laiona_augreg_320': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
- 'convnext_large_mlp.clip_laion2b_augreg': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=768),
- 'convnext_large_mlp.clip_laion2b_ft_320': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
- 'convnext_large_mlp.clip_laion2b_ft_soup_320': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
- 'convnext_xxlarge.clip_laion2b_soup': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
- 'convnext_xxlarge.clip_laion2b_rewind': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
- # NOTE dinov3 convnext weights are under a specific license, and downstream outputs must be shared with this
- # https://ai.meta.com/resources/models-and-libraries/dinov3-license/
- 'convnext_tiny.dinov3_lvd1689m': _cfg(
- hf_hub_id='timm/',
- crop_pct=1.0,
- num_classes=0,
- license='dinov3-license',
- ),
- 'convnext_small.dinov3_lvd1689m': _cfg(
- hf_hub_id='timm/',
- crop_pct=1.0,
- num_classes=0,
- license='dinov3-license',
- ),
- 'convnext_base.dinov3_lvd1689m': _cfg(
- hf_hub_id='timm/',
- crop_pct=1.0,
- num_classes=0,
- license='dinov3-license',
- ),
- 'convnext_large.dinov3_lvd1689m': _cfg(
- hf_hub_id='timm/',
- crop_pct=1.0,
- num_classes=0,
- license='dinov3-license',
- ),
- "test_convnext.r160_in1k": _cfg(
- hf_hub_id='timm/',
- mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
- input_size=(3, 160, 160), pool_size=(5, 5), crop_pct=0.95),
- "test_convnext2.r160_in1k": _cfg(
- hf_hub_id='timm/',
- mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
- input_size=(3, 160, 160), pool_size=(5, 5), crop_pct=0.95),
- "test_convnext3.r160_in1k": _cfg(
- hf_hub_id='timm/',
- mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
- input_size=(3, 160, 160), pool_size=(5, 5), crop_pct=0.95),
- })
- @register_model
- def convnext_zepto_rms(pretrained=False, **kwargs) -> ConvNeXt:
- # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
- model_args = dict(depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='simplenorm')
- model = _create_convnext('convnext_zepto_rms', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convnext_zepto_rms_ols(pretrained=False, **kwargs) -> ConvNeXt:
- # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
- model_args = dict(
- depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='simplenorm', stem_type='overlap_act')
- model = _create_convnext('convnext_zepto_rms_ols', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convnext_atto(pretrained=False, **kwargs) -> ConvNeXt:
- # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
- model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True)
- model = _create_convnext('convnext_atto', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convnext_atto_ols(pretrained=False, **kwargs) -> ConvNeXt:
- # timm femto variant with overlapping 3x3 conv stem, wider than non-ols femto above, current param count 3.7M
- model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, stem_type='overlap_tiered')
- model = _create_convnext('convnext_atto_ols', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convnext_atto_rms(pretrained=False, **kwargs) -> ConvNeXt:
- # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
- model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, norm_layer='rmsnorm2d')
- model = _create_convnext('convnext_atto_rms', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convnext_femto(pretrained=False, **kwargs) -> ConvNeXt:
- # timm femto variant
- model_args = dict(depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True)
- model = _create_convnext('convnext_femto', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convnext_femto_ols(pretrained=False, **kwargs) -> ConvNeXt:
- # timm femto variant
- model_args = dict(depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, stem_type='overlap_tiered')
- model = _create_convnext('convnext_femto_ols', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convnext_pico(pretrained=False, **kwargs) -> ConvNeXt:
- # timm pico variant
- model_args = dict(depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True)
- model = _create_convnext('convnext_pico', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convnext_pico_ols(pretrained=False, **kwargs) -> ConvNeXt:
- # timm nano variant with overlapping 3x3 conv stem
- model_args = dict(depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, stem_type='overlap_tiered')
- model = _create_convnext('convnext_pico_ols', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convnext_nano(pretrained=False, **kwargs) -> ConvNeXt:
- # timm nano variant with standard stem and head
- model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True)
- model = _create_convnext('convnext_nano', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convnext_nano_ols(pretrained=False, **kwargs) -> ConvNeXt:
- # experimental nano variant with overlapping conv stem
- model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, stem_type='overlap')
- model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convnext_tiny_hnf(pretrained=False, **kwargs) -> ConvNeXt:
- # experimental tiny variant with norm before pooling in head (head norm first)
- model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True)
- model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convnext_tiny(pretrained=False, **kwargs) -> ConvNeXt:
- model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768))
- model = _create_convnext('convnext_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convnext_small(pretrained=False, **kwargs) -> ConvNeXt:
- model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768])
- model = _create_convnext('convnext_small', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convnext_base(pretrained=False, **kwargs) -> ConvNeXt:
- model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])
- model = _create_convnext('convnext_base', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convnext_large(pretrained=False, **kwargs) -> ConvNeXt:
- model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536])
- model = _create_convnext('convnext_large', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convnext_large_mlp(pretrained=False, **kwargs) -> ConvNeXt:
- model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], head_hidden_size=1536)
- model = _create_convnext('convnext_large_mlp', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convnext_xlarge(pretrained=False, **kwargs) -> ConvNeXt:
- model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048])
- model = _create_convnext('convnext_xlarge', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convnext_xxlarge(pretrained=False, **kwargs) -> ConvNeXt:
- model_args = dict(depths=[3, 4, 30, 3], dims=[384, 768, 1536, 3072], norm_eps=kwargs.pop('norm_eps', 1e-5))
- model = _create_convnext('convnext_xxlarge', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convnextv2_atto(pretrained=False, **kwargs) -> ConvNeXt:
- # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
- model_args = dict(
- depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), use_grn=True, ls_init_value=None, conv_mlp=True)
- model = _create_convnext('convnextv2_atto', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convnextv2_femto(pretrained=False, **kwargs) -> ConvNeXt:
- # timm femto variant
- model_args = dict(
- depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), use_grn=True, ls_init_value=None, conv_mlp=True)
- model = _create_convnext('convnextv2_femto', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convnextv2_pico(pretrained=False, **kwargs) -> ConvNeXt:
- # timm pico variant
- model_args = dict(
- depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), use_grn=True, ls_init_value=None, conv_mlp=True)
- model = _create_convnext('convnextv2_pico', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convnextv2_nano(pretrained=False, **kwargs) -> ConvNeXt:
- # timm nano variant with standard stem and head
- model_args = dict(
- depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), use_grn=True, ls_init_value=None, conv_mlp=True)
- model = _create_convnext('convnextv2_nano', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convnextv2_tiny(pretrained=False, **kwargs) -> ConvNeXt:
- model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), use_grn=True, ls_init_value=None)
- model = _create_convnext('convnextv2_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convnextv2_small(pretrained=False, **kwargs) -> ConvNeXt:
- model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], use_grn=True, ls_init_value=None)
- model = _create_convnext('convnextv2_small', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convnextv2_base(pretrained=False, **kwargs) -> ConvNeXt:
- model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], use_grn=True, ls_init_value=None)
- model = _create_convnext('convnextv2_base', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convnextv2_large(pretrained=False, **kwargs) -> ConvNeXt:
- model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], use_grn=True, ls_init_value=None)
- model = _create_convnext('convnextv2_large', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def convnextv2_huge(pretrained=False, **kwargs) -> ConvNeXt:
- model_args = dict(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], use_grn=True, ls_init_value=None)
- model = _create_convnext('convnextv2_huge', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def test_convnext(pretrained=False, **kwargs) -> ConvNeXt:
- model_args = dict(depths=[1, 2, 4, 2], dims=[24, 32, 48, 64], norm_eps=kwargs.pop('norm_eps', 1e-5), act_layer='gelu_tanh')
- model = _create_convnext('test_convnext', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def test_convnext2(pretrained=False, **kwargs) -> ConvNeXt:
- model_args = dict(depths=[1, 1, 1, 1], dims=[32, 64, 96, 128], norm_eps=kwargs.pop('norm_eps', 1e-5), act_layer='gelu_tanh')
- model = _create_convnext('test_convnext2', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def test_convnext3(pretrained=False, **kwargs) -> ConvNeXt:
- model_args = dict(
- depths=[1, 1, 1, 1], dims=[32, 64, 96, 128], norm_eps=kwargs.pop('norm_eps', 1e-5), kernel_sizes=(7, 5, 5, 3), act_layer='silu')
- model = _create_convnext('test_convnext3', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- register_model_deprecations(__name__, {
- 'convnext_tiny_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k',
- 'convnext_small_in22ft1k': 'convnext_small.fb_in22k_ft_in1k',
- 'convnext_base_in22ft1k': 'convnext_base.fb_in22k_ft_in1k',
- 'convnext_large_in22ft1k': 'convnext_large.fb_in22k_ft_in1k',
- 'convnext_xlarge_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k',
- 'convnext_tiny_384_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k_384',
- 'convnext_small_384_in22ft1k': 'convnext_small.fb_in22k_ft_in1k_384',
- 'convnext_base_384_in22ft1k': 'convnext_base.fb_in22k_ft_in1k_384',
- 'convnext_large_384_in22ft1k': 'convnext_large.fb_in22k_ft_in1k_384',
- 'convnext_xlarge_384_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k_384',
- 'convnext_tiny_in22k': 'convnext_tiny.fb_in22k',
- 'convnext_small_in22k': 'convnext_small.fb_in22k',
- 'convnext_base_in22k': 'convnext_base.fb_in22k',
- 'convnext_large_in22k': 'convnext_large.fb_in22k',
- 'convnext_xlarge_in22k': 'convnext_xlarge.fb_in22k',
- })
|