beit.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062
  1. """ BEiT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
  2. Model from official source: https://github.com/microsoft/unilm/tree/master/beit
  3. @inproceedings{beit,
  4. title={{BEiT}: {BERT} Pre-Training of Image Transformers},
  5. author={Hangbo Bao and Li Dong and Songhao Piao and Furu Wei},
  6. booktitle={International Conference on Learning Representations},
  7. year={2022},
  8. url={https://openreview.net/forum?id=p-BhZSz59o4}
  9. }
  10. BEiT-v2 from https://github.com/microsoft/unilm/tree/master/beit2
  11. @article{beitv2,
  12. title={{BEiT v2}: Masked Image Modeling with Vector-Quantized Visual Tokenizers},
  13. author={Zhiliang Peng and Li Dong and Hangbo Bao and Qixiang Ye and Furu Wei},
  14. year={2022},
  15. eprint={2208.06366},
  16. archivePrefix={arXiv},
  17. primaryClass={cs.CV}
  18. }
  19. At this point only the 1k fine-tuned classification weights and model configs have been added,
  20. see original source above for pre-training models and procedure.
  21. Modifications by / Copyright 2021 Ross Wightman, original copyrights below
  22. """
  23. # --------------------------------------------------------
  24. # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
  25. # Github source: https://github.com/microsoft/unilm/tree/master/beit
  26. # Copyright (c) 2021 Microsoft
  27. # Licensed under The MIT License [see LICENSE for details]
  28. # By Hangbo Bao
  29. # Based on timm and DeiT code bases
  30. # https://github.com/rwightman/pytorch-image-models/tree/master/timm
  31. # https://github.com/facebookresearch/deit/
  32. # https://github.com/facebookresearch/dino
  33. # --------------------------------------------------------'
  34. import math
  35. from functools import partial
  36. from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
  37. import torch
  38. import torch.nn as nn
  39. import torch.nn.functional as F
  40. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  41. from timm.layers import (
  42. PatchEmbed,
  43. Mlp,
  44. SwiGLU,
  45. LayerNorm,
  46. DropPath,
  47. calculate_drop_path_rates,
  48. trunc_normal_,
  49. use_fused_attn,
  50. resample_patch_embed,
  51. resample_abs_pos_embed,
  52. resize_rel_pos_bias_table,
  53. ndgrid,
  54. )
  55. from ._builder import build_model_with_cfg
  56. from ._features import feature_take_indices
  57. from ._manipulate import checkpoint
  58. from ._registry import generate_default_cfgs, register_model
  59. __all__ = ['Beit']
  60. def gen_relative_position_index(window_size: Tuple[int, int], device=None) -> torch.Tensor:
  61. """Generate relative position index for window-based attention.
  62. Creates a lookup table for relative position indices between all pairs of positions
  63. within a window, including special handling for cls token interactions.
  64. Args:
  65. window_size: Height and width of the attention window.
  66. Returns:
  67. Relative position index tensor of shape (window_area+1, window_area+1)
  68. where +1 accounts for the cls token.
  69. """
  70. num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
  71. # cls to token & token 2 cls & cls to cls
  72. # get pair-wise relative position index for each token inside the window
  73. window_area = window_size[0] * window_size[1]
  74. coords = torch.stack(ndgrid(
  75. torch.arange(window_size[0], device=device, dtype=torch.long),
  76. torch.arange(window_size[1], device=device, dtype=torch.long),
  77. )) # 2, Wh, Ww
  78. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
  79. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
  80. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
  81. relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
  82. relative_coords[:, :, 1] += window_size[1] - 1
  83. relative_coords[:, :, 0] *= 2 * window_size[1] - 1
  84. relative_position_index = torch.zeros(size=(window_area + 1,) * 2, device=device, dtype=relative_coords.dtype)
  85. relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  86. relative_position_index[0, 0:] = num_relative_distance - 3
  87. relative_position_index[0:, 0] = num_relative_distance - 2
  88. relative_position_index[0, 0] = num_relative_distance - 1
  89. return relative_position_index
  90. class Attention(nn.Module):
  91. """Multi-head attention module with optional relative position bias.
  92. Implements multi-head self-attention with support for relative position bias
  93. and fused attention operations. Can use either standard or custom head dimensions.
  94. """
  95. fused_attn: torch.jit.Final[bool]
  96. def __init__(
  97. self,
  98. dim: int,
  99. num_heads: int = 8,
  100. qkv_bias: bool = False,
  101. qkv_bias_separate: bool = False,
  102. attn_drop: float = 0.,
  103. proj_drop: float = 0.,
  104. window_size: Optional[Tuple[int, int]] = None,
  105. attn_head_dim: Optional[int] = None,
  106. device=None,
  107. dtype=None,
  108. ):
  109. """Initialize attention module.
  110. Args:
  111. dim: Input feature dimension.
  112. num_heads: Number of attention heads.
  113. qkv_bias: If True, add learnable bias to query, key, value projections.
  114. qkv_bias_separate: If True, use separate bias for q, k, v projections.
  115. attn_drop: Dropout rate for attention weights.
  116. proj_drop: Dropout rate for output projection.
  117. window_size: Window size for relative position bias. If None, no relative position bias.
  118. attn_head_dim: Dimension per attention head. If None, uses dim // num_heads.
  119. """
  120. dd = {'device': device, 'dtype': dtype}
  121. super().__init__()
  122. self.num_heads = num_heads
  123. head_dim = dim // num_heads
  124. if attn_head_dim is not None:
  125. head_dim = attn_head_dim
  126. all_head_dim = head_dim * self.num_heads
  127. self.scale = head_dim ** -0.5
  128. self.fused_attn = use_fused_attn()
  129. self.qkv_bias_separate = qkv_bias_separate
  130. self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False, **dd)
  131. if qkv_bias:
  132. self.q_bias = nn.Parameter(torch.empty(all_head_dim, **dd))
  133. self.register_buffer('k_bias', torch.empty(all_head_dim, **dd), persistent=False)
  134. self.v_bias = nn.Parameter(torch.empty(all_head_dim, **dd))
  135. else:
  136. self.q_bias = None
  137. self.k_bias = None
  138. self.v_bias = None
  139. if window_size:
  140. self.window_size = window_size
  141. self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
  142. window_area = window_size[0] * window_size[1]
  143. self.relative_position_bias_table = nn.Parameter(
  144. torch.empty(self.num_relative_distance, num_heads, **dd)) # 2*Wh-1 * 2*Ww-1, nH
  145. self.register_buffer(
  146. "relative_position_index",
  147. torch.empty((window_area + 1, window_area + 1), device=device, dtype=torch.long),
  148. persistent=False,
  149. )
  150. else:
  151. self.window_size = None
  152. self.relative_position_bias_table = None
  153. self.relative_position_index = None
  154. self.attn_drop = nn.Dropout(attn_drop)
  155. self.proj = nn.Linear(all_head_dim, dim, **dd)
  156. self.proj_drop = nn.Dropout(proj_drop)
  157. # TODO: skip init when on meta device when safe to do so
  158. self.reset_parameters()
  159. def _get_rel_pos_bias(self) -> torch.Tensor:
  160. """Get relative position bias for the attention window.
  161. Returns:
  162. Relative position bias tensor of shape (1, num_heads, window_area+1, window_area+1).
  163. """
  164. relative_position_bias = self.relative_position_bias_table[
  165. self.relative_position_index.view(-1)].view(
  166. self.window_size[0] * self.window_size[1] + 1,
  167. self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
  168. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
  169. return relative_position_bias.unsqueeze(0)
  170. def forward(self, x: torch.Tensor, shared_rel_pos_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  171. """Forward pass of attention module.
  172. Args:
  173. x: Input tensor of shape (batch_size, num_tokens, dim).
  174. shared_rel_pos_bias: Optional shared relative position bias from parent module.
  175. Returns:
  176. Output tensor of shape (batch_size, num_tokens, dim).
  177. """
  178. B, N, C = x.shape
  179. if self.q_bias is None:
  180. qkv = self.qkv(x)
  181. else:
  182. qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias))
  183. if self.qkv_bias_separate:
  184. qkv = self.qkv(x)
  185. qkv += qkv_bias
  186. else:
  187. qkv = F.linear(x, weight=self.qkv.weight, bias=qkv_bias)
  188. qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
  189. q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
  190. if self.fused_attn:
  191. rel_pos_bias = None
  192. if self.relative_position_bias_table is not None:
  193. rel_pos_bias = self._get_rel_pos_bias()
  194. if shared_rel_pos_bias is not None:
  195. rel_pos_bias = rel_pos_bias + shared_rel_pos_bias
  196. elif shared_rel_pos_bias is not None:
  197. rel_pos_bias = shared_rel_pos_bias
  198. x = F.scaled_dot_product_attention(
  199. q, k, v,
  200. attn_mask=rel_pos_bias,
  201. dropout_p=self.attn_drop.p if self.training else 0.,
  202. )
  203. else:
  204. q = q * self.scale
  205. attn = (q @ k.transpose(-2, -1))
  206. if self.relative_position_bias_table is not None:
  207. attn = attn + self._get_rel_pos_bias()
  208. if shared_rel_pos_bias is not None:
  209. attn = attn + shared_rel_pos_bias
  210. attn = attn.softmax(dim=-1)
  211. attn = self.attn_drop(attn)
  212. x = attn @ v
  213. x = x.transpose(1, 2).reshape(B, N, C)
  214. x = self.proj(x)
  215. x = self.proj_drop(x)
  216. return x
  217. def reset_parameters(self) -> None:
  218. """Initialize parameters and buffers."""
  219. if self.q_bias is not None:
  220. nn.init.zeros_(self.q_bias)
  221. nn.init.zeros_(self.v_bias)
  222. if self.relative_position_bias_table is not None:
  223. nn.init.zeros_(self.relative_position_bias_table)
  224. self._init_buffers()
  225. def _init_buffers(self) -> None:
  226. """Compute and fill non-persistent buffer values."""
  227. if self.k_bias is not None:
  228. self.k_bias.zero_()
  229. if self.relative_position_index is not None:
  230. self.relative_position_index.copy_(
  231. gen_relative_position_index(self.window_size, device=self.relative_position_index.device)
  232. )
  233. def init_non_persistent_buffers(self) -> None:
  234. """Initialize non-persistent buffers."""
  235. self._init_buffers()
  236. class Block(nn.Module):
  237. """Transformer block with attention and MLP.
  238. Standard transformer block consisting of multi-head self-attention and MLP
  239. with residual connections and layer normalization. Supports layer scale and
  240. stochastic depth regularization.
  241. """
  242. def __init__(
  243. self,
  244. dim: int,
  245. num_heads: int,
  246. qkv_bias: bool = False,
  247. mlp_ratio: float = 4.,
  248. scale_mlp: bool = False,
  249. swiglu_mlp: bool = False,
  250. proj_drop: float = 0.,
  251. attn_drop: float = 0.,
  252. drop_path: float = 0.,
  253. init_values: Optional[float] = None,
  254. act_layer: Type[nn.Module] = nn.GELU,
  255. norm_layer: Type[nn.Module] = LayerNorm,
  256. window_size: Optional[Tuple[int, int]] = None,
  257. attn_head_dim: Optional[int] = None,
  258. device=None,
  259. dtype=None,
  260. ):
  261. """Initialize transformer block.
  262. Args:
  263. dim: Input feature dimension.
  264. num_heads: Number of attention heads.
  265. qkv_bias: If True, add learnable bias to query, key, value projections.
  266. mlp_ratio: Ratio of MLP hidden dimension to input dimension.
  267. scale_mlp: If True, apply layer normalization in MLP.
  268. swiglu_mlp: If True, use SwiGLU activation in MLP.
  269. proj_drop: Dropout rate for projections.
  270. attn_drop: Dropout rate for attention.
  271. drop_path: Drop path rate for stochastic depth.
  272. init_values: Initial values for layer scale. If None, no layer scale.
  273. act_layer: Activation function class.
  274. norm_layer: Normalization layer class.
  275. window_size: Window size for relative position bias in attention.
  276. attn_head_dim: Dimension per attention head.
  277. """
  278. dd = {'device': device, 'dtype': dtype}
  279. super().__init__()
  280. self.norm1 = norm_layer(dim, **dd)
  281. self.attn = Attention(
  282. dim,
  283. num_heads=num_heads,
  284. qkv_bias=qkv_bias,
  285. attn_drop=attn_drop,
  286. proj_drop=proj_drop,
  287. window_size=window_size,
  288. attn_head_dim=attn_head_dim,
  289. **dd,
  290. )
  291. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  292. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  293. self.norm2 = norm_layer(dim, **dd)
  294. if swiglu_mlp:
  295. self.mlp = SwiGLU(
  296. in_features=dim,
  297. hidden_features=int(dim * mlp_ratio),
  298. norm_layer=norm_layer if scale_mlp else None,
  299. drop=proj_drop,
  300. **dd,
  301. )
  302. else:
  303. self.mlp = Mlp(
  304. in_features=dim,
  305. hidden_features=int(dim * mlp_ratio),
  306. act_layer=act_layer,
  307. norm_layer=norm_layer if scale_mlp else None,
  308. drop=proj_drop,
  309. **dd,
  310. )
  311. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  312. self.init_values = init_values
  313. if init_values:
  314. self.gamma_1 = nn.Parameter(torch.empty(dim, **dd))
  315. self.gamma_2 = nn.Parameter(torch.empty(dim, **dd))
  316. else:
  317. self.gamma_1, self.gamma_2 = None, None
  318. # TODO: skip init when on meta device when safe to do so
  319. self.reset_parameters()
  320. def reset_parameters(self) -> None:
  321. """Initialize parameters."""
  322. if self.gamma_1 is not None:
  323. nn.init.constant_(self.gamma_1, self.init_values)
  324. nn.init.constant_(self.gamma_2, self.init_values)
  325. def forward(self, x: torch.Tensor, shared_rel_pos_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  326. """Forward pass of transformer block.
  327. Args:
  328. x: Input tensor of shape (batch_size, num_tokens, dim).
  329. shared_rel_pos_bias: Optional shared relative position bias.
  330. Returns:
  331. Output tensor of shape (batch_size, num_tokens, dim).
  332. """
  333. if self.gamma_1 is None:
  334. x = x + self.drop_path1(self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias))
  335. x = x + self.drop_path2(self.mlp(self.norm2(x)))
  336. else:
  337. x = x + self.drop_path1(self.gamma_1 * self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias))
  338. x = x + self.drop_path2(self.gamma_2 * self.mlp(self.norm2(x)))
  339. return x
  340. class RelativePositionBias(nn.Module):
  341. """Relative position bias module for window-based attention.
  342. Generates learnable relative position biases for all pairs of positions
  343. within a window, including special handling for cls token.
  344. """
  345. def __init__(self, window_size: Tuple[int, int], num_heads: int, device=None, dtype=None):
  346. """Initialize relative position bias module.
  347. Args:
  348. window_size: Height and width of the attention window.
  349. num_heads: Number of attention heads.
  350. """
  351. dd = {'device': device, 'dtype': dtype}
  352. super().__init__()
  353. self.window_size = window_size
  354. self.window_area = window_size[0] * window_size[1]
  355. num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
  356. self.relative_position_bias_table = nn.Parameter(torch.empty(num_relative_distance, num_heads, **dd))
  357. self.register_buffer(
  358. "relative_position_index",
  359. torch.empty((self.window_area + 1, self.window_area + 1), device=device, dtype=torch.long),
  360. persistent=False,
  361. )
  362. # TODO: skip init when on meta device when safe to do so
  363. self.reset_parameters()
  364. def reset_parameters(self) -> None:
  365. """Initialize parameters and buffers."""
  366. nn.init.zeros_(self.relative_position_bias_table)
  367. self._init_buffers()
  368. def _init_buffers(self) -> None:
  369. """Compute and fill non-persistent buffer values."""
  370. self.relative_position_index.copy_(
  371. gen_relative_position_index(self.window_size, device=self.relative_position_index.device)
  372. )
  373. def init_non_persistent_buffers(self) -> None:
  374. """Initialize non-persistent buffers."""
  375. self._init_buffers()
  376. def forward(self) -> torch.Tensor:
  377. """Generate relative position bias.
  378. Returns:
  379. Relative position bias tensor of shape (num_heads, window_area+1, window_area+1).
  380. """
  381. relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
  382. self.window_area + 1, self.window_area + 1, -1) # Wh*Ww,Wh*Ww,nH
  383. return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
  384. class Beit(nn.Module):
  385. """BEiT: BERT Pre-Training of Image Transformers.
  386. Vision Transformer model with support for relative position bias and
  387. shared relative position bias across layers. Implements both BEiT v1 and v2
  388. architectures with flexible configuration options.
  389. """
  390. def __init__(
  391. self,
  392. img_size: Union[int, Tuple[int, int]] = 224,
  393. patch_size: Union[int, Tuple[int, int]] = 16,
  394. in_chans: int = 3,
  395. num_classes: int = 1000,
  396. global_pool: str = 'avg',
  397. embed_dim: int = 768,
  398. depth: int = 12,
  399. num_heads: int = 12,
  400. qkv_bias: bool = True,
  401. mlp_ratio: float = 4.,
  402. swiglu_mlp: bool = False,
  403. scale_mlp: bool = False,
  404. drop_rate: float = 0.,
  405. pos_drop_rate: float = 0.,
  406. proj_drop_rate: float = 0.,
  407. attn_drop_rate: float = 0.,
  408. drop_path_rate: float = 0.,
  409. norm_layer: Type[nn.Module] = LayerNorm,
  410. init_values: Optional[float] = None,
  411. use_abs_pos_emb: bool = True,
  412. use_rel_pos_bias: bool = False,
  413. use_shared_rel_pos_bias: bool = False,
  414. head_init_scale: float = 0.001,
  415. device=None,
  416. dtype=None,
  417. ):
  418. """Initialize BEiT model.
  419. Args:
  420. img_size: Input image size.
  421. patch_size: Patch size for patch embedding.
  422. in_chans: Number of input image channels.
  423. num_classes: Number of classes for classification head.
  424. global_pool: Type of global pooling ('avg' or '').
  425. embed_dim: Embedding dimension.
  426. depth: Number of transformer blocks.
  427. num_heads: Number of attention heads.
  428. qkv_bias: If True, add learnable bias to query, key, value projections.
  429. mlp_ratio: Ratio of MLP hidden dimension to embedding dimension.
  430. swiglu_mlp: If True, use SwiGLU activation in MLP.
  431. scale_mlp: If True, apply layer normalization in MLP.
  432. drop_rate: Dropout rate.
  433. pos_drop_rate: Dropout rate for position embeddings.
  434. proj_drop_rate: Dropout rate for projections.
  435. attn_drop_rate: Dropout rate for attention.
  436. drop_path_rate: Stochastic depth rate.
  437. norm_layer: Normalization layer class.
  438. init_values: Initial values for layer scale.
  439. use_abs_pos_emb: If True, use absolute position embeddings.
  440. use_rel_pos_bias: If True, use relative position bias in attention.
  441. use_shared_rel_pos_bias: If True, share relative position bias across layers.
  442. head_init_scale: Scale factor for head initialization.
  443. """
  444. dd = {'device': device, 'dtype': dtype}
  445. super().__init__()
  446. self.num_classes = num_classes
  447. self.in_chans = in_chans
  448. self.global_pool = global_pool
  449. self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
  450. self.num_prefix_tokens = 1
  451. self.grad_checkpointing = False
  452. self.patch_embed = PatchEmbed(
  453. img_size=img_size,
  454. patch_size=patch_size,
  455. in_chans=in_chans,
  456. embed_dim=embed_dim,
  457. **dd,
  458. )
  459. num_patches = self.patch_embed.num_patches
  460. r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
  461. self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim, **dd))
  462. # self.mask_token = nn.Parameter(torch.empty(1, 1, embed_dim))
  463. self.pos_embed = nn.Parameter(torch.empty(1, num_patches + 1, embed_dim, **dd)) if use_abs_pos_emb else None
  464. self.pos_drop = nn.Dropout(p=pos_drop_rate)
  465. if use_shared_rel_pos_bias:
  466. self.rel_pos_bias = RelativePositionBias(
  467. window_size=self.patch_embed.grid_size,
  468. num_heads=num_heads,
  469. **dd,
  470. )
  471. else:
  472. self.rel_pos_bias = None
  473. dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule
  474. self.blocks = nn.ModuleList([
  475. Block(
  476. dim=embed_dim,
  477. num_heads=num_heads,
  478. qkv_bias=qkv_bias,
  479. mlp_ratio=mlp_ratio,
  480. scale_mlp=scale_mlp,
  481. swiglu_mlp=swiglu_mlp,
  482. proj_drop=proj_drop_rate,
  483. attn_drop=attn_drop_rate,
  484. drop_path=dpr[i],
  485. norm_layer=norm_layer,
  486. init_values=init_values,
  487. window_size=self.patch_embed.grid_size if use_rel_pos_bias else None,
  488. **dd,
  489. )
  490. for i in range(depth)])
  491. self.feature_info = [
  492. dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
  493. use_fc_norm = self.global_pool == 'avg'
  494. self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim, **dd)
  495. self.fc_norm = norm_layer(embed_dim, **dd) if use_fc_norm else nn.Identity()
  496. self.head_drop = nn.Dropout(drop_rate)
  497. self.head = nn.Linear(embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
  498. self.head_init_scale = head_init_scale
  499. # TODO: skip init when on meta device when safe to do so
  500. self.init_weights(needs_reset=False)
  501. def init_weights(self, needs_reset: bool = True) -> None:
  502. """Initialize model weights.
  503. Args:
  504. needs_reset: If True, call reset_parameters() on modules that have it.
  505. Set to False when modules have already self-initialized in __init__.
  506. """
  507. self.apply(partial(self._init_weights, needs_reset=needs_reset))
  508. if self.pos_embed is not None:
  509. trunc_normal_(self.pos_embed, std=.02)
  510. trunc_normal_(self.cls_token, std=.02)
  511. self.fix_init_weight()
  512. if self.head_init_scale and isinstance(self.head, nn.Linear):
  513. trunc_normal_(self.head.weight, std=.02)
  514. with torch.no_grad():
  515. self.head.weight.mul_(self.head_init_scale)
  516. self.head.bias.mul_(self.head_init_scale)
  517. def fix_init_weight(self) -> None:
  518. """Fix initialization weights according to BEiT paper.
  519. Rescales attention and MLP weights based on layer depth to improve
  520. training stability.
  521. """
  522. with torch.no_grad():
  523. for layer_id, layer in enumerate(self.blocks):
  524. scale = math.sqrt(2.0 * (layer_id + 1))
  525. layer.attn.proj.weight.div_(scale)
  526. layer.mlp.fc2.weight.div_(scale)
  527. def _init_weights(self, m: nn.Module, needs_reset: bool = True):
  528. """Initialize model weights.
  529. Args:
  530. m: Module to initialize.
  531. needs_reset: If True, call reset_parameters() on modules that have it.
  532. """
  533. if isinstance(m, nn.Linear):
  534. trunc_normal_(m.weight, std=.02)
  535. if m.bias is not None:
  536. nn.init.constant_(m.bias, 0)
  537. elif needs_reset and hasattr(m, 'reset_parameters'):
  538. m.reset_parameters()
  539. @torch.jit.ignore
  540. def no_weight_decay(self) -> Set[str]:
  541. """Get parameter names that should not use weight decay.
  542. Returns:
  543. Set of parameter names to exclude from weight decay.
  544. """
  545. nwd = {'pos_embed', 'cls_token'}
  546. for n, _ in self.named_parameters():
  547. if 'relative_position_bias_table' in n:
  548. nwd.add(n)
  549. return nwd
  550. @torch.jit.ignore
  551. def set_grad_checkpointing(self, enable: bool = True):
  552. """Enable or disable gradient checkpointing.
  553. Args:
  554. enable: If True, enable gradient checkpointing.
  555. """
  556. self.grad_checkpointing = enable
  557. @torch.jit.ignore
  558. def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
  559. """Create parameter group matcher for optimizer parameter groups.
  560. Args:
  561. coarse: If True, use coarse grouping.
  562. Returns:
  563. Dictionary mapping group names to regex patterns.
  564. """
  565. matcher = dict(
  566. stem=r'^cls_token|pos_embed|patch_embed|rel_pos_bias', # stem and embed
  567. blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))],
  568. )
  569. return matcher
  570. @torch.jit.ignore
  571. def get_classifier(self) -> nn.Module:
  572. """Get the classifier head.
  573. Returns:
  574. The classification head module.
  575. """
  576. return self.head
  577. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  578. """Reset the classification head.
  579. Args:
  580. num_classes: Number of classes for new head.
  581. global_pool: Global pooling type.
  582. """
  583. self.num_classes = num_classes
  584. if global_pool is not None:
  585. self.global_pool = global_pool
  586. self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
  587. def forward_intermediates(
  588. self,
  589. x: torch.Tensor,
  590. indices: Optional[Union[int, List[int]]] = None,
  591. return_prefix_tokens: bool = False,
  592. norm: bool = False,
  593. stop_early: bool = False,
  594. output_fmt: str = 'NCHW',
  595. intermediates_only: bool = False,
  596. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  597. """Forward pass that returns intermediate feature maps.
  598. Args:
  599. x: Input image tensor of shape (batch_size, channels, height, width).
  600. indices: Block indices to return features from. If int, returns last n blocks.
  601. return_prefix_tokens: If True, return both prefix and spatial tokens.
  602. norm: If True, apply normalization to intermediate features.
  603. stop_early: If True, stop at last selected intermediate.
  604. output_fmt: Output format ('NCHW' or 'NLC').
  605. intermediates_only: If True, only return intermediate features.
  606. Returns:
  607. If intermediates_only is True, returns list of intermediate tensors.
  608. Otherwise, returns tuple of (final_features, intermediates).
  609. """
  610. assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
  611. reshape = output_fmt == 'NCHW'
  612. intermediates = []
  613. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  614. # forward pass
  615. B, _, height, width = x.shape
  616. x = self.patch_embed(x)
  617. x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
  618. if self.pos_embed is not None:
  619. x = x + self.pos_embed
  620. x = self.pos_drop(x)
  621. rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
  622. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  623. blocks = self.blocks
  624. else:
  625. blocks = self.blocks[:max_index + 1]
  626. for i, blk in enumerate(blocks):
  627. if self.grad_checkpointing and not torch.jit.is_scripting():
  628. x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias)
  629. else:
  630. x = blk(x, shared_rel_pos_bias=rel_pos_bias)
  631. if i in take_indices:
  632. # normalize intermediates with final norm layer if enabled
  633. intermediates.append(self.norm(x) if norm else x)
  634. # process intermediates
  635. if self.num_prefix_tokens:
  636. # split prefix (e.g. class, distill) and spatial feature tokens
  637. prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
  638. intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
  639. if reshape:
  640. # reshape to BCHW output format
  641. H, W = self.patch_embed.dynamic_feat_size((height, width))
  642. intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
  643. if not torch.jit.is_scripting() and return_prefix_tokens:
  644. # return_prefix not support in torchscript due to poor type handling
  645. intermediates = list(zip(intermediates, prefix_tokens))
  646. if intermediates_only:
  647. return intermediates
  648. x = self.norm(x)
  649. return x, intermediates
  650. def prune_intermediate_layers(
  651. self,
  652. indices: Union[int, List[int]] = 1,
  653. prune_norm: bool = False,
  654. prune_head: bool = True,
  655. ) -> List[int]:
  656. """Prune layers not required for specified intermediate outputs.
  657. Args:
  658. indices: Indices of blocks to keep.
  659. prune_norm: If True, remove final normalization.
  660. prune_head: If True, remove classification head.
  661. Returns:
  662. List of indices that were kept.
  663. """
  664. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  665. self.blocks = self.blocks[:max_index + 1] # truncate blocks
  666. if prune_norm:
  667. self.norm = nn.Identity()
  668. if prune_head:
  669. self.fc_norm = nn.Identity()
  670. self.reset_classifier(0, '')
  671. return take_indices
  672. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  673. """Forward pass through feature extraction layers.
  674. Args:
  675. x: Input tensor of shape (batch_size, channels, height, width).
  676. Returns:
  677. Feature tensor of shape (batch_size, num_tokens, embed_dim).
  678. """
  679. x = self.patch_embed(x)
  680. x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
  681. if self.pos_embed is not None:
  682. x = x + self.pos_embed
  683. x = self.pos_drop(x)
  684. rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
  685. for blk in self.blocks:
  686. if self.grad_checkpointing and not torch.jit.is_scripting():
  687. x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias)
  688. else:
  689. x = blk(x, shared_rel_pos_bias=rel_pos_bias)
  690. x = self.norm(x)
  691. return x
  692. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  693. """Forward pass through classification head.
  694. Args:
  695. x: Feature tensor of shape (batch_size, num_tokens, embed_dim).
  696. pre_logits: If True, return features before final linear layer.
  697. Returns:
  698. Logits tensor of shape (batch_size, num_classes) or pre-logits.
  699. """
  700. if self.global_pool:
  701. x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
  702. x = self.fc_norm(x)
  703. x = self.head_drop(x)
  704. return x if pre_logits else self.head(x)
  705. def forward(self, x: torch.Tensor) -> torch.Tensor:
  706. """Forward pass through the model.
  707. Args:
  708. x: Input tensor of shape (batch_size, channels, height, width).
  709. Returns:
  710. Logits tensor of shape (batch_size, num_classes).
  711. """
  712. x = self.forward_features(x)
  713. x = self.forward_head(x)
  714. return x
  715. def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
  716. """Create a default configuration dictionary for BEiT models.
  717. Args:
  718. url: Model weights URL.
  719. **kwargs: Additional configuration parameters.
  720. Returns:
  721. Configuration dictionary.
  722. """
  723. return {
  724. 'url': url,
  725. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  726. 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
  727. 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
  728. 'first_conv': 'patch_embed.proj', 'classifier': 'head',
  729. 'license': 'apache-2.0',
  730. **kwargs
  731. }
  732. default_cfgs = generate_default_cfgs({
  733. 'beit_base_patch16_224.in22k_ft_in22k_in1k': _cfg(
  734. #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth',
  735. hf_hub_id='timm/'),
  736. 'beit_base_patch16_384.in22k_ft_in22k_in1k': _cfg(
  737. #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth',
  738. hf_hub_id='timm/',
  739. input_size=(3, 384, 384), crop_pct=1.0,
  740. ),
  741. 'beit_base_patch16_224.in22k_ft_in22k': _cfg(
  742. #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth',
  743. hf_hub_id='timm/',
  744. num_classes=21841,
  745. ),
  746. 'beit_large_patch16_224.in22k_ft_in22k_in1k': _cfg(
  747. #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth',
  748. hf_hub_id='timm/'),
  749. 'beit_large_patch16_384.in22k_ft_in22k_in1k': _cfg(
  750. #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth',
  751. hf_hub_id='timm/',
  752. input_size=(3, 384, 384), crop_pct=1.0,
  753. ),
  754. 'beit_large_patch16_512.in22k_ft_in22k_in1k': _cfg(
  755. #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth',
  756. hf_hub_id='timm/',
  757. input_size=(3, 512, 512), crop_pct=1.0,
  758. ),
  759. 'beit_large_patch16_224.in22k_ft_in22k': _cfg(
  760. #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth',
  761. hf_hub_id='timm/',
  762. num_classes=21841,
  763. ),
  764. 'beitv2_base_patch16_224.in1k_ft_in22k_in1k': _cfg(
  765. #url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth',
  766. hf_hub_id='timm/',
  767. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
  768. ),
  769. 'beitv2_base_patch16_224.in1k_ft_in1k': _cfg(
  770. #url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft1k.pth',
  771. hf_hub_id='timm/',
  772. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
  773. ),
  774. 'beitv2_base_patch16_224.in1k_ft_in22k': _cfg(
  775. #url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth',
  776. hf_hub_id='timm/',
  777. num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
  778. ),
  779. 'beitv2_large_patch16_224.in1k_ft_in22k_in1k': _cfg(
  780. #url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth',
  781. hf_hub_id='timm/',
  782. crop_pct=0.95, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
  783. ),
  784. 'beitv2_large_patch16_224.in1k_ft_in1k': _cfg(
  785. #url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft1k.pth',
  786. hf_hub_id='timm/',
  787. crop_pct=0.95, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
  788. ),
  789. 'beitv2_large_patch16_224.in1k_ft_in22k': _cfg(
  790. #url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth',
  791. hf_hub_id='timm/',
  792. num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
  793. ),
  794. })
  795. def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module, interpolation: str = 'bicubic', antialias: bool = True) -> Dict[str, torch.Tensor]:
  796. """Filter and process checkpoint state dict for loading.
  797. Handles resizing of patch embeddings, position embeddings, and relative position
  798. bias tables when model size differs from checkpoint.
  799. Args:
  800. state_dict: Checkpoint state dictionary.
  801. model: Target model to load weights into.
  802. interpolation: Interpolation method for resizing.
  803. antialias: If True, use antialiasing when resizing.
  804. Returns:
  805. Filtered state dictionary.
  806. """
  807. state_dict = state_dict.get('model', state_dict)
  808. state_dict = state_dict.get('module', state_dict)
  809. # beit v2 didn't strip module
  810. out_dict = {}
  811. for k, v in state_dict.items():
  812. if 'relative_position_index' in k:
  813. continue
  814. if 'patch_embed.proj.weight' in k:
  815. O, I, H, W = model.patch_embed.proj.weight.shape
  816. if v.shape[-1] != W or v.shape[-2] != H:
  817. v = resample_patch_embed(
  818. v,
  819. (H, W),
  820. interpolation=interpolation,
  821. antialias=antialias,
  822. verbose=True,
  823. )
  824. elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
  825. # To resize pos embedding when using model at different size from pretrained weights
  826. num_prefix_tokens = 1
  827. v = resample_abs_pos_embed(
  828. v,
  829. new_size=model.patch_embed.grid_size,
  830. num_prefix_tokens=num_prefix_tokens,
  831. interpolation=interpolation,
  832. antialias=antialias,
  833. verbose=True,
  834. )
  835. elif k.endswith('relative_position_bias_table'):
  836. m = model.get_submodule(k[:-29])
  837. if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]:
  838. v = resize_rel_pos_bias_table(
  839. v,
  840. new_window_size=m.window_size,
  841. new_bias_shape=m.relative_position_bias_table.shape,
  842. )
  843. out_dict[k] = v
  844. return out_dict
  845. def _create_beit(variant: str, pretrained: bool = False, **kwargs) -> Beit:
  846. """Create a BEiT model.
  847. Args:
  848. variant: Model variant name.
  849. pretrained: If True, load pretrained weights.
  850. **kwargs: Additional model arguments.
  851. Returns:
  852. BEiT model instance.
  853. """
  854. out_indices = kwargs.pop('out_indices', 3)
  855. model = build_model_with_cfg(
  856. Beit, variant, pretrained,
  857. pretrained_filter_fn=checkpoint_filter_fn,
  858. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  859. **kwargs,
  860. )
  861. return model
  862. @register_model
  863. def beit_base_patch16_224(pretrained: bool = False, **kwargs) -> Beit:
  864. """BEiT base model @ 224x224 with patch size 16x16."""
  865. model_args = dict(
  866. patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
  867. use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1)
  868. model = _create_beit('beit_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  869. return model
  870. @register_model
  871. def beit_base_patch16_384(pretrained: bool = False, **kwargs) -> Beit:
  872. """BEiT base model @ 384x384 with patch size 16x16."""
  873. model_args = dict(
  874. img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
  875. use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1)
  876. model = _create_beit('beit_base_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  877. return model
  878. @register_model
  879. def beit_large_patch16_224(pretrained: bool = False, **kwargs) -> Beit:
  880. """BEiT large model @ 224x224 with patch size 16x16."""
  881. model_args = dict(
  882. patch_size=16, embed_dim=1024, depth=24, num_heads=16,
  883. use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
  884. model = _create_beit('beit_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  885. return model
  886. @register_model
  887. def beit_large_patch16_384(pretrained: bool = False, **kwargs) -> Beit:
  888. """BEiT large model @ 384x384 with patch size 16x16."""
  889. model_args = dict(
  890. img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16,
  891. use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
  892. model = _create_beit('beit_large_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  893. return model
  894. @register_model
  895. def beit_large_patch16_512(pretrained: bool = False, **kwargs) -> Beit:
  896. """BEiT large model @ 512x512 with patch size 16x16."""
  897. model_args = dict(
  898. img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16,
  899. use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
  900. model = _create_beit('beit_large_patch16_512', pretrained=pretrained, **dict(model_args, **kwargs))
  901. return model
  902. @register_model
  903. def beitv2_base_patch16_224(pretrained: bool = False, **kwargs) -> Beit:
  904. """BEiT v2 base model @ 224x224 with patch size 16x16."""
  905. model_args = dict(
  906. patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
  907. use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
  908. model = _create_beit('beitv2_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  909. return model
  910. @register_model
  911. def beitv2_large_patch16_224(pretrained: bool = False, **kwargs) -> Beit:
  912. """BEiT v2 large model @ 224x224 with patch size 16x16."""
  913. model_args = dict(
  914. patch_size=16, embed_dim=1024, depth=24, num_heads=16,
  915. use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
  916. model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  917. return model