convnext.py 59 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410
  1. """ ConvNeXt
  2. Papers:
  3. * `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
  4. @Article{liu2022convnet,
  5. author = {Zhuang Liu and Hanzi Mao and Chao-Yuan Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
  6. title = {A ConvNet for the 2020s},
  7. journal = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  8. year = {2022},
  9. }
  10. * `ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808
  11. @article{Woo2023ConvNeXtV2,
  12. title={ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders},
  13. author={Sanghyun Woo, Shoubhik Debnath, Ronghang Hu, Xinlei Chen, Zhuang Liu, In So Kweon and Saining Xie},
  14. year={2023},
  15. journal={arXiv preprint arXiv:2301.00808},
  16. }
  17. Original code and weights from:
  18. * https://github.com/facebookresearch/ConvNeXt, original copyright below
  19. * https://github.com/facebookresearch/ConvNeXt-V2, original copyright below
  20. Model defs atto, femto, pico, nano and _ols / _hnf variants are timm originals.
  21. Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman
  22. """
  23. # ConvNeXt
  24. # Copyright (c) Meta Platforms, Inc. and affiliates.
  25. # All rights reserved.
  26. # This source code is licensed under the MIT license
  27. # ConvNeXt-V2
  28. # Copyright (c) Meta Platforms, Inc. and affiliates.
  29. # All rights reserved.
  30. # This source code is licensed under the license found in the
  31. # LICENSE file in the root directory of this source tree (Attribution-NonCommercial 4.0 International (CC BY-NC 4.0))
  32. # No code was used directly from ConvNeXt-V2, however the weights are CC BY-NC 4.0 so beware if using commercially.
  33. from functools import partial
  34. from typing import Callable, Dict, List, Optional, Tuple, Union
  35. import torch
  36. import torch.nn as nn
  37. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
  38. from timm.layers import (
  39. trunc_normal_,
  40. AvgPool2dSame,
  41. DropPath,
  42. calculate_drop_path_rates,
  43. Mlp,
  44. GlobalResponseNormMlp,
  45. LayerNorm2d,
  46. LayerNorm,
  47. RmsNorm2d,
  48. RmsNorm,
  49. SimpleNorm2d,
  50. SimpleNorm,
  51. create_conv2d,
  52. get_act_layer,
  53. get_norm_layer,
  54. make_divisible,
  55. to_ntuple,
  56. NormMlpClassifierHead,
  57. ClassifierHead,
  58. )
  59. from ._builder import build_model_with_cfg
  60. from ._features import feature_take_indices
  61. from ._manipulate import named_apply, checkpoint_seq
  62. from ._registry import generate_default_cfgs, register_model, register_model_deprecations
  63. __all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
  64. class Downsample(nn.Module):
  65. """Downsample module for ConvNeXt."""
  66. def __init__(
  67. self,
  68. in_chs: int,
  69. out_chs: int,
  70. stride: int = 1,
  71. dilation: int = 1,
  72. device=None,
  73. dtype=None,
  74. ) -> None:
  75. """Initialize Downsample module.
  76. Args:
  77. in_chs: Number of input channels.
  78. out_chs: Number of output channels.
  79. stride: Stride for downsampling.
  80. dilation: Dilation rate.
  81. """
  82. dd = {'device': device, 'dtype': dtype}
  83. super().__init__()
  84. avg_stride = stride if dilation == 1 else 1
  85. if stride > 1 or dilation > 1:
  86. avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
  87. self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
  88. else:
  89. self.pool = nn.Identity()
  90. if in_chs != out_chs:
  91. self.conv = create_conv2d(in_chs, out_chs, 1, stride=1, **dd)
  92. else:
  93. self.conv = nn.Identity()
  94. def forward(self, x: torch.Tensor) -> torch.Tensor:
  95. """Forward pass."""
  96. x = self.pool(x)
  97. x = self.conv(x)
  98. return x
  99. class ConvNeXtBlock(nn.Module):
  100. """ConvNeXt Block.
  101. There are two equivalent implementations:
  102. (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
  103. (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
  104. Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
  105. choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
  106. is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
  107. """
  108. def __init__(
  109. self,
  110. in_chs: int,
  111. out_chs: Optional[int] = None,
  112. kernel_size: int = 7,
  113. stride: int = 1,
  114. dilation: Union[int, Tuple[int, int]] = (1, 1),
  115. mlp_ratio: float = 4,
  116. conv_mlp: bool = False,
  117. conv_bias: bool = True,
  118. use_grn: bool = False,
  119. ls_init_value: Optional[float] = 1e-6,
  120. act_layer: Union[str, Callable] = 'gelu',
  121. norm_layer: Optional[Callable] = None,
  122. drop_path: float = 0.,
  123. device=None,
  124. dtype=None,
  125. ):
  126. """
  127. Args:
  128. in_chs: Block input channels.
  129. out_chs: Block output channels (same as in_chs if None).
  130. kernel_size: Depthwise convolution kernel size.
  131. stride: Stride of depthwise convolution.
  132. dilation: Tuple specifying input and output dilation of block.
  133. mlp_ratio: MLP expansion ratio.
  134. conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
  135. conv_bias: Apply bias for all convolution (linear) layers.
  136. use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
  137. ls_init_value: Layer-scale init values, layer-scale applied if not None.
  138. act_layer: Activation layer.
  139. norm_layer: Normalization layer (defaults to LN if not specified).
  140. drop_path: Stochastic depth probability.
  141. """
  142. dd = {'device': device, 'dtype': dtype}
  143. super().__init__()
  144. out_chs = out_chs or in_chs
  145. dilation = to_ntuple(2)(dilation)
  146. act_layer = get_act_layer(act_layer)
  147. if not norm_layer:
  148. norm_layer = LayerNorm2d if conv_mlp else LayerNorm
  149. mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp)
  150. self.use_conv_mlp = conv_mlp
  151. self.conv_dw = create_conv2d(
  152. in_chs,
  153. out_chs,
  154. kernel_size=kernel_size,
  155. stride=stride,
  156. dilation=dilation[0],
  157. depthwise=True,
  158. bias=conv_bias,
  159. **dd,
  160. )
  161. self.norm = norm_layer(out_chs, **dd)
  162. self.mlp = mlp_layer(
  163. out_chs,
  164. int(mlp_ratio * out_chs),
  165. act_layer=act_layer,
  166. **dd,
  167. )
  168. self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs, **dd)) if ls_init_value is not None else None
  169. if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
  170. self.shortcut = Downsample(in_chs, out_chs, stride=stride, dilation=dilation[0], **dd)
  171. else:
  172. self.shortcut = nn.Identity()
  173. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  174. def forward(self, x: torch.Tensor) -> torch.Tensor:
  175. """Forward pass."""
  176. shortcut = x
  177. x = self.conv_dw(x)
  178. if self.use_conv_mlp:
  179. x = self.norm(x)
  180. x = self.mlp(x)
  181. else:
  182. x = x.permute(0, 2, 3, 1)
  183. x = self.norm(x)
  184. x = self.mlp(x)
  185. x = x.permute(0, 3, 1, 2)
  186. if self.gamma is not None:
  187. x = x.mul(self.gamma.reshape(1, -1, 1, 1))
  188. x = self.drop_path(x) + self.shortcut(shortcut)
  189. return x
  190. class ConvNeXtStage(nn.Module):
  191. """ConvNeXt stage (multiple blocks)."""
  192. def __init__(
  193. self,
  194. in_chs: int,
  195. out_chs: int,
  196. kernel_size: int = 7,
  197. stride: int = 2,
  198. depth: int = 2,
  199. dilation: Tuple[int, int] = (1, 1),
  200. drop_path_rates: Optional[List[float]] = None,
  201. ls_init_value: float = 1.0,
  202. conv_mlp: bool = False,
  203. conv_bias: bool = True,
  204. use_grn: bool = False,
  205. act_layer: Union[str, Callable] = 'gelu',
  206. norm_layer: Optional[Callable] = None,
  207. norm_layer_cl: Optional[Callable] = None,
  208. device=None,
  209. dtype=None,
  210. ) -> None:
  211. """Initialize ConvNeXt stage.
  212. Args:
  213. in_chs: Number of input channels.
  214. out_chs: Number of output channels.
  215. kernel_size: Kernel size for depthwise convolution.
  216. stride: Stride for downsampling.
  217. depth: Number of blocks in stage.
  218. dilation: Dilation rates.
  219. drop_path_rates: Drop path rates for each block.
  220. ls_init_value: Initial value for layer scale.
  221. conv_mlp: Use convolutional MLP.
  222. conv_bias: Use bias in convolutions.
  223. use_grn: Use global response normalization.
  224. act_layer: Activation layer.
  225. norm_layer: Normalization layer.
  226. norm_layer_cl: Normalization layer for channels last.
  227. """
  228. dd = {'device': device, 'dtype': dtype}
  229. super().__init__()
  230. self.grad_checkpointing = False
  231. if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]:
  232. ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1
  233. pad = 'same' if dilation[1] > 1 else 0 # same padding needed if dilation used
  234. self.downsample = nn.Sequential(
  235. norm_layer(in_chs, **dd),
  236. create_conv2d(
  237. in_chs,
  238. out_chs,
  239. kernel_size=ds_ks,
  240. stride=stride,
  241. dilation=dilation[0],
  242. padding=pad,
  243. bias=conv_bias,
  244. **dd,
  245. ),
  246. )
  247. in_chs = out_chs
  248. else:
  249. self.downsample = nn.Identity()
  250. drop_path_rates = drop_path_rates or [0.] * depth
  251. stage_blocks = []
  252. for i in range(depth):
  253. stage_blocks.append(ConvNeXtBlock(
  254. in_chs=in_chs,
  255. out_chs=out_chs,
  256. kernel_size=kernel_size,
  257. dilation=dilation[1],
  258. drop_path=drop_path_rates[i],
  259. ls_init_value=ls_init_value,
  260. conv_mlp=conv_mlp,
  261. conv_bias=conv_bias,
  262. use_grn=use_grn,
  263. act_layer=act_layer,
  264. norm_layer=norm_layer if conv_mlp else norm_layer_cl,
  265. **dd,
  266. ))
  267. in_chs = out_chs
  268. self.blocks = nn.Sequential(*stage_blocks)
  269. def forward(self, x: torch.Tensor) -> torch.Tensor:
  270. """Forward pass."""
  271. x = self.downsample(x)
  272. if self.grad_checkpointing and not torch.jit.is_scripting():
  273. x = checkpoint_seq(self.blocks, x)
  274. else:
  275. x = self.blocks(x)
  276. return x
  277. # map of norm layers with NCHW (2D) and channels last variants
  278. _NORM_MAP = {
  279. 'layernorm': (LayerNorm2d, LayerNorm),
  280. 'layernorm2d': (LayerNorm2d, LayerNorm),
  281. 'simplenorm': (SimpleNorm2d, SimpleNorm),
  282. 'simplenorm2d': (SimpleNorm2d, SimpleNorm),
  283. 'rmsnorm': (RmsNorm2d, RmsNorm),
  284. 'rmsnorm2d': (RmsNorm2d, RmsNorm),
  285. }
  286. def _get_norm_layers(norm_layer: Union[Callable, str], conv_mlp: bool, norm_eps: float):
  287. norm_layer = norm_layer or 'layernorm'
  288. if norm_layer in _NORM_MAP:
  289. norm_layer_cl = _NORM_MAP[norm_layer][0] if conv_mlp else _NORM_MAP[norm_layer][1]
  290. norm_layer = _NORM_MAP[norm_layer][0]
  291. if norm_eps is not None:
  292. norm_layer = partial(norm_layer, eps=norm_eps)
  293. norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
  294. else:
  295. assert conv_mlp, \
  296. 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
  297. norm_layer = get_norm_layer(norm_layer)
  298. norm_layer_cl = norm_layer
  299. if norm_eps is not None:
  300. norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
  301. return norm_layer, norm_layer_cl
  302. class ConvNeXt(nn.Module):
  303. """ConvNeXt model architecture.
  304. A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
  305. """
  306. def __init__(
  307. self,
  308. in_chans: int = 3,
  309. num_classes: int = 1000,
  310. global_pool: str = 'avg',
  311. output_stride: int = 32,
  312. depths: Tuple[int, ...] = (3, 3, 9, 3),
  313. dims: Tuple[int, ...] = (96, 192, 384, 768),
  314. kernel_sizes: Union[int, Tuple[int, ...]] = 7,
  315. ls_init_value: Optional[float] = 1e-6,
  316. stem_type: str = 'patch',
  317. patch_size: int = 4,
  318. head_init_scale: float = 1.,
  319. head_norm_first: bool = False,
  320. head_hidden_size: Optional[int] = None,
  321. conv_mlp: bool = False,
  322. conv_bias: bool = True,
  323. use_grn: bool = False,
  324. act_layer: Union[str, Callable] = 'gelu',
  325. norm_layer: Optional[Union[str, Callable]] = None,
  326. norm_eps: Optional[float] = None,
  327. drop_rate: float = 0.,
  328. drop_path_rate: float = 0.,
  329. device=None,
  330. dtype=None,
  331. ):
  332. """
  333. Args:
  334. in_chans: Number of input image channels.
  335. num_classes: Number of classes for classification head.
  336. global_pool: Global pooling type.
  337. output_stride: Output stride of network, one of (8, 16, 32).
  338. depths: Number of blocks at each stage.
  339. dims: Feature dimension at each stage.
  340. kernel_sizes: Depthwise convolution kernel-sizes for each stage.
  341. ls_init_value: Init value for Layer Scale, disabled if None.
  342. stem_type: Type of stem.
  343. patch_size: Stem patch size for patch stem.
  344. head_init_scale: Init scaling value for classifier weights and biases.
  345. head_norm_first: Apply normalization before global pool + head.
  346. head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False.
  347. conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last.
  348. conv_bias: Use bias layers w/ all convolutions.
  349. use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP.
  350. act_layer: Activation layer type.
  351. norm_layer: Normalization layer type.
  352. drop_rate: Head pre-classifier dropout rate.
  353. drop_path_rate: Stochastic depth drop rate.
  354. """
  355. super().__init__()
  356. dd = {'device': device, 'dtype': dtype}
  357. assert output_stride in (8, 16, 32)
  358. kernel_sizes = to_ntuple(4)(kernel_sizes)
  359. norm_layer, norm_layer_cl = _get_norm_layers(norm_layer, conv_mlp, norm_eps)
  360. act_layer = get_act_layer(act_layer)
  361. self.num_classes = num_classes
  362. self.in_chans = in_chans
  363. self.drop_rate = drop_rate
  364. self.feature_info = []
  365. assert stem_type in ('patch', 'overlap', 'overlap_tiered', 'overlap_act')
  366. if stem_type == 'patch':
  367. # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
  368. self.stem = nn.Sequential(
  369. nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias, **dd),
  370. norm_layer(dims[0], **dd),
  371. )
  372. stem_stride = patch_size
  373. else:
  374. mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0]
  375. self.stem = nn.Sequential(*filter(None, [
  376. nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias, **dd),
  377. act_layer() if 'act' in stem_type else None,
  378. nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias, **dd),
  379. norm_layer(dims[0], **dd),
  380. ]))
  381. stem_stride = 4
  382. self.stages = nn.Sequential()
  383. dp_rates = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
  384. stages = []
  385. prev_chs = dims[0]
  386. curr_stride = stem_stride
  387. dilation = 1
  388. # 4 feature resolution stages, each consisting of multiple residual blocks
  389. for i in range(4):
  390. stride = 2 if curr_stride == 2 or i > 0 else 1
  391. if curr_stride >= output_stride and stride > 1:
  392. dilation *= stride
  393. stride = 1
  394. curr_stride *= stride
  395. first_dilation = 1 if dilation in (1, 2) else 2
  396. out_chs = dims[i]
  397. stages.append(ConvNeXtStage(
  398. prev_chs,
  399. out_chs,
  400. kernel_size=kernel_sizes[i],
  401. stride=stride,
  402. dilation=(first_dilation, dilation),
  403. depth=depths[i],
  404. drop_path_rates=dp_rates[i],
  405. ls_init_value=ls_init_value,
  406. conv_mlp=conv_mlp,
  407. conv_bias=conv_bias,
  408. use_grn=use_grn,
  409. act_layer=act_layer,
  410. norm_layer=norm_layer,
  411. norm_layer_cl=norm_layer_cl,
  412. **dd,
  413. ))
  414. prev_chs = out_chs
  415. # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
  416. self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
  417. self.stages = nn.Sequential(*stages)
  418. self.num_features = self.head_hidden_size = prev_chs
  419. # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
  420. # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
  421. if head_norm_first:
  422. assert not head_hidden_size
  423. self.norm_pre = norm_layer(self.num_features, **dd)
  424. self.head = ClassifierHead(
  425. self.num_features,
  426. num_classes,
  427. pool_type=global_pool,
  428. drop_rate=self.drop_rate,
  429. **dd,
  430. )
  431. else:
  432. self.norm_pre = nn.Identity()
  433. self.head = NormMlpClassifierHead(
  434. self.num_features,
  435. num_classes,
  436. hidden_size=head_hidden_size,
  437. pool_type=global_pool,
  438. drop_rate=self.drop_rate,
  439. norm_layer=norm_layer,
  440. act_layer='gelu',
  441. **dd,
  442. )
  443. self.head_hidden_size = self.head.num_features
  444. named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
  445. @torch.jit.ignore
  446. def group_matcher(self, coarse: bool = False) -> Dict[str, Union[str, List]]:
  447. """Create regex patterns for parameter grouping.
  448. Args:
  449. coarse: Use coarse grouping.
  450. Returns:
  451. Dictionary mapping group names to regex patterns.
  452. """
  453. return dict(
  454. stem=r'^stem',
  455. blocks=r'^stages\.(\d+)' if coarse else [
  456. (r'^stages\.(\d+)\.downsample', (0,)), # blocks
  457. (r'^stages\.(\d+)\.blocks\.(\d+)', None),
  458. (r'^norm_pre', (99999,))
  459. ]
  460. )
  461. @torch.jit.ignore
  462. def set_grad_checkpointing(self, enable: bool = True) -> None:
  463. """Enable or disable gradient checkpointing.
  464. Args:
  465. enable: Whether to enable gradient checkpointing.
  466. """
  467. for s in self.stages:
  468. s.grad_checkpointing = enable
  469. @torch.jit.ignore
  470. def get_classifier(self) -> nn.Module:
  471. """Get the classifier module."""
  472. return self.head.fc
  473. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
  474. """Reset the classifier head.
  475. Args:
  476. num_classes: Number of classes for new classifier.
  477. global_pool: Global pooling type.
  478. """
  479. self.num_classes = num_classes
  480. self.head.reset(num_classes, global_pool)
  481. def forward_intermediates(
  482. self,
  483. x: torch.Tensor,
  484. indices: Optional[Union[int, List[int]]] = None,
  485. norm: bool = False,
  486. stop_early: bool = False,
  487. output_fmt: str = 'NCHW',
  488. intermediates_only: bool = False,
  489. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  490. """Forward features that returns intermediates.
  491. Args:
  492. x: Input image tensor.
  493. indices: Take last n blocks if int, all if None, select matching indices if sequence.
  494. norm: Apply norm layer to compatible intermediates.
  495. stop_early: Stop iterating over blocks when last desired intermediate hit.
  496. output_fmt: Shape of intermediate feature outputs.
  497. intermediates_only: Only return intermediate features.
  498. Returns:
  499. List of intermediate features or tuple of (final features, intermediates).
  500. """
  501. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  502. intermediates = []
  503. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  504. # forward pass
  505. x = self.stem(x)
  506. last_idx = len(self.stages) - 1
  507. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  508. stages = self.stages
  509. else:
  510. stages = self.stages[:max_index + 1]
  511. for feat_idx, stage in enumerate(stages):
  512. x = stage(x)
  513. if feat_idx in take_indices:
  514. if norm and feat_idx == last_idx:
  515. intermediates.append(self.norm_pre(x))
  516. else:
  517. intermediates.append(x)
  518. if intermediates_only:
  519. return intermediates
  520. if feat_idx == last_idx:
  521. x = self.norm_pre(x)
  522. return x, intermediates
  523. def prune_intermediate_layers(
  524. self,
  525. indices: Union[int, List[int]] = 1,
  526. prune_norm: bool = False,
  527. prune_head: bool = True,
  528. ) -> List[int]:
  529. """Prune layers not required for specified intermediates.
  530. Args:
  531. indices: Indices of intermediate layers to keep.
  532. prune_norm: Whether to prune normalization layer.
  533. prune_head: Whether to prune the classifier head.
  534. Returns:
  535. List of indices that were kept.
  536. """
  537. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  538. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  539. if prune_norm:
  540. self.norm_pre = nn.Identity()
  541. if prune_head:
  542. self.reset_classifier(0, '')
  543. return take_indices
  544. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  545. """Forward pass through feature extraction layers."""
  546. x = self.stem(x)
  547. x = self.stages(x)
  548. x = self.norm_pre(x)
  549. return x
  550. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  551. """Forward pass through classifier head.
  552. Args:
  553. x: Feature tensor.
  554. pre_logits: Return features before final classifier.
  555. Returns:
  556. Output tensor.
  557. """
  558. return self.head(x, pre_logits=True) if pre_logits else self.head(x)
  559. def forward(self, x: torch.Tensor) -> torch.Tensor:
  560. """Forward pass."""
  561. x = self.forward_features(x)
  562. x = self.forward_head(x)
  563. return x
  564. def _init_weights(module: nn.Module, name: Optional[str] = None, head_init_scale: float = 1.0) -> None:
  565. """Initialize model weights.
  566. Args:
  567. module: Module to initialize.
  568. name: Module name.
  569. head_init_scale: Scale factor for head initialization.
  570. """
  571. if isinstance(module, nn.Conv2d):
  572. trunc_normal_(module.weight, std=.02)
  573. if module.bias is not None:
  574. nn.init.zeros_(module.bias)
  575. elif isinstance(module, nn.Linear):
  576. trunc_normal_(module.weight, std=.02)
  577. nn.init.zeros_(module.bias)
  578. if name and 'head.' in name:
  579. module.weight.data.mul_(head_init_scale)
  580. module.bias.data.mul_(head_init_scale)
  581. def checkpoint_filter_fn(state_dict, model):
  582. """ Remap FB checkpoints -> timm """
  583. if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict:
  584. return state_dict # non-FB checkpoint
  585. if 'model' in state_dict:
  586. state_dict = state_dict['model']
  587. out_dict = {}
  588. if 'visual.trunk.stem.0.weight' in state_dict:
  589. out_dict = {k.replace('visual.trunk.', ''): v for k, v in state_dict.items() if k.startswith('visual.trunk.')}
  590. if 'visual.head.proj.weight' in state_dict:
  591. out_dict['head.fc.weight'] = state_dict['visual.head.proj.weight']
  592. out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
  593. elif 'visual.head.mlp.fc1.weight' in state_dict:
  594. out_dict['head.pre_logits.fc.weight'] = state_dict['visual.head.mlp.fc1.weight']
  595. out_dict['head.pre_logits.fc.bias'] = state_dict['visual.head.mlp.fc1.bias']
  596. out_dict['head.fc.weight'] = state_dict['visual.head.mlp.fc2.weight']
  597. out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.mlp.fc2.weight'].shape[0])
  598. return out_dict
  599. import re
  600. for k, v in state_dict.items():
  601. k = k.replace('downsample_layers.0.', 'stem.')
  602. k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
  603. k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k)
  604. k = k.replace('dwconv', 'conv_dw')
  605. k = k.replace('pwconv', 'mlp.fc')
  606. if 'grn' in k:
  607. k = k.replace('grn.beta', 'mlp.grn.bias')
  608. k = k.replace('grn.gamma', 'mlp.grn.weight')
  609. v = v.reshape(v.shape[-1])
  610. k = k.replace('head.', 'head.fc.')
  611. if k.startswith('norm.'):
  612. k = k.replace('norm', 'head.norm')
  613. if v.ndim == 2 and 'head' not in k:
  614. model_shape = model.state_dict()[k].shape
  615. v = v.reshape(model_shape)
  616. out_dict[k] = v
  617. return out_dict
  618. def _create_convnext(variant, pretrained=False, **kwargs):
  619. if kwargs.get('pretrained_cfg', '') == 'fcmae':
  620. # NOTE fcmae pretrained weights have no classifier or final norm-layer (`head.norm`)
  621. # This is workaround loading with num_classes=0 w/o removing norm-layer.
  622. kwargs.setdefault('pretrained_strict', False)
  623. model = build_model_with_cfg(
  624. ConvNeXt, variant, pretrained,
  625. pretrained_filter_fn=checkpoint_filter_fn,
  626. feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
  627. **kwargs)
  628. return model
  629. def _cfg(url='', **kwargs):
  630. return {
  631. 'url': url,
  632. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  633. 'crop_pct': 0.875, 'interpolation': 'bicubic',
  634. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  635. 'first_conv': 'stem.0', 'classifier': 'head.fc',
  636. 'license': 'apache-2.0', **kwargs
  637. }
  638. def _cfgv2(url='', **kwargs):
  639. return {
  640. 'url': url,
  641. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  642. 'crop_pct': 0.875, 'interpolation': 'bicubic',
  643. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  644. 'first_conv': 'stem.0', 'classifier': 'head.fc',
  645. 'license': 'cc-by-nc-4.0', 'paper_ids': 'arXiv:2301.00808',
  646. 'paper_name': 'ConvNeXt-V2: Co-designing and Scaling ConvNets with Masked Autoencoders',
  647. 'origin_url': 'https://github.com/facebookresearch/ConvNeXt-V2',
  648. **kwargs
  649. }
  650. default_cfgs = generate_default_cfgs({
  651. # timm specific variants
  652. 'convnext_tiny.in12k_ft_in1k': _cfg(
  653. hf_hub_id='timm/',
  654. crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
  655. 'convnext_small.in12k_ft_in1k': _cfg(
  656. hf_hub_id='timm/',
  657. crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
  658. 'convnext_zepto_rms.ra4_e3600_r224_in1k': _cfg(
  659. hf_hub_id='timm/',
  660. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
  661. 'convnext_zepto_rms_ols.ra4_e3600_r224_in1k': _cfg(
  662. hf_hub_id='timm/',
  663. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  664. crop_pct=0.9),
  665. 'convnext_atto.d2_in1k': _cfg(
  666. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
  667. hf_hub_id='timm/',
  668. test_input_size=(3, 288, 288), test_crop_pct=0.95),
  669. 'convnext_atto_ols.a2_in1k': _cfg(
  670. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth',
  671. hf_hub_id='timm/',
  672. test_input_size=(3, 288, 288), test_crop_pct=0.95),
  673. 'convnext_atto_rms.untrained': _cfg(
  674. #hf_hub_id='timm/',
  675. test_input_size=(3, 256, 256), test_crop_pct=0.95),
  676. 'convnext_femto.d1_in1k': _cfg(
  677. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
  678. hf_hub_id='timm/',
  679. test_input_size=(3, 288, 288), test_crop_pct=0.95),
  680. 'convnext_femto_ols.d1_in1k': _cfg(
  681. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth',
  682. hf_hub_id='timm/',
  683. test_input_size=(3, 288, 288), test_crop_pct=0.95),
  684. 'convnext_pico.d1_in1k': _cfg(
  685. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth',
  686. hf_hub_id='timm/',
  687. test_input_size=(3, 288, 288), test_crop_pct=0.95),
  688. 'convnext_pico_ols.d1_in1k': _cfg(
  689. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth',
  690. hf_hub_id='timm/',
  691. crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
  692. 'convnext_nano.in12k_ft_in1k': _cfg(
  693. hf_hub_id='timm/',
  694. crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
  695. 'convnext_nano.d1h_in1k': _cfg(
  696. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth',
  697. hf_hub_id='timm/',
  698. crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
  699. 'convnext_nano_ols.d1h_in1k': _cfg(
  700. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth',
  701. hf_hub_id='timm/',
  702. crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
  703. 'convnext_tiny_hnf.a2h_in1k': _cfg(
  704. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
  705. hf_hub_id='timm/',
  706. crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
  707. 'convnext_nano.r384_in12k_ft_in1k': _cfg(
  708. hf_hub_id='timm/',
  709. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
  710. 'convnext_tiny.in12k_ft_in1k_384': _cfg(
  711. hf_hub_id='timm/',
  712. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  713. 'convnext_small.in12k_ft_in1k_384': _cfg(
  714. hf_hub_id='timm/',
  715. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  716. 'convnext_nano.in12k': _cfg(
  717. hf_hub_id='timm/',
  718. crop_pct=0.95, num_classes=11821),
  719. 'convnext_nano.r384_in12k': _cfg(
  720. hf_hub_id='timm/',
  721. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=11821),
  722. 'convnext_nano.r384_ad_in12k': _cfg(
  723. hf_hub_id='timm/',
  724. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=11821),
  725. 'convnext_tiny.in12k': _cfg(
  726. hf_hub_id='timm/',
  727. crop_pct=0.95, num_classes=11821),
  728. 'convnext_small.in12k': _cfg(
  729. hf_hub_id='timm/',
  730. crop_pct=0.95, num_classes=11821),
  731. 'convnext_tiny.fb_in22k_ft_in1k': _cfg(
  732. url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth',
  733. hf_hub_id='timm/',
  734. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  735. 'convnext_small.fb_in22k_ft_in1k': _cfg(
  736. url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth',
  737. hf_hub_id='timm/',
  738. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  739. 'convnext_base.fb_in22k_ft_in1k': _cfg(
  740. url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth',
  741. hf_hub_id='timm/',
  742. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  743. 'convnext_large.fb_in22k_ft_in1k': _cfg(
  744. url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth',
  745. hf_hub_id='timm/',
  746. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  747. 'convnext_xlarge.fb_in22k_ft_in1k': _cfg(
  748. url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth',
  749. hf_hub_id='timm/',
  750. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  751. 'convnext_tiny.fb_in1k': _cfg(
  752. url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
  753. hf_hub_id='timm/',
  754. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  755. 'convnext_small.fb_in1k': _cfg(
  756. url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
  757. hf_hub_id='timm/',
  758. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  759. 'convnext_base.fb_in1k': _cfg(
  760. url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
  761. hf_hub_id='timm/',
  762. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  763. 'convnext_large.fb_in1k': _cfg(
  764. url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
  765. hf_hub_id='timm/',
  766. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  767. 'convnext_tiny.fb_in22k_ft_in1k_384': _cfg(
  768. url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
  769. hf_hub_id='timm/',
  770. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  771. 'convnext_small.fb_in22k_ft_in1k_384': _cfg(
  772. url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth',
  773. hf_hub_id='timm/',
  774. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  775. 'convnext_base.fb_in22k_ft_in1k_384': _cfg(
  776. url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth',
  777. hf_hub_id='timm/',
  778. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  779. 'convnext_large.fb_in22k_ft_in1k_384': _cfg(
  780. url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth',
  781. hf_hub_id='timm/',
  782. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  783. 'convnext_xlarge.fb_in22k_ft_in1k_384': _cfg(
  784. url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth',
  785. hf_hub_id='timm/',
  786. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  787. 'convnext_tiny.fb_in22k': _cfg(
  788. url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
  789. hf_hub_id='timm/',
  790. num_classes=21841),
  791. 'convnext_small.fb_in22k': _cfg(
  792. url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
  793. hf_hub_id='timm/',
  794. num_classes=21841),
  795. 'convnext_base.fb_in22k': _cfg(
  796. url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
  797. hf_hub_id='timm/',
  798. num_classes=21841),
  799. 'convnext_large.fb_in22k': _cfg(
  800. url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
  801. hf_hub_id='timm/',
  802. num_classes=21841),
  803. 'convnext_xlarge.fb_in22k': _cfg(
  804. url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
  805. hf_hub_id='timm/',
  806. num_classes=21841),
  807. 'convnextv2_nano.fcmae_ft_in22k_in1k': _cfgv2(
  808. url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_224_ema.pt',
  809. hf_hub_id='timm/',
  810. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  811. 'convnextv2_nano.fcmae_ft_in22k_in1k_384': _cfgv2(
  812. url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_384_ema.pt',
  813. hf_hub_id='timm/',
  814. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  815. 'convnextv2_tiny.fcmae_ft_in22k_in1k': _cfgv2(
  816. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_224_ema.pt",
  817. hf_hub_id='timm/',
  818. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  819. 'convnextv2_tiny.fcmae_ft_in22k_in1k_384': _cfgv2(
  820. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_384_ema.pt",
  821. hf_hub_id='timm/',
  822. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  823. 'convnextv2_base.fcmae_ft_in22k_in1k': _cfgv2(
  824. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_224_ema.pt",
  825. hf_hub_id='timm/',
  826. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  827. 'convnextv2_base.fcmae_ft_in22k_in1k_384': _cfgv2(
  828. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt",
  829. hf_hub_id='timm/',
  830. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  831. 'convnextv2_large.fcmae_ft_in22k_in1k': _cfgv2(
  832. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_224_ema.pt",
  833. hf_hub_id='timm/',
  834. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  835. 'convnextv2_large.fcmae_ft_in22k_in1k_384': _cfgv2(
  836. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt",
  837. hf_hub_id='timm/',
  838. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  839. 'convnextv2_huge.fcmae_ft_in22k_in1k_384': _cfgv2(
  840. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_384_ema.pt",
  841. hf_hub_id='timm/',
  842. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  843. 'convnextv2_huge.fcmae_ft_in22k_in1k_512': _cfgv2(
  844. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_512_ema.pt",
  845. hf_hub_id='timm/',
  846. input_size=(3, 512, 512), pool_size=(15, 15), crop_pct=1.0, crop_mode='squash'),
  847. 'convnextv2_atto.fcmae_ft_in1k': _cfgv2(
  848. url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_atto_1k_224_ema.pt',
  849. hf_hub_id='timm/',
  850. test_input_size=(3, 288, 288), test_crop_pct=0.95),
  851. 'convnextv2_femto.fcmae_ft_in1k': _cfgv2(
  852. url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_femto_1k_224_ema.pt',
  853. hf_hub_id='timm/',
  854. test_input_size=(3, 288, 288), test_crop_pct=0.95),
  855. 'convnextv2_pico.fcmae_ft_in1k': _cfgv2(
  856. url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_pico_1k_224_ema.pt',
  857. hf_hub_id='timm/',
  858. test_input_size=(3, 288, 288), test_crop_pct=0.95),
  859. 'convnextv2_nano.fcmae_ft_in1k': _cfgv2(
  860. url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_nano_1k_224_ema.pt',
  861. hf_hub_id='timm/',
  862. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  863. 'convnextv2_tiny.fcmae_ft_in1k': _cfgv2(
  864. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_tiny_1k_224_ema.pt",
  865. hf_hub_id='timm/',
  866. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  867. 'convnextv2_base.fcmae_ft_in1k': _cfgv2(
  868. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_base_1k_224_ema.pt",
  869. hf_hub_id='timm/',
  870. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  871. 'convnextv2_large.fcmae_ft_in1k': _cfgv2(
  872. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_large_1k_224_ema.pt",
  873. hf_hub_id='timm/',
  874. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  875. 'convnextv2_huge.fcmae_ft_in1k': _cfgv2(
  876. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_huge_1k_224_ema.pt",
  877. hf_hub_id='timm/',
  878. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  879. 'convnextv2_atto.fcmae': _cfgv2(
  880. url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_atto_1k_224_fcmae.pt',
  881. hf_hub_id='timm/',
  882. num_classes=0),
  883. 'convnextv2_femto.fcmae': _cfgv2(
  884. url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_femto_1k_224_fcmae.pt',
  885. hf_hub_id='timm/',
  886. num_classes=0),
  887. 'convnextv2_pico.fcmae': _cfgv2(
  888. url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_pico_1k_224_fcmae.pt',
  889. hf_hub_id='timm/',
  890. num_classes=0),
  891. 'convnextv2_nano.fcmae': _cfgv2(
  892. url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_nano_1k_224_fcmae.pt',
  893. hf_hub_id='timm/',
  894. num_classes=0),
  895. 'convnextv2_tiny.fcmae': _cfgv2(
  896. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_tiny_1k_224_fcmae.pt",
  897. hf_hub_id='timm/',
  898. num_classes=0),
  899. 'convnextv2_base.fcmae': _cfgv2(
  900. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_base_1k_224_fcmae.pt",
  901. hf_hub_id='timm/',
  902. num_classes=0),
  903. 'convnextv2_large.fcmae': _cfgv2(
  904. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_large_1k_224_fcmae.pt",
  905. hf_hub_id='timm/',
  906. num_classes=0),
  907. 'convnextv2_huge.fcmae': _cfgv2(
  908. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_huge_1k_224_fcmae.pt",
  909. hf_hub_id='timm/',
  910. num_classes=0),
  911. 'convnextv2_small.untrained': _cfg(),
  912. # CLIP weights, fine-tuned on in1k or in12k + in1k
  913. 'convnext_base.clip_laion2b_augreg_ft_in12k_in1k': _cfg(
  914. hf_hub_id='timm/',
  915. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  916. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
  917. 'convnext_base.clip_laion2b_augreg_ft_in12k_in1k_384': _cfg(
  918. hf_hub_id='timm/',
  919. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  920. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  921. 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_320': _cfg(
  922. hf_hub_id='timm/',
  923. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  924. input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0),
  925. 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384': _cfg(
  926. hf_hub_id='timm/',
  927. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  928. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  929. 'convnext_base.clip_laion2b_augreg_ft_in1k': _cfg(
  930. hf_hub_id='timm/',
  931. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  932. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
  933. 'convnext_base.clip_laiona_augreg_ft_in1k_384': _cfg(
  934. hf_hub_id='timm/',
  935. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  936. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
  937. 'convnext_large_mlp.clip_laion2b_augreg_ft_in1k': _cfg(
  938. hf_hub_id='timm/',
  939. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  940. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0
  941. ),
  942. 'convnext_large_mlp.clip_laion2b_augreg_ft_in1k_384': _cfg(
  943. hf_hub_id='timm/',
  944. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  945. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'
  946. ),
  947. 'convnext_xxlarge.clip_laion2b_soup_ft_in1k': _cfg(
  948. hf_hub_id='timm/',
  949. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  950. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
  951. 'convnext_base.clip_laion2b_augreg_ft_in12k': _cfg(
  952. hf_hub_id='timm/',
  953. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
  954. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
  955. 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_320': _cfg(
  956. hf_hub_id='timm/',
  957. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
  958. input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0),
  959. 'convnext_large_mlp.clip_laion2b_augreg_ft_in12k_384': _cfg(
  960. hf_hub_id='timm/',
  961. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
  962. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  963. 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_384': _cfg(
  964. hf_hub_id='timm/',
  965. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
  966. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  967. 'convnext_xxlarge.clip_laion2b_soup_ft_in12k': _cfg(
  968. hf_hub_id='timm/',
  969. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
  970. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
  971. # CLIP original image tower weights
  972. 'convnext_base.clip_laion2b': _cfg(
  973. hf_hub_id='timm/',
  974. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  975. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
  976. 'convnext_base.clip_laion2b_augreg': _cfg(
  977. hf_hub_id='timm/',
  978. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  979. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
  980. 'convnext_base.clip_laiona': _cfg(
  981. hf_hub_id='timm/',
  982. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  983. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
  984. 'convnext_base.clip_laiona_320': _cfg(
  985. hf_hub_id='timm/',
  986. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  987. input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
  988. 'convnext_base.clip_laiona_augreg_320': _cfg(
  989. hf_hub_id='timm/',
  990. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  991. input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
  992. 'convnext_large_mlp.clip_laion2b_augreg': _cfg(
  993. hf_hub_id='timm/',
  994. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  995. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=768),
  996. 'convnext_large_mlp.clip_laion2b_ft_320': _cfg(
  997. hf_hub_id='timm/',
  998. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  999. input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
  1000. 'convnext_large_mlp.clip_laion2b_ft_soup_320': _cfg(
  1001. hf_hub_id='timm/',
  1002. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1003. input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
  1004. 'convnext_xxlarge.clip_laion2b_soup': _cfg(
  1005. hf_hub_id='timm/',
  1006. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1007. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
  1008. 'convnext_xxlarge.clip_laion2b_rewind': _cfg(
  1009. hf_hub_id='timm/',
  1010. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1011. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
  1012. # NOTE dinov3 convnext weights are under a specific license, and downstream outputs must be shared with this
  1013. # https://ai.meta.com/resources/models-and-libraries/dinov3-license/
  1014. 'convnext_tiny.dinov3_lvd1689m': _cfg(
  1015. hf_hub_id='timm/',
  1016. crop_pct=1.0,
  1017. num_classes=0,
  1018. license='dinov3-license',
  1019. ),
  1020. 'convnext_small.dinov3_lvd1689m': _cfg(
  1021. hf_hub_id='timm/',
  1022. crop_pct=1.0,
  1023. num_classes=0,
  1024. license='dinov3-license',
  1025. ),
  1026. 'convnext_base.dinov3_lvd1689m': _cfg(
  1027. hf_hub_id='timm/',
  1028. crop_pct=1.0,
  1029. num_classes=0,
  1030. license='dinov3-license',
  1031. ),
  1032. 'convnext_large.dinov3_lvd1689m': _cfg(
  1033. hf_hub_id='timm/',
  1034. crop_pct=1.0,
  1035. num_classes=0,
  1036. license='dinov3-license',
  1037. ),
  1038. "test_convnext.r160_in1k": _cfg(
  1039. hf_hub_id='timm/',
  1040. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  1041. input_size=(3, 160, 160), pool_size=(5, 5), crop_pct=0.95),
  1042. "test_convnext2.r160_in1k": _cfg(
  1043. hf_hub_id='timm/',
  1044. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  1045. input_size=(3, 160, 160), pool_size=(5, 5), crop_pct=0.95),
  1046. "test_convnext3.r160_in1k": _cfg(
  1047. hf_hub_id='timm/',
  1048. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  1049. input_size=(3, 160, 160), pool_size=(5, 5), crop_pct=0.95),
  1050. })
  1051. @register_model
  1052. def convnext_zepto_rms(pretrained=False, **kwargs) -> ConvNeXt:
  1053. # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
  1054. model_args = dict(depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='simplenorm')
  1055. model = _create_convnext('convnext_zepto_rms', pretrained=pretrained, **dict(model_args, **kwargs))
  1056. return model
  1057. @register_model
  1058. def convnext_zepto_rms_ols(pretrained=False, **kwargs) -> ConvNeXt:
  1059. # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
  1060. model_args = dict(
  1061. depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='simplenorm', stem_type='overlap_act')
  1062. model = _create_convnext('convnext_zepto_rms_ols', pretrained=pretrained, **dict(model_args, **kwargs))
  1063. return model
  1064. @register_model
  1065. def convnext_atto(pretrained=False, **kwargs) -> ConvNeXt:
  1066. # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
  1067. model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True)
  1068. model = _create_convnext('convnext_atto', pretrained=pretrained, **dict(model_args, **kwargs))
  1069. return model
  1070. @register_model
  1071. def convnext_atto_ols(pretrained=False, **kwargs) -> ConvNeXt:
  1072. # timm femto variant with overlapping 3x3 conv stem, wider than non-ols femto above, current param count 3.7M
  1073. model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, stem_type='overlap_tiered')
  1074. model = _create_convnext('convnext_atto_ols', pretrained=pretrained, **dict(model_args, **kwargs))
  1075. return model
  1076. @register_model
  1077. def convnext_atto_rms(pretrained=False, **kwargs) -> ConvNeXt:
  1078. # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
  1079. model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, norm_layer='rmsnorm2d')
  1080. model = _create_convnext('convnext_atto_rms', pretrained=pretrained, **dict(model_args, **kwargs))
  1081. return model
  1082. @register_model
  1083. def convnext_femto(pretrained=False, **kwargs) -> ConvNeXt:
  1084. # timm femto variant
  1085. model_args = dict(depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True)
  1086. model = _create_convnext('convnext_femto', pretrained=pretrained, **dict(model_args, **kwargs))
  1087. return model
  1088. @register_model
  1089. def convnext_femto_ols(pretrained=False, **kwargs) -> ConvNeXt:
  1090. # timm femto variant
  1091. model_args = dict(depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, stem_type='overlap_tiered')
  1092. model = _create_convnext('convnext_femto_ols', pretrained=pretrained, **dict(model_args, **kwargs))
  1093. return model
  1094. @register_model
  1095. def convnext_pico(pretrained=False, **kwargs) -> ConvNeXt:
  1096. # timm pico variant
  1097. model_args = dict(depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True)
  1098. model = _create_convnext('convnext_pico', pretrained=pretrained, **dict(model_args, **kwargs))
  1099. return model
  1100. @register_model
  1101. def convnext_pico_ols(pretrained=False, **kwargs) -> ConvNeXt:
  1102. # timm nano variant with overlapping 3x3 conv stem
  1103. model_args = dict(depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, stem_type='overlap_tiered')
  1104. model = _create_convnext('convnext_pico_ols', pretrained=pretrained, **dict(model_args, **kwargs))
  1105. return model
  1106. @register_model
  1107. def convnext_nano(pretrained=False, **kwargs) -> ConvNeXt:
  1108. # timm nano variant with standard stem and head
  1109. model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True)
  1110. model = _create_convnext('convnext_nano', pretrained=pretrained, **dict(model_args, **kwargs))
  1111. return model
  1112. @register_model
  1113. def convnext_nano_ols(pretrained=False, **kwargs) -> ConvNeXt:
  1114. # experimental nano variant with overlapping conv stem
  1115. model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, stem_type='overlap')
  1116. model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **dict(model_args, **kwargs))
  1117. return model
  1118. @register_model
  1119. def convnext_tiny_hnf(pretrained=False, **kwargs) -> ConvNeXt:
  1120. # experimental tiny variant with norm before pooling in head (head norm first)
  1121. model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True)
  1122. model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **dict(model_args, **kwargs))
  1123. return model
  1124. @register_model
  1125. def convnext_tiny(pretrained=False, **kwargs) -> ConvNeXt:
  1126. model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768))
  1127. model = _create_convnext('convnext_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
  1128. return model
  1129. @register_model
  1130. def convnext_small(pretrained=False, **kwargs) -> ConvNeXt:
  1131. model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768])
  1132. model = _create_convnext('convnext_small', pretrained=pretrained, **dict(model_args, **kwargs))
  1133. return model
  1134. @register_model
  1135. def convnext_base(pretrained=False, **kwargs) -> ConvNeXt:
  1136. model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])
  1137. model = _create_convnext('convnext_base', pretrained=pretrained, **dict(model_args, **kwargs))
  1138. return model
  1139. @register_model
  1140. def convnext_large(pretrained=False, **kwargs) -> ConvNeXt:
  1141. model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536])
  1142. model = _create_convnext('convnext_large', pretrained=pretrained, **dict(model_args, **kwargs))
  1143. return model
  1144. @register_model
  1145. def convnext_large_mlp(pretrained=False, **kwargs) -> ConvNeXt:
  1146. model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], head_hidden_size=1536)
  1147. model = _create_convnext('convnext_large_mlp', pretrained=pretrained, **dict(model_args, **kwargs))
  1148. return model
  1149. @register_model
  1150. def convnext_xlarge(pretrained=False, **kwargs) -> ConvNeXt:
  1151. model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048])
  1152. model = _create_convnext('convnext_xlarge', pretrained=pretrained, **dict(model_args, **kwargs))
  1153. return model
  1154. @register_model
  1155. def convnext_xxlarge(pretrained=False, **kwargs) -> ConvNeXt:
  1156. model_args = dict(depths=[3, 4, 30, 3], dims=[384, 768, 1536, 3072], norm_eps=kwargs.pop('norm_eps', 1e-5))
  1157. model = _create_convnext('convnext_xxlarge', pretrained=pretrained, **dict(model_args, **kwargs))
  1158. return model
  1159. @register_model
  1160. def convnextv2_atto(pretrained=False, **kwargs) -> ConvNeXt:
  1161. # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
  1162. model_args = dict(
  1163. depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), use_grn=True, ls_init_value=None, conv_mlp=True)
  1164. model = _create_convnext('convnextv2_atto', pretrained=pretrained, **dict(model_args, **kwargs))
  1165. return model
  1166. @register_model
  1167. def convnextv2_femto(pretrained=False, **kwargs) -> ConvNeXt:
  1168. # timm femto variant
  1169. model_args = dict(
  1170. depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), use_grn=True, ls_init_value=None, conv_mlp=True)
  1171. model = _create_convnext('convnextv2_femto', pretrained=pretrained, **dict(model_args, **kwargs))
  1172. return model
  1173. @register_model
  1174. def convnextv2_pico(pretrained=False, **kwargs) -> ConvNeXt:
  1175. # timm pico variant
  1176. model_args = dict(
  1177. depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), use_grn=True, ls_init_value=None, conv_mlp=True)
  1178. model = _create_convnext('convnextv2_pico', pretrained=pretrained, **dict(model_args, **kwargs))
  1179. return model
  1180. @register_model
  1181. def convnextv2_nano(pretrained=False, **kwargs) -> ConvNeXt:
  1182. # timm nano variant with standard stem and head
  1183. model_args = dict(
  1184. depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), use_grn=True, ls_init_value=None, conv_mlp=True)
  1185. model = _create_convnext('convnextv2_nano', pretrained=pretrained, **dict(model_args, **kwargs))
  1186. return model
  1187. @register_model
  1188. def convnextv2_tiny(pretrained=False, **kwargs) -> ConvNeXt:
  1189. model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), use_grn=True, ls_init_value=None)
  1190. model = _create_convnext('convnextv2_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
  1191. return model
  1192. @register_model
  1193. def convnextv2_small(pretrained=False, **kwargs) -> ConvNeXt:
  1194. model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], use_grn=True, ls_init_value=None)
  1195. model = _create_convnext('convnextv2_small', pretrained=pretrained, **dict(model_args, **kwargs))
  1196. return model
  1197. @register_model
  1198. def convnextv2_base(pretrained=False, **kwargs) -> ConvNeXt:
  1199. model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], use_grn=True, ls_init_value=None)
  1200. model = _create_convnext('convnextv2_base', pretrained=pretrained, **dict(model_args, **kwargs))
  1201. return model
  1202. @register_model
  1203. def convnextv2_large(pretrained=False, **kwargs) -> ConvNeXt:
  1204. model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], use_grn=True, ls_init_value=None)
  1205. model = _create_convnext('convnextv2_large', pretrained=pretrained, **dict(model_args, **kwargs))
  1206. return model
  1207. @register_model
  1208. def convnextv2_huge(pretrained=False, **kwargs) -> ConvNeXt:
  1209. model_args = dict(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], use_grn=True, ls_init_value=None)
  1210. model = _create_convnext('convnextv2_huge', pretrained=pretrained, **dict(model_args, **kwargs))
  1211. return model
  1212. @register_model
  1213. def test_convnext(pretrained=False, **kwargs) -> ConvNeXt:
  1214. model_args = dict(depths=[1, 2, 4, 2], dims=[24, 32, 48, 64], norm_eps=kwargs.pop('norm_eps', 1e-5), act_layer='gelu_tanh')
  1215. model = _create_convnext('test_convnext', pretrained=pretrained, **dict(model_args, **kwargs))
  1216. return model
  1217. @register_model
  1218. def test_convnext2(pretrained=False, **kwargs) -> ConvNeXt:
  1219. model_args = dict(depths=[1, 1, 1, 1], dims=[32, 64, 96, 128], norm_eps=kwargs.pop('norm_eps', 1e-5), act_layer='gelu_tanh')
  1220. model = _create_convnext('test_convnext2', pretrained=pretrained, **dict(model_args, **kwargs))
  1221. return model
  1222. @register_model
  1223. def test_convnext3(pretrained=False, **kwargs) -> ConvNeXt:
  1224. model_args = dict(
  1225. depths=[1, 1, 1, 1], dims=[32, 64, 96, 128], norm_eps=kwargs.pop('norm_eps', 1e-5), kernel_sizes=(7, 5, 5, 3), act_layer='silu')
  1226. model = _create_convnext('test_convnext3', pretrained=pretrained, **dict(model_args, **kwargs))
  1227. return model
  1228. register_model_deprecations(__name__, {
  1229. 'convnext_tiny_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k',
  1230. 'convnext_small_in22ft1k': 'convnext_small.fb_in22k_ft_in1k',
  1231. 'convnext_base_in22ft1k': 'convnext_base.fb_in22k_ft_in1k',
  1232. 'convnext_large_in22ft1k': 'convnext_large.fb_in22k_ft_in1k',
  1233. 'convnext_xlarge_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k',
  1234. 'convnext_tiny_384_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k_384',
  1235. 'convnext_small_384_in22ft1k': 'convnext_small.fb_in22k_ft_in1k_384',
  1236. 'convnext_base_384_in22ft1k': 'convnext_base.fb_in22k_ft_in1k_384',
  1237. 'convnext_large_384_in22ft1k': 'convnext_large.fb_in22k_ft_in1k_384',
  1238. 'convnext_xlarge_384_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k_384',
  1239. 'convnext_tiny_in22k': 'convnext_tiny.fb_in22k',
  1240. 'convnext_small_in22k': 'convnext_small.fb_in22k',
  1241. 'convnext_base_in22k': 'convnext_base.fb_in22k',
  1242. 'convnext_large_in22k': 'convnext_large.fb_in22k',
  1243. 'convnext_xlarge_in22k': 'convnext_xlarge.fb_in22k',
  1244. })