mlp_mixer.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880
  1. """ MLP-Mixer, ResMLP, and gMLP in PyTorch
  2. This impl originally based on MLP-Mixer paper.
  3. Official JAX impl: https://github.com/google-research/vision_transformer/blob/linen/vit_jax/models_mixer.py
  4. Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
  5. @article{tolstikhin2021,
  6. title={MLP-Mixer: An all-MLP Architecture for Vision},
  7. author={Tolstikhin, Ilya and Houlsby, Neil and Kolesnikov, Alexander and Beyer, Lucas and Zhai, Xiaohua and Unterthiner,
  8. Thomas and Yung, Jessica and Keysers, Daniel and Uszkoreit, Jakob and Lucic, Mario and Dosovitskiy, Alexey},
  9. journal={arXiv preprint arXiv:2105.01601},
  10. year={2021}
  11. }
  12. Also supporting ResMlp, and a preliminary (not verified) implementations of gMLP
  13. Code: https://github.com/facebookresearch/deit
  14. Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
  15. @misc{touvron2021resmlp,
  16. title={ResMLP: Feedforward networks for image classification with data-efficient training},
  17. author={Hugo Touvron and Piotr Bojanowski and Mathilde Caron and Matthieu Cord and Alaaeldin El-Nouby and
  18. Edouard Grave and Armand Joulin and Gabriel Synnaeve and Jakob Verbeek and Hervé Jégou},
  19. year={2021},
  20. eprint={2105.03404},
  21. }
  22. Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
  23. @misc{liu2021pay,
  24. title={Pay Attention to MLPs},
  25. author={Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
  26. year={2021},
  27. eprint={2105.08050},
  28. }
  29. A thank you to paper authors for releasing code and weights.
  30. Hacked together by / Copyright 2021 Ross Wightman
  31. """
  32. import math
  33. from functools import partial
  34. from typing import Any, Dict, List, Optional, Type, Union, Tuple
  35. import torch
  36. import torch.nn as nn
  37. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  38. from timm.layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple
  39. from ._builder import build_model_with_cfg
  40. from ._features import feature_take_indices
  41. from ._manipulate import named_apply, checkpoint, checkpoint_seq
  42. from ._registry import generate_default_cfgs, register_model, register_model_deprecations
  43. __all__ = ['MixerBlock', 'MlpMixer'] # model_registry will add each entrypoint fn to this
  44. class MixerBlock(nn.Module):
  45. """Residual Block w/ token mixing and channel MLPs.
  46. Based on: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
  47. """
  48. def __init__(
  49. self,
  50. dim: int,
  51. seq_len: int,
  52. mlp_ratio: Union[float, Tuple[float, float]] = (0.5, 4.0),
  53. mlp_layer: Type[nn.Module] = Mlp,
  54. norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  55. act_layer: Type[nn.Module] = nn.GELU,
  56. drop: float = 0.,
  57. drop_path: float = 0.,
  58. device=None,
  59. dtype=None,
  60. ) -> None:
  61. """Initialize MixerBlock.
  62. Args:
  63. dim: Dimension of input features.
  64. seq_len: Sequence length.
  65. mlp_ratio: Expansion ratios for token mixing and channel MLPs.
  66. mlp_layer: MLP layer class.
  67. norm_layer: Normalization layer.
  68. act_layer: Activation layer.
  69. drop: Dropout rate.
  70. drop_path: Drop path rate.
  71. """
  72. dd = {'device': device, 'dtype': dtype}
  73. super().__init__()
  74. tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)]
  75. self.norm1 = norm_layer(dim, **dd)
  76. self.mlp_tokens = mlp_layer(seq_len, tokens_dim, act_layer=act_layer, drop=drop, **dd)
  77. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  78. self.norm2 = norm_layer(dim, **dd)
  79. self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop, **dd)
  80. def forward(self, x: torch.Tensor) -> torch.Tensor:
  81. """Forward pass."""
  82. x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2))
  83. x = x + self.drop_path(self.mlp_channels(self.norm2(x)))
  84. return x
  85. class Affine(nn.Module):
  86. """Affine transformation layer."""
  87. def __init__(self, dim: int, device=None, dtype=None) -> None:
  88. """Initialize Affine layer.
  89. Args:
  90. dim: Dimension of features.
  91. """
  92. dd = {'device': device, 'dtype': dtype}
  93. super().__init__()
  94. self.alpha = nn.Parameter(torch.ones((1, 1, dim), **dd))
  95. self.beta = nn.Parameter(torch.zeros((1, 1, dim), **dd))
  96. def forward(self, x: torch.Tensor) -> torch.Tensor:
  97. """Apply affine transformation."""
  98. return torch.addcmul(self.beta, self.alpha, x)
  99. class ResBlock(nn.Module):
  100. """Residual MLP block w/ LayerScale and Affine 'norm'.
  101. Based on: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
  102. """
  103. def __init__(
  104. self,
  105. dim: int,
  106. seq_len: int,
  107. mlp_ratio: float = 4,
  108. mlp_layer: Type[nn.Module] = Mlp,
  109. norm_layer: Type[nn.Module] = Affine,
  110. act_layer: Type[nn.Module] = nn.GELU,
  111. init_values: float = 1e-4,
  112. drop: float = 0.,
  113. drop_path: float = 0.,
  114. device=None,
  115. dtype=None,
  116. ) -> None:
  117. """Initialize ResBlock.
  118. Args:
  119. dim: Dimension of input features.
  120. seq_len: Sequence length.
  121. mlp_ratio: Channel MLP expansion ratio.
  122. mlp_layer: MLP layer class.
  123. norm_layer: Normalization layer.
  124. act_layer: Activation layer.
  125. init_values: Initial values for layer scale.
  126. drop: Dropout rate.
  127. drop_path: Drop path rate.
  128. """
  129. dd = {'device': device, 'dtype': dtype}
  130. super().__init__()
  131. channel_dim = int(dim * mlp_ratio)
  132. self.norm1 = norm_layer(dim, **dd)
  133. self.linear_tokens = nn.Linear(seq_len, seq_len, **dd)
  134. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  135. self.norm2 = norm_layer(dim, **dd)
  136. self.mlp_channels = mlp_layer(dim, channel_dim, act_layer=act_layer, drop=drop, **dd)
  137. self.ls1 = nn.Parameter(init_values * torch.ones(dim, **dd))
  138. self.ls2 = nn.Parameter(init_values * torch.ones(dim, **dd))
  139. def forward(self, x: torch.Tensor) -> torch.Tensor:
  140. """Forward pass."""
  141. x = x + self.drop_path(self.ls1 * self.linear_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2))
  142. x = x + self.drop_path(self.ls2 * self.mlp_channels(self.norm2(x)))
  143. return x
  144. class SpatialGatingUnit(nn.Module):
  145. """Spatial Gating Unit.
  146. Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
  147. """
  148. def __init__(
  149. self,
  150. dim: int,
  151. seq_len: int,
  152. norm_layer: Type[nn.Module] = nn.LayerNorm,
  153. device=None,
  154. dtype=None,
  155. ) -> None:
  156. """Initialize Spatial Gating Unit.
  157. Args:
  158. dim: Dimension of input features.
  159. seq_len: Sequence length.
  160. norm_layer: Normalization layer.
  161. """
  162. dd = {'device': device, 'dtype': dtype}
  163. super().__init__()
  164. gate_dim = dim // 2
  165. self.norm = norm_layer(gate_dim, **dd)
  166. self.proj = nn.Linear(seq_len, seq_len, **dd)
  167. def init_weights(self) -> None:
  168. """Initialize weights for projection gate."""
  169. # special init for the projection gate, called as override by base model init
  170. nn.init.normal_(self.proj.weight, std=1e-6)
  171. nn.init.ones_(self.proj.bias)
  172. def forward(self, x: torch.Tensor) -> torch.Tensor:
  173. """Apply spatial gating."""
  174. u, v = x.chunk(2, dim=-1)
  175. v = self.norm(v)
  176. v = self.proj(v.transpose(-1, -2))
  177. return u * v.transpose(-1, -2)
  178. class SpatialGatingBlock(nn.Module):
  179. """Residual Block w/ Spatial Gating.
  180. Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
  181. """
  182. def __init__(
  183. self,
  184. dim: int,
  185. seq_len: int,
  186. mlp_ratio: float = 4,
  187. mlp_layer: Type[nn.Module] = GatedMlp,
  188. norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  189. act_layer: Type[nn.Module] = nn.GELU,
  190. drop: float = 0.,
  191. drop_path: float = 0.,
  192. device=None,
  193. dtype=None,
  194. ) -> None:
  195. """Initialize SpatialGatingBlock.
  196. Args:
  197. dim: Dimension of input features.
  198. seq_len: Sequence length.
  199. mlp_ratio: Channel MLP expansion ratio.
  200. mlp_layer: MLP layer class.
  201. norm_layer: Normalization layer.
  202. act_layer: Activation layer.
  203. drop: Dropout rate.
  204. drop_path: Drop path rate.
  205. """
  206. dd = {'device': device, 'dtype': dtype}
  207. super().__init__()
  208. channel_dim = int(dim * mlp_ratio)
  209. self.norm = norm_layer(dim, **dd)
  210. sgu = partial(SpatialGatingUnit, seq_len=seq_len, **dd)
  211. self.mlp_channels = mlp_layer(
  212. dim,
  213. channel_dim,
  214. act_layer=act_layer,
  215. gate_layer=sgu,
  216. drop=drop,
  217. **dd,
  218. )
  219. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  220. def forward(self, x: torch.Tensor) -> torch.Tensor:
  221. """Forward pass."""
  222. x = x + self.drop_path(self.mlp_channels(self.norm(x)))
  223. return x
  224. class MlpMixer(nn.Module):
  225. """MLP-Mixer model architecture.
  226. Based on: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
  227. """
  228. def __init__(
  229. self,
  230. num_classes: int = 1000,
  231. img_size: int = 224,
  232. in_chans: int = 3,
  233. patch_size: int = 16,
  234. num_blocks: int = 8,
  235. embed_dim: int = 512,
  236. mlp_ratio: Union[float, Tuple[float, float]] = (0.5, 4.0),
  237. block_layer: Type[nn.Module] = MixerBlock,
  238. mlp_layer: Type[nn.Module] = Mlp,
  239. norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  240. act_layer: Type[nn.Module] = nn.GELU,
  241. drop_rate: float = 0.,
  242. proj_drop_rate: float = 0.,
  243. drop_path_rate: float = 0.,
  244. nlhb: bool = False,
  245. stem_norm: bool = False,
  246. global_pool: str = 'avg',
  247. device=None,
  248. dtype=None,
  249. ) -> None:
  250. """Initialize MLP-Mixer.
  251. Args:
  252. num_classes: Number of classes for classification.
  253. img_size: Input image size.
  254. in_chans: Number of input channels.
  255. patch_size: Patch size.
  256. num_blocks: Number of mixer blocks.
  257. embed_dim: Embedding dimension.
  258. mlp_ratio: MLP expansion ratio(s).
  259. block_layer: Block layer class.
  260. mlp_layer: MLP layer class.
  261. norm_layer: Normalization layer.
  262. act_layer: Activation layer.
  263. drop_rate: Head dropout rate.
  264. proj_drop_rate: Projection dropout rate.
  265. drop_path_rate: Drop path rate.
  266. nlhb: Use negative log bias initialization.
  267. stem_norm: Apply normalization to stem.
  268. global_pool: Global pooling type.
  269. """
  270. super().__init__()
  271. dd = {'device': device, 'dtype': dtype}
  272. self.num_classes = num_classes
  273. self.in_chans = in_chans
  274. self.global_pool = global_pool
  275. self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
  276. self.grad_checkpointing = False
  277. self.stem = PatchEmbed(
  278. img_size=img_size,
  279. patch_size=patch_size,
  280. in_chans=in_chans,
  281. embed_dim=embed_dim,
  282. norm_layer=norm_layer if stem_norm else None,
  283. **dd,
  284. )
  285. reduction = self.stem.feat_ratio() if hasattr(self.stem, 'feat_ratio') else patch_size
  286. # FIXME drop_path (stochastic depth scaling rule or all the same?)
  287. self.blocks = nn.Sequential(*[
  288. block_layer(
  289. embed_dim,
  290. self.stem.num_patches,
  291. mlp_ratio,
  292. mlp_layer=mlp_layer,
  293. norm_layer=norm_layer,
  294. act_layer=act_layer,
  295. drop=proj_drop_rate,
  296. drop_path=drop_path_rate,
  297. **dd,
  298. )
  299. for _ in range(num_blocks)])
  300. self.feature_info = [
  301. dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(num_blocks)]
  302. self.norm = norm_layer(embed_dim, **dd)
  303. self.head_drop = nn.Dropout(drop_rate)
  304. self.head = nn.Linear(embed_dim, self.num_classes, **dd) if num_classes > 0 else nn.Identity()
  305. self.init_weights(nlhb=nlhb)
  306. @torch.jit.ignore
  307. def init_weights(self, nlhb: bool = False) -> None:
  308. """Initialize model weights.
  309. Args:
  310. nlhb: Use negative log bias initialization for head.
  311. """
  312. head_bias = -math.log(self.num_classes) if nlhb else 0.
  313. named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first
  314. @torch.jit.ignore
  315. def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
  316. """Create regex patterns for parameter grouping.
  317. Args:
  318. coarse: Use coarse grouping.
  319. Returns:
  320. Dictionary mapping group names to regex patterns.
  321. """
  322. return dict(
  323. stem=r'^stem', # stem and embed
  324. blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
  325. )
  326. @torch.jit.ignore
  327. def set_grad_checkpointing(self, enable: bool = True) -> None:
  328. """Enable or disable gradient checkpointing.
  329. Args:
  330. enable: Whether to enable gradient checkpointing.
  331. """
  332. self.grad_checkpointing = enable
  333. @torch.jit.ignore
  334. def get_classifier(self) -> nn.Module:
  335. """Get the classifier module."""
  336. return self.head
  337. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
  338. """Reset the classifier head.
  339. Args:
  340. num_classes: Number of classes for new classifier.
  341. global_pool: Global pooling type.
  342. """
  343. self.num_classes = num_classes
  344. if global_pool is not None:
  345. assert global_pool in ('', 'avg')
  346. self.global_pool = global_pool
  347. device, dtype = self.head.weight.device, self.head.weight.dtype if hasattr(self.head, 'weight') else (None, None)
  348. self.head = nn.Linear(self.embed_dim, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity()
  349. def forward_intermediates(
  350. self,
  351. x: torch.Tensor,
  352. indices: Optional[Union[int, List[int]]] = None,
  353. norm: bool = False,
  354. stop_early: bool = False,
  355. output_fmt: str = 'NCHW',
  356. intermediates_only: bool = False,
  357. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  358. """Forward features that returns intermediates.
  359. Args:
  360. x: Input image tensor.
  361. indices: Take last n blocks if int, all if None, select matching indices if sequence.
  362. norm: Apply norm layer to all intermediates.
  363. stop_early: Stop iterating over blocks when last desired intermediate hit.
  364. output_fmt: Shape of intermediate feature outputs ('NCHW' or 'NLC').
  365. intermediates_only: Only return intermediate features.
  366. Returns:
  367. List of intermediate features or tuple of (final features, intermediates).
  368. """
  369. assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
  370. reshape = output_fmt == 'NCHW'
  371. intermediates = []
  372. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  373. # forward pass
  374. B, _, height, width = x.shape
  375. x = self.stem(x)
  376. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  377. blocks = self.blocks
  378. else:
  379. blocks = self.blocks[:max_index + 1]
  380. for i, blk in enumerate(blocks):
  381. if self.grad_checkpointing and not torch.jit.is_scripting():
  382. x = checkpoint(blk, x)
  383. else:
  384. x = blk(x)
  385. if i in take_indices:
  386. # normalize intermediates with final norm layer if enabled
  387. intermediates.append(self.norm(x) if norm else x)
  388. # process intermediates
  389. if reshape:
  390. # reshape to BCHW output format
  391. H, W = self.stem.dynamic_feat_size((height, width))
  392. intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
  393. if intermediates_only:
  394. return intermediates
  395. x = self.norm(x)
  396. return x, intermediates
  397. def prune_intermediate_layers(
  398. self,
  399. indices: Union[int, List[int]] = 1,
  400. prune_norm: bool = False,
  401. prune_head: bool = True,
  402. ) -> List[int]:
  403. """Prune layers not required for specified intermediates.
  404. Args:
  405. indices: Indices of intermediate layers to keep.
  406. prune_norm: Whether to prune normalization layer.
  407. prune_head: Whether to prune the classifier head.
  408. Returns:
  409. List of indices that were kept.
  410. """
  411. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  412. self.blocks = self.blocks[:max_index + 1] # truncate blocks
  413. if prune_norm:
  414. self.norm = nn.Identity()
  415. if prune_head:
  416. self.reset_classifier(0, '')
  417. return take_indices
  418. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  419. """Forward pass through feature extraction layers."""
  420. x = self.stem(x)
  421. if self.grad_checkpointing and not torch.jit.is_scripting():
  422. x = checkpoint_seq(self.blocks, x)
  423. else:
  424. x = self.blocks(x)
  425. x = self.norm(x)
  426. return x
  427. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  428. """Forward pass through classifier head.
  429. Args:
  430. x: Feature tensor.
  431. pre_logits: Return features before final classifier.
  432. Returns:
  433. Output tensor.
  434. """
  435. if self.global_pool == 'avg':
  436. x = x.mean(dim=1)
  437. x = self.head_drop(x)
  438. return x if pre_logits else self.head(x)
  439. def forward(self, x: torch.Tensor) -> torch.Tensor:
  440. """Forward pass."""
  441. x = self.forward_features(x)
  442. x = self.forward_head(x)
  443. return x
  444. def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax: bool = False) -> None:
  445. """Mixer weight initialization (trying to match Flax defaults).
  446. Args:
  447. module: Module to initialize.
  448. name: Module name.
  449. head_bias: Bias value for head layer.
  450. flax: Use Flax-style initialization.
  451. """
  452. if isinstance(module, nn.Linear):
  453. if name.startswith('head'):
  454. nn.init.zeros_(module.weight)
  455. nn.init.constant_(module.bias, head_bias)
  456. else:
  457. if flax:
  458. # Flax defaults
  459. lecun_normal_(module.weight)
  460. if module.bias is not None:
  461. nn.init.zeros_(module.bias)
  462. else:
  463. # like MLP init in vit (my original init)
  464. nn.init.xavier_uniform_(module.weight)
  465. if module.bias is not None:
  466. if 'mlp' in name:
  467. nn.init.normal_(module.bias, std=1e-6)
  468. else:
  469. nn.init.zeros_(module.bias)
  470. elif isinstance(module, nn.Conv2d):
  471. lecun_normal_(module.weight)
  472. if module.bias is not None:
  473. nn.init.zeros_(module.bias)
  474. elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
  475. nn.init.ones_(module.weight)
  476. nn.init.zeros_(module.bias)
  477. elif hasattr(module, 'init_weights'):
  478. # NOTE if a parent module contains init_weights method, it can override the init of the
  479. # child modules as this will be called in depth-first order.
  480. module.init_weights()
  481. def checkpoint_filter_fn(state_dict, model):
  482. """ Remap checkpoints if needed """
  483. if 'patch_embed.proj.weight' in state_dict:
  484. # Remap FB ResMlp models -> timm
  485. out_dict = {}
  486. for k, v in state_dict.items():
  487. k = k.replace('patch_embed.', 'stem.')
  488. k = k.replace('attn.', 'linear_tokens.')
  489. k = k.replace('mlp.', 'mlp_channels.')
  490. k = k.replace('gamma_', 'ls')
  491. if k.endswith('.alpha') or k.endswith('.beta'):
  492. v = v.reshape(1, 1, -1)
  493. out_dict[k] = v
  494. return out_dict
  495. return state_dict
  496. def _create_mixer(variant, pretrained=False, **kwargs) -> MlpMixer:
  497. out_indices = kwargs.pop('out_indices', 3)
  498. model = build_model_with_cfg(
  499. MlpMixer,
  500. variant,
  501. pretrained,
  502. pretrained_filter_fn=checkpoint_filter_fn,
  503. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  504. **kwargs,
  505. )
  506. return model
  507. def _cfg(url='', **kwargs) -> Dict[str, Any]:
  508. return {
  509. 'url': url,
  510. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  511. 'crop_pct': 0.875, 'interpolation': 'bicubic', 'fixed_input_size': True,
  512. 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
  513. 'first_conv': 'stem.proj', 'classifier': 'head',
  514. 'license': 'apache-2.0',
  515. **kwargs
  516. }
  517. default_cfgs = generate_default_cfgs({
  518. 'mixer_s32_224.untrained': _cfg(),
  519. 'mixer_s16_224.untrained': _cfg(),
  520. 'mixer_b32_224.untrained': _cfg(),
  521. 'mixer_b16_224.goog_in21k_ft_in1k': _cfg(
  522. hf_hub_id='timm/',
  523. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth',
  524. ),
  525. 'mixer_b16_224.goog_in21k': _cfg(
  526. hf_hub_id='timm/',
  527. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224_in21k-617b3de2.pth',
  528. num_classes=21843
  529. ),
  530. 'mixer_l32_224.untrained': _cfg(),
  531. 'mixer_l16_224.goog_in21k_ft_in1k': _cfg(
  532. hf_hub_id='timm/',
  533. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224-92f9adc4.pth',
  534. ),
  535. 'mixer_l16_224.goog_in21k': _cfg(
  536. hf_hub_id='timm/',
  537. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224_in21k-846aa33c.pth',
  538. num_classes=21843
  539. ),
  540. # Mixer ImageNet-21K-P pretraining
  541. 'mixer_b16_224.miil_in21k': _cfg(
  542. hf_hub_id='timm/',
  543. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mixer_b16_224_miil_in21k-2a558a71.pth',
  544. mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
  545. ),
  546. 'mixer_b16_224.miil_in21k_ft_in1k': _cfg(
  547. hf_hub_id='timm/',
  548. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mixer_b16_224_miil-9229a591.pth',
  549. mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear',
  550. ),
  551. 'gmixer_12_224.untrained': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  552. 'gmixer_24_224.ra3_in1k': _cfg(
  553. hf_hub_id='timm/',
  554. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmixer_24_224_raa-7daf7ae6.pth',
  555. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  556. 'resmlp_12_224.fb_in1k': _cfg(
  557. hf_hub_id='timm/',
  558. url='https://dl.fbaipublicfiles.com/deit/resmlp_12_no_dist.pth',
  559. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  560. 'resmlp_24_224.fb_in1k': _cfg(
  561. hf_hub_id='timm/',
  562. url='https://dl.fbaipublicfiles.com/deit/resmlp_24_no_dist.pth',
  563. #url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resmlp_24_224_raa-a8256759.pth',
  564. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  565. 'resmlp_36_224.fb_in1k': _cfg(
  566. hf_hub_id='timm/',
  567. url='https://dl.fbaipublicfiles.com/deit/resmlp_36_no_dist.pth',
  568. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  569. 'resmlp_big_24_224.fb_in1k': _cfg(
  570. hf_hub_id='timm/',
  571. url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_no_dist.pth',
  572. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  573. 'resmlp_12_224.fb_distilled_in1k': _cfg(
  574. hf_hub_id='timm/',
  575. url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dist.pth',
  576. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  577. 'resmlp_24_224.fb_distilled_in1k': _cfg(
  578. hf_hub_id='timm/',
  579. url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dist.pth',
  580. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  581. 'resmlp_36_224.fb_distilled_in1k': _cfg(
  582. hf_hub_id='timm/',
  583. url='https://dl.fbaipublicfiles.com/deit/resmlp_36_dist.pth',
  584. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  585. 'resmlp_big_24_224.fb_distilled_in1k': _cfg(
  586. hf_hub_id='timm/',
  587. url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_dist.pth',
  588. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  589. 'resmlp_big_24_224.fb_in22k_ft_in1k': _cfg(
  590. hf_hub_id='timm/',
  591. url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth',
  592. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  593. 'resmlp_12_224.fb_dino': _cfg(
  594. hf_hub_id='timm/',
  595. url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dino.pth',
  596. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  597. 'resmlp_24_224.fb_dino': _cfg(
  598. hf_hub_id='timm/',
  599. url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dino.pth',
  600. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  601. 'gmlp_ti16_224.untrained': _cfg(),
  602. 'gmlp_s16_224.ra3_in1k': _cfg(
  603. hf_hub_id='timm/',
  604. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmlp_s16_224_raa-10536d42.pth',
  605. ),
  606. 'gmlp_b16_224.untrained': _cfg(),
  607. })
  608. @register_model
  609. def mixer_s32_224(pretrained=False, **kwargs) -> MlpMixer:
  610. """ Mixer-S/32 224x224
  611. Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
  612. """
  613. model_args = dict(patch_size=32, num_blocks=8, embed_dim=512, **kwargs)
  614. model = _create_mixer('mixer_s32_224', pretrained=pretrained, **model_args)
  615. return model
  616. @register_model
  617. def mixer_s16_224(pretrained=False, **kwargs) -> MlpMixer:
  618. """ Mixer-S/16 224x224
  619. Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
  620. """
  621. model_args = dict(patch_size=16, num_blocks=8, embed_dim=512, **kwargs)
  622. model = _create_mixer('mixer_s16_224', pretrained=pretrained, **model_args)
  623. return model
  624. @register_model
  625. def mixer_b32_224(pretrained=False, **kwargs) -> MlpMixer:
  626. """ Mixer-B/32 224x224
  627. Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
  628. """
  629. model_args = dict(patch_size=32, num_blocks=12, embed_dim=768, **kwargs)
  630. model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args)
  631. return model
  632. @register_model
  633. def mixer_b16_224(pretrained=False, **kwargs) -> MlpMixer:
  634. """ Mixer-B/16 224x224. ImageNet-1k pretrained weights.
  635. Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
  636. """
  637. model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
  638. model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args)
  639. return model
  640. @register_model
  641. def mixer_l32_224(pretrained=False, **kwargs) -> MlpMixer:
  642. """ Mixer-L/32 224x224.
  643. Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
  644. """
  645. model_args = dict(patch_size=32, num_blocks=24, embed_dim=1024, **kwargs)
  646. model = _create_mixer('mixer_l32_224', pretrained=pretrained, **model_args)
  647. return model
  648. @register_model
  649. def mixer_l16_224(pretrained=False, **kwargs) -> MlpMixer:
  650. """ Mixer-L/16 224x224. ImageNet-1k pretrained weights.
  651. Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
  652. """
  653. model_args = dict(patch_size=16, num_blocks=24, embed_dim=1024, **kwargs)
  654. model = _create_mixer('mixer_l16_224', pretrained=pretrained, **model_args)
  655. return model
  656. @register_model
  657. def gmixer_12_224(pretrained=False, **kwargs) -> MlpMixer:
  658. """ Glu-Mixer-12 224x224
  659. Experiment by Ross Wightman, adding SwiGLU to MLP-Mixer
  660. """
  661. model_args = dict(
  662. patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=(1.0, 4.0),
  663. mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs)
  664. model = _create_mixer('gmixer_12_224', pretrained=pretrained, **model_args)
  665. return model
  666. @register_model
  667. def gmixer_24_224(pretrained=False, **kwargs) -> MlpMixer:
  668. """ Glu-Mixer-24 224x224
  669. Experiment by Ross Wightman, adding SwiGLU to MLP-Mixer
  670. """
  671. model_args = dict(
  672. patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=(1.0, 4.0),
  673. mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs)
  674. model = _create_mixer('gmixer_24_224', pretrained=pretrained, **model_args)
  675. return model
  676. @register_model
  677. def resmlp_12_224(pretrained=False, **kwargs) -> MlpMixer:
  678. """ ResMLP-12
  679. Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
  680. """
  681. model_args = dict(
  682. patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
  683. model = _create_mixer('resmlp_12_224', pretrained=pretrained, **model_args)
  684. return model
  685. @register_model
  686. def resmlp_24_224(pretrained=False, **kwargs) -> MlpMixer:
  687. """ ResMLP-24
  688. Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
  689. """
  690. model_args = dict(
  691. patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4,
  692. block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs)
  693. model = _create_mixer('resmlp_24_224', pretrained=pretrained, **model_args)
  694. return model
  695. @register_model
  696. def resmlp_36_224(pretrained=False, **kwargs) -> MlpMixer:
  697. """ ResMLP-36
  698. Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
  699. """
  700. model_args = dict(
  701. patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4,
  702. block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
  703. model = _create_mixer('resmlp_36_224', pretrained=pretrained, **model_args)
  704. return model
  705. @register_model
  706. def resmlp_big_24_224(pretrained=False, **kwargs) -> MlpMixer:
  707. """ ResMLP-B-24
  708. Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
  709. """
  710. model_args = dict(
  711. patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4,
  712. block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
  713. model = _create_mixer('resmlp_big_24_224', pretrained=pretrained, **model_args)
  714. return model
  715. @register_model
  716. def gmlp_ti16_224(pretrained=False, **kwargs) -> MlpMixer:
  717. """ gMLP-Tiny
  718. Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
  719. """
  720. model_args = dict(
  721. patch_size=16, num_blocks=30, embed_dim=128, mlp_ratio=6, block_layer=SpatialGatingBlock,
  722. mlp_layer=GatedMlp, **kwargs)
  723. model = _create_mixer('gmlp_ti16_224', pretrained=pretrained, **model_args)
  724. return model
  725. @register_model
  726. def gmlp_s16_224(pretrained=False, **kwargs) -> MlpMixer:
  727. """ gMLP-Small
  728. Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
  729. """
  730. model_args = dict(
  731. patch_size=16, num_blocks=30, embed_dim=256, mlp_ratio=6, block_layer=SpatialGatingBlock,
  732. mlp_layer=GatedMlp, **kwargs)
  733. model = _create_mixer('gmlp_s16_224', pretrained=pretrained, **model_args)
  734. return model
  735. @register_model
  736. def gmlp_b16_224(pretrained=False, **kwargs) -> MlpMixer:
  737. """ gMLP-Base
  738. Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
  739. """
  740. model_args = dict(
  741. patch_size=16, num_blocks=30, embed_dim=512, mlp_ratio=6, block_layer=SpatialGatingBlock,
  742. mlp_layer=GatedMlp, **kwargs)
  743. model = _create_mixer('gmlp_b16_224', pretrained=pretrained, **model_args)
  744. return model
  745. register_model_deprecations(__name__, {
  746. 'mixer_b16_224_in21k': 'mixer_b16_224.goog_in21k_ft_in1k',
  747. 'mixer_l16_224_in21k': 'mixer_l16_224.goog_in21k_ft_in1k',
  748. 'mixer_b16_224_miil': 'mixer_b16_224.miil_in21k_ft_in1k',
  749. 'mixer_b16_224_miil_in21k': 'mixer_b16_224.miil_in21k',
  750. 'resmlp_12_distilled_224': 'resmlp_12_224.fb_distilled_in1k',
  751. 'resmlp_24_distilled_224': 'resmlp_24_224.fb_distilled_in1k',
  752. 'resmlp_36_distilled_224': 'resmlp_36_224.fb_distilled_in1k',
  753. 'resmlp_big_24_distilled_224': 'resmlp_big_24_224.fb_distilled_in1k',
  754. 'resmlp_big_24_224_in22ft1k': 'resmlp_big_24_224.fb_in22k_ft_in1k',
  755. 'resmlp_12_224_dino': 'resmlp_12_224',
  756. 'resmlp_24_224_dino': 'resmlp_24_224',
  757. })