swin_transformer.py 48 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255
  1. """ Swin Transformer
  2. A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`
  3. - https://arxiv.org/pdf/2103.14030
  4. Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below
  5. S3 (AutoFormerV2, https://arxiv.org/abs/2111.14725) Swin weights from
  6. - https://github.com/microsoft/Cream/tree/main/AutoFormerV2
  7. Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
  8. """
  9. # --------------------------------------------------------
  10. # Swin Transformer
  11. # Copyright (c) 2021 Microsoft
  12. # Licensed under The MIT License [see LICENSE for details]
  13. # Written by Ze Liu
  14. # --------------------------------------------------------
  15. import logging
  16. import math
  17. from typing import Any, Dict, Callable, List, Optional, Set, Tuple, Union, Type
  18. import torch
  19. import torch.nn as nn
  20. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  21. from timm.layers import PatchEmbed, Mlp, DropPath, calculate_drop_path_rates, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, \
  22. use_fused_attn, resize_rel_pos_bias_table, resample_patch_embed, ndgrid
  23. from ._builder import build_model_with_cfg
  24. from ._features import feature_take_indices
  25. from ._features_fx import register_notrace_function
  26. from ._manipulate import checkpoint_seq, named_apply
  27. from ._registry import generate_default_cfgs, register_model, register_model_deprecations
  28. from .vision_transformer import get_init_weights_vit
  29. __all__ = ['SwinTransformer'] # model_registry will add each entrypoint fn to this
  30. _logger = logging.getLogger(__name__)
  31. _int_or_tuple_2_t = Union[int, Tuple[int, int]]
  32. def window_partition(
  33. x: torch.Tensor,
  34. window_size: Tuple[int, int],
  35. ) -> torch.Tensor:
  36. """Partition into non-overlapping windows.
  37. Args:
  38. x: Input tokens with shape [B, H, W, C].
  39. window_size: Window size.
  40. Returns:
  41. Windows after partition with shape [B * num_windows, window_size, window_size, C].
  42. """
  43. B, H, W, C = x.shape
  44. x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
  45. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
  46. return windows
  47. @register_notrace_function # reason: int argument is a Proxy
  48. def window_reverse(windows: torch.Tensor, window_size: Tuple[int, int], H: int, W: int) -> torch.Tensor:
  49. """Reverse window partition.
  50. Args:
  51. windows: Windows with shape (num_windows*B, window_size, window_size, C).
  52. window_size: Window size.
  53. H: Height of image.
  54. W: Width of image.
  55. Returns:
  56. Tensor with shape (B, H, W, C).
  57. """
  58. C = windows.shape[-1]
  59. x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
  60. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
  61. return x
  62. def get_relative_position_index(win_h: int, win_w: int, device=None) -> torch.Tensor:
  63. """Get pair-wise relative position index for each token inside the window.
  64. Args:
  65. win_h: Window height.
  66. win_w: Window width.
  67. Returns:
  68. Relative position index tensor.
  69. """
  70. # get pair-wise relative position index for each token inside the window
  71. coords = torch.stack(ndgrid(
  72. torch.arange(win_h, device=device, dtype=torch.long),
  73. torch.arange(win_w, device=device, dtype=torch.long),
  74. )) # 2, Wh, Ww
  75. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
  76. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
  77. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
  78. relative_coords[:, :, 0] += win_h - 1 # shift to start from 0
  79. relative_coords[:, :, 1] += win_w - 1
  80. relative_coords[:, :, 0] *= 2 * win_w - 1
  81. return relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  82. class WindowAttention(nn.Module):
  83. """Window based multi-head self attention (W-MSA) module with relative position bias.
  84. Supports both shifted and non-shifted windows.
  85. """
  86. fused_attn: torch.jit.Final[bool]
  87. def __init__(
  88. self,
  89. dim: int,
  90. num_heads: int,
  91. head_dim: Optional[int] = None,
  92. window_size: _int_or_tuple_2_t = 7,
  93. qkv_bias: bool = True,
  94. attn_drop: float = 0.,
  95. proj_drop: float = 0.,
  96. device=None,
  97. dtype=None,
  98. ):
  99. """
  100. Args:
  101. dim: Number of input channels.
  102. num_heads: Number of attention heads.
  103. head_dim: Number of channels per head (dim // num_heads if not set)
  104. window_size: The height and width of the window.
  105. qkv_bias: If True, add a learnable bias to query, key, value.
  106. attn_drop: Dropout ratio of attention weight.
  107. proj_drop: Dropout ratio of output.
  108. """
  109. dd = {'device': device, 'dtype': dtype}
  110. super().__init__()
  111. self.dim = dim
  112. self.window_size = to_2tuple(window_size) # Wh, Ww
  113. win_h, win_w = self.window_size
  114. self.window_area = win_h * win_w
  115. self.num_heads = num_heads
  116. head_dim = head_dim or dim // num_heads
  117. attn_dim = head_dim * num_heads
  118. self.scale = head_dim ** -0.5
  119. self.fused_attn = use_fused_attn(experimental=True) # NOTE not tested for prime-time yet
  120. # define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH
  121. self.relative_position_bias_table = nn.Parameter(
  122. torch.empty((2 * win_h - 1) * (2 * win_w - 1), num_heads, **dd))
  123. # register empty buffer for relative position index
  124. self.register_buffer(
  125. "relative_position_index",
  126. torch.empty(win_h * win_w, win_h * win_w, device=device, dtype=torch.long),
  127. persistent=False,
  128. )
  129. self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias, **dd)
  130. self.attn_drop = nn.Dropout(attn_drop)
  131. self.proj = nn.Linear(attn_dim, dim, **dd)
  132. self.proj_drop = nn.Dropout(proj_drop)
  133. self.softmax = nn.Softmax(dim=-1)
  134. # TODO: skip init when on meta device when safe to do so
  135. self.reset_parameters()
  136. def reset_parameters(self) -> None:
  137. """Initialize parameters and buffers."""
  138. trunc_normal_(self.relative_position_bias_table, std=.02)
  139. self._init_buffers()
  140. def _init_buffers(self) -> None:
  141. """Compute and fill non-persistent buffer values."""
  142. win_h, win_w = self.window_size
  143. self.relative_position_index.copy_(
  144. get_relative_position_index(win_h, win_w, device=self.relative_position_index.device)
  145. )
  146. def set_window_size(self, window_size: Tuple[int, int]) -> None:
  147. """Update window size & interpolate position embeddings
  148. Args:
  149. window_size (int): New window size
  150. """
  151. window_size = to_2tuple(window_size)
  152. if window_size == self.window_size:
  153. return
  154. self.window_size = window_size
  155. win_h, win_w = self.window_size
  156. self.window_area = win_h * win_w
  157. with torch.no_grad():
  158. new_bias_shape = (2 * win_h - 1) * (2 * win_w - 1), self.num_heads
  159. self.relative_position_bias_table = nn.Parameter(
  160. resize_rel_pos_bias_table(
  161. self.relative_position_bias_table,
  162. new_window_size=self.window_size,
  163. new_bias_shape=new_bias_shape,
  164. ))
  165. self.register_buffer(
  166. "relative_position_index",
  167. get_relative_position_index(win_h, win_w, device=self.relative_position_bias_table.device),
  168. persistent=False,
  169. )
  170. def _get_rel_pos_bias(self) -> torch.Tensor:
  171. relative_position_bias = self.relative_position_bias_table[
  172. self.relative_position_index.view(-1)].view(self.window_area, self.window_area, -1) # Wh*Ww,Wh*Ww,nH
  173. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
  174. return relative_position_bias.unsqueeze(0)
  175. def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  176. """Forward pass.
  177. Args:
  178. x: Input features with shape of (num_windows*B, N, C).
  179. mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None.
  180. Returns:
  181. Output features with shape of (num_windows*B, N, C).
  182. """
  183. B_, N, C = x.shape
  184. qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
  185. q, k, v = qkv.unbind(0)
  186. if self.fused_attn:
  187. attn_mask = self._get_rel_pos_bias()
  188. if mask is not None:
  189. num_win = mask.shape[0]
  190. mask = mask.view(1, num_win, 1, N, N).expand(B_ // num_win, -1, self.num_heads, -1, -1)
  191. attn_mask = attn_mask + mask.reshape(-1, self.num_heads, N, N)
  192. x = torch.nn.functional.scaled_dot_product_attention(
  193. q, k, v,
  194. attn_mask=attn_mask,
  195. dropout_p=self.attn_drop.p if self.training else 0.,
  196. )
  197. else:
  198. q = q * self.scale
  199. attn = q @ k.transpose(-2, -1)
  200. attn = attn + self._get_rel_pos_bias()
  201. if mask is not None:
  202. num_win = mask.shape[0]
  203. attn = attn.view(-1, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
  204. attn = attn.view(-1, self.num_heads, N, N)
  205. attn = self.softmax(attn)
  206. attn = self.attn_drop(attn)
  207. x = attn @ v
  208. x = x.transpose(1, 2).reshape(B_, N, -1)
  209. x = self.proj(x)
  210. x = self.proj_drop(x)
  211. return x
  212. def init_non_persistent_buffers(self) -> None:
  213. """Initialize non-persistent buffers."""
  214. self._init_buffers()
  215. class SwinTransformerBlock(nn.Module):
  216. """Swin Transformer Block.
  217. A transformer block with window-based self-attention and shifted windows.
  218. """
  219. def __init__(
  220. self,
  221. dim: int,
  222. input_resolution: _int_or_tuple_2_t,
  223. num_heads: int = 4,
  224. head_dim: Optional[int] = None,
  225. window_size: _int_or_tuple_2_t = 7,
  226. shift_size: int = 0,
  227. always_partition: bool = False,
  228. dynamic_mask: bool = False,
  229. mlp_ratio: float = 4.,
  230. qkv_bias: bool = True,
  231. proj_drop: float = 0.,
  232. attn_drop: float = 0.,
  233. drop_path: float = 0.,
  234. act_layer: Type[nn.Module] = nn.GELU,
  235. norm_layer: Type[nn.Module] = nn.LayerNorm,
  236. device=None,
  237. dtype=None,
  238. ):
  239. """
  240. Args:
  241. dim: Number of input channels.
  242. input_resolution: Input resolution.
  243. window_size: Window size.
  244. num_heads: Number of attention heads.
  245. head_dim: Enforce the number of channels per head
  246. shift_size: Shift size for SW-MSA.
  247. always_partition: Always partition into full windows and shift
  248. mlp_ratio: Ratio of mlp hidden dim to embedding dim.
  249. qkv_bias: If True, add a learnable bias to query, key, value.
  250. proj_drop: Dropout rate.
  251. attn_drop: Attention dropout rate.
  252. drop_path: Stochastic depth rate.
  253. act_layer: Activation layer.
  254. norm_layer: Normalization layer.
  255. """
  256. dd = {'device': device, 'dtype': dtype}
  257. super().__init__()
  258. self.dim = dim
  259. self.input_resolution = input_resolution
  260. self.target_shift_size = to_2tuple(shift_size) # store for later resize
  261. self.always_partition = always_partition
  262. self.dynamic_mask = dynamic_mask
  263. self.window_size, self.shift_size = self._calc_window_shift(window_size, shift_size)
  264. self.window_area = self.window_size[0] * self.window_size[1]
  265. self.mlp_ratio = mlp_ratio
  266. self.norm1 = norm_layer(dim, **dd)
  267. self.attn = WindowAttention(
  268. dim,
  269. num_heads=num_heads,
  270. head_dim=head_dim,
  271. window_size=self.window_size,
  272. qkv_bias=qkv_bias,
  273. attn_drop=attn_drop,
  274. proj_drop=proj_drop,
  275. **dd,
  276. )
  277. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  278. self.norm2 = norm_layer(dim, **dd)
  279. self.mlp = Mlp(
  280. in_features=dim,
  281. hidden_features=int(dim * mlp_ratio),
  282. act_layer=act_layer,
  283. drop=proj_drop,
  284. **dd,
  285. )
  286. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  287. # Register buffer as None initially, will be computed in reset_parameters if needed
  288. self.register_buffer("attn_mask", None, persistent=False)
  289. # TODO: skip init when on meta device when safe to do so
  290. self.reset_parameters()
  291. def reset_parameters(self) -> None:
  292. """Initialize parameters and buffers."""
  293. self._init_buffers()
  294. def _init_buffers(self) -> None:
  295. """Compute and fill non-persistent buffer values."""
  296. if not self.dynamic_mask:
  297. device = self.norm1.weight.device
  298. dtype = self.norm1.weight.dtype
  299. attn_mask = self.get_attn_mask(device=device, dtype=dtype)
  300. self.register_buffer("attn_mask", attn_mask, persistent=False)
  301. def get_attn_mask(
  302. self,
  303. x: Optional[torch.Tensor] = None,
  304. device: Optional[torch.device] = None,
  305. dtype: Optional[torch.dtype] = None,
  306. ) -> Optional[torch.Tensor]:
  307. if any(self.shift_size):
  308. # calculate attention mask for SW-MSA
  309. if x is not None:
  310. H, W = x.shape[1], x.shape[2]
  311. device = x.device
  312. dtype = x.dtype
  313. else:
  314. H, W = self.input_resolution
  315. device = device
  316. dtype = dtype
  317. H = math.ceil(H / self.window_size[0]) * self.window_size[0]
  318. W = math.ceil(W / self.window_size[1]) * self.window_size[1]
  319. img_mask = torch.zeros((1, H, W, 1), dtype=dtype, device=device) # 1 H W 1
  320. cnt = 0
  321. for h in (
  322. (0, -self.window_size[0]),
  323. (-self.window_size[0], -self.shift_size[0]),
  324. (-self.shift_size[0], None),
  325. ):
  326. for w in (
  327. (0, -self.window_size[1]),
  328. (-self.window_size[1], -self.shift_size[1]),
  329. (-self.shift_size[1], None),
  330. ):
  331. img_mask[:, h[0]:h[1], w[0]:w[1], :] = cnt
  332. cnt += 1
  333. mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
  334. mask_windows = mask_windows.view(-1, self.window_area)
  335. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  336. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
  337. else:
  338. attn_mask = None
  339. return attn_mask
  340. def _calc_window_shift(
  341. self,
  342. target_window_size: Union[int, Tuple[int, int]],
  343. target_shift_size: Optional[Union[int, Tuple[int, int]]] = None,
  344. ) -> Tuple[Tuple[int, int], Tuple[int, int]]:
  345. target_window_size = to_2tuple(target_window_size)
  346. if target_shift_size is None:
  347. # if passed value is None, recalculate from default window_size // 2 if it was previously non-zero
  348. target_shift_size = self.target_shift_size
  349. if any(target_shift_size):
  350. target_shift_size = (target_window_size[0] // 2, target_window_size[1] // 2)
  351. else:
  352. target_shift_size = to_2tuple(target_shift_size)
  353. if self.always_partition:
  354. return target_window_size, target_shift_size
  355. window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
  356. shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
  357. return tuple(window_size), tuple(shift_size)
  358. def set_input_size(
  359. self,
  360. feat_size: Tuple[int, int],
  361. window_size: Tuple[int, int],
  362. always_partition: Optional[bool] = None,
  363. ):
  364. """
  365. Args:
  366. feat_size: New input resolution
  367. window_size: New window size
  368. always_partition: Change always_partition attribute if not None
  369. """
  370. self.input_resolution = feat_size
  371. if always_partition is not None:
  372. self.always_partition = always_partition
  373. self.window_size, self.shift_size = self._calc_window_shift(window_size)
  374. self.window_area = self.window_size[0] * self.window_size[1]
  375. self.attn.set_window_size(self.window_size)
  376. device = self.attn_mask.device if self.attn_mask is not None else None
  377. dtype = self.attn_mask.dtype if self.attn_mask is not None else None
  378. self.register_buffer(
  379. "attn_mask",
  380. None if self.dynamic_mask else self.get_attn_mask(device=device, dtype=dtype),
  381. persistent=False,
  382. )
  383. def _attn(self, x):
  384. B, H, W, C = x.shape
  385. # cyclic shift
  386. has_shift = any(self.shift_size)
  387. if has_shift:
  388. shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
  389. else:
  390. shifted_x = x
  391. # pad for resolution not divisible by window size
  392. pad_h = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
  393. pad_w = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
  394. shifted_x = torch.nn.functional.pad(shifted_x, (0, 0, 0, pad_w, 0, pad_h))
  395. _, Hp, Wp, _ = shifted_x.shape
  396. # partition windows
  397. x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
  398. x_windows = x_windows.view(-1, self.window_area, C) # nW*B, window_size*window_size, C
  399. # W-MSA/SW-MSA
  400. if getattr(self, 'dynamic_mask', False):
  401. attn_mask = self.get_attn_mask(shifted_x)
  402. else:
  403. attn_mask = self.attn_mask
  404. attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
  405. # merge windows
  406. attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
  407. shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
  408. shifted_x = shifted_x[:, :H, :W, :].contiguous()
  409. # reverse cyclic shift
  410. if has_shift:
  411. x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2))
  412. else:
  413. x = shifted_x
  414. return x
  415. def forward(self, x: torch.Tensor) -> torch.Tensor:
  416. """Forward pass.
  417. Args:
  418. x: Input features with shape (B, H, W, C).
  419. Returns:
  420. Output features with shape (B, H, W, C).
  421. """
  422. B, H, W, C = x.shape
  423. x = x + self.drop_path1(self._attn(self.norm1(x)))
  424. x = x.reshape(B, -1, C)
  425. x = x + self.drop_path2(self.mlp(self.norm2(x)))
  426. x = x.reshape(B, H, W, C)
  427. return x
  428. def init_non_persistent_buffers(self) -> None:
  429. """Initialize non-persistent buffers."""
  430. self._init_buffers()
  431. class PatchMerging(nn.Module):
  432. """Patch Merging Layer.
  433. Downsample features by merging 2x2 neighboring patches.
  434. """
  435. def __init__(
  436. self,
  437. dim: int,
  438. out_dim: Optional[int] = None,
  439. norm_layer: Type[nn.Module] = nn.LayerNorm,
  440. device=None,
  441. dtype=None,
  442. ):
  443. """
  444. Args:
  445. dim: Number of input channels.
  446. out_dim: Number of output channels (or 2 * dim if None)
  447. norm_layer: Normalization layer.
  448. """
  449. dd = {'device': device, 'dtype': dtype}
  450. super().__init__()
  451. self.dim = dim
  452. self.out_dim = out_dim or 2 * dim
  453. self.norm = norm_layer(4 * dim, **dd)
  454. self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False, **dd)
  455. def forward(self, x: torch.Tensor) -> torch.Tensor:
  456. """Forward pass.
  457. Args:
  458. x: Input features with shape (B, H, W, C).
  459. Returns:
  460. Output features with shape (B, H//2, W//2, out_dim).
  461. """
  462. B, H, W, C = x.shape
  463. pad_values = (0, 0, 0, W % 2, 0, H % 2)
  464. x = nn.functional.pad(x, pad_values)
  465. _, H, W, _ = x.shape
  466. x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3)
  467. x = self.norm(x)
  468. x = self.reduction(x)
  469. return x
  470. class SwinTransformerStage(nn.Module):
  471. """A basic Swin Transformer layer for one stage.
  472. Contains multiple Swin Transformer blocks and optional downsampling.
  473. """
  474. def __init__(
  475. self,
  476. dim: int,
  477. out_dim: int,
  478. input_resolution: Tuple[int, int],
  479. depth: int,
  480. downsample: bool = True,
  481. num_heads: int = 4,
  482. head_dim: Optional[int] = None,
  483. window_size: _int_or_tuple_2_t = 7,
  484. always_partition: bool = False,
  485. dynamic_mask: bool = False,
  486. mlp_ratio: float = 4.,
  487. qkv_bias: bool = True,
  488. proj_drop: float = 0.,
  489. attn_drop: float = 0.,
  490. drop_path: Union[List[float], float] = 0.,
  491. norm_layer: Type[nn.Module] = nn.LayerNorm,
  492. device=None,
  493. dtype=None,
  494. ):
  495. """
  496. Args:
  497. dim: Number of input channels.
  498. out_dim: Number of output channels.
  499. input_resolution: Input resolution.
  500. depth: Number of blocks.
  501. downsample: Downsample layer at the end of the layer.
  502. num_heads: Number of attention heads.
  503. head_dim: Channels per head (dim // num_heads if not set)
  504. window_size: Local window size.
  505. mlp_ratio: Ratio of mlp hidden dim to embedding dim.
  506. qkv_bias: If True, add a learnable bias to query, key, value.
  507. proj_drop: Projection dropout rate.
  508. attn_drop: Attention dropout rate.
  509. drop_path: Stochastic depth rate.
  510. norm_layer: Normalization layer.
  511. """
  512. dd = {'device': device, 'dtype': dtype}
  513. super().__init__()
  514. self.dim = dim
  515. self.input_resolution = input_resolution
  516. self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution
  517. self.depth = depth
  518. self.grad_checkpointing = False
  519. window_size = to_2tuple(window_size)
  520. shift_size = tuple([w // 2 for w in window_size])
  521. # patch merging layer
  522. if downsample:
  523. self.downsample = PatchMerging(
  524. dim=dim,
  525. out_dim=out_dim,
  526. norm_layer=norm_layer,
  527. **dd,
  528. )
  529. else:
  530. assert dim == out_dim
  531. self.downsample = nn.Identity()
  532. # build blocks
  533. self.blocks = nn.Sequential(*[
  534. SwinTransformerBlock(
  535. dim=out_dim,
  536. input_resolution=self.output_resolution,
  537. num_heads=num_heads,
  538. head_dim=head_dim,
  539. window_size=window_size,
  540. shift_size=0 if (i % 2 == 0) else shift_size,
  541. always_partition=always_partition,
  542. dynamic_mask=dynamic_mask,
  543. mlp_ratio=mlp_ratio,
  544. qkv_bias=qkv_bias,
  545. proj_drop=proj_drop,
  546. attn_drop=attn_drop,
  547. drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
  548. norm_layer=norm_layer,
  549. **dd,
  550. )
  551. for i in range(depth)])
  552. def set_input_size(
  553. self,
  554. feat_size: Tuple[int, int],
  555. window_size: int,
  556. always_partition: Optional[bool] = None,
  557. ):
  558. """ Updates the resolution, window size and so the pair-wise relative positions.
  559. Args:
  560. feat_size: New input (feature) resolution
  561. window_size: New window size
  562. always_partition: Always partition / shift the window
  563. """
  564. self.input_resolution = feat_size
  565. if isinstance(self.downsample, nn.Identity):
  566. self.output_resolution = feat_size
  567. else:
  568. self.output_resolution = tuple(i // 2 for i in feat_size)
  569. for block in self.blocks:
  570. block.set_input_size(
  571. feat_size=self.output_resolution,
  572. window_size=window_size,
  573. always_partition=always_partition,
  574. )
  575. def forward(self, x: torch.Tensor) -> torch.Tensor:
  576. """Forward pass.
  577. Args:
  578. x: Input features.
  579. Returns:
  580. Output features.
  581. """
  582. x = self.downsample(x)
  583. if self.grad_checkpointing and not torch.jit.is_scripting():
  584. x = checkpoint_seq(self.blocks, x)
  585. else:
  586. x = self.blocks(x)
  587. return x
  588. class SwinTransformer(nn.Module):
  589. """Swin Transformer.
  590. A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
  591. https://arxiv.org/pdf/2103.14030
  592. """
  593. def __init__(
  594. self,
  595. img_size: _int_or_tuple_2_t = 224,
  596. patch_size: int = 4,
  597. in_chans: int = 3,
  598. num_classes: int = 1000,
  599. global_pool: str = 'avg',
  600. embed_dim: int = 96,
  601. depths: Tuple[int, ...] = (2, 2, 6, 2),
  602. num_heads: Tuple[int, ...] = (3, 6, 12, 24),
  603. head_dim: Optional[int] = None,
  604. window_size: _int_or_tuple_2_t = 7,
  605. always_partition: bool = False,
  606. strict_img_size: bool = True,
  607. mlp_ratio: float = 4.,
  608. qkv_bias: bool = True,
  609. drop_rate: float = 0.,
  610. proj_drop_rate: float = 0.,
  611. attn_drop_rate: float = 0.,
  612. drop_path_rate: float = 0.1,
  613. embed_layer: Type[nn.Module] = PatchEmbed,
  614. norm_layer: Union[str, Type[nn.Module]] = nn.LayerNorm,
  615. weight_init: str = '',
  616. device=None,
  617. dtype=None,
  618. **kwargs,
  619. ):
  620. """
  621. Args:
  622. img_size: Input image size.
  623. patch_size: Patch size.
  624. in_chans: Number of input image channels.
  625. num_classes: Number of classes for classification head.
  626. embed_dim: Patch embedding dimension.
  627. depths: Depth of each Swin Transformer layer.
  628. num_heads: Number of attention heads in different layers.
  629. head_dim: Dimension of self-attention heads.
  630. window_size: Window size.
  631. mlp_ratio: Ratio of mlp hidden dim to embedding dim.
  632. qkv_bias: If True, add a learnable bias to query, key, value.
  633. drop_rate: Dropout rate.
  634. attn_drop_rate (float): Attention dropout rate.
  635. drop_path_rate (float): Stochastic depth rate.
  636. embed_layer: Patch embedding layer.
  637. norm_layer (nn.Module): Normalization layer.
  638. """
  639. super().__init__()
  640. dd = {'device': device, 'dtype': dtype}
  641. assert global_pool in ('', 'avg')
  642. self.num_classes = num_classes
  643. self.in_chans = in_chans
  644. self.global_pool = global_pool
  645. self.output_fmt = 'NHWC'
  646. self.num_layers = len(depths)
  647. self.embed_dim = embed_dim
  648. self.num_features = self.head_hidden_size = int(embed_dim * 2 ** (self.num_layers - 1))
  649. self.feature_info = []
  650. if not isinstance(embed_dim, (tuple, list)):
  651. embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
  652. # split image into non-overlapping patches
  653. self.patch_embed = embed_layer(
  654. img_size=img_size,
  655. patch_size=patch_size,
  656. in_chans=in_chans,
  657. embed_dim=embed_dim[0],
  658. norm_layer=norm_layer,
  659. strict_img_size=strict_img_size,
  660. output_fmt='NHWC',
  661. **dd,
  662. )
  663. patch_grid = self.patch_embed.grid_size
  664. # build layers
  665. head_dim = to_ntuple(self.num_layers)(head_dim)
  666. if not isinstance(window_size, (list, tuple)):
  667. window_size = to_ntuple(self.num_layers)(window_size)
  668. elif len(window_size) == 2:
  669. window_size = (window_size,) * self.num_layers
  670. assert len(window_size) == self.num_layers
  671. mlp_ratio = to_ntuple(self.num_layers)(mlp_ratio)
  672. dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
  673. layers = []
  674. in_dim = embed_dim[0]
  675. scale = 1
  676. for i in range(self.num_layers):
  677. out_dim = embed_dim[i]
  678. layers += [SwinTransformerStage(
  679. dim=in_dim,
  680. out_dim=out_dim,
  681. input_resolution=(
  682. patch_grid[0] // scale,
  683. patch_grid[1] // scale
  684. ),
  685. depth=depths[i],
  686. downsample=i > 0,
  687. num_heads=num_heads[i],
  688. head_dim=head_dim[i],
  689. window_size=window_size[i],
  690. always_partition=always_partition,
  691. dynamic_mask=not strict_img_size,
  692. mlp_ratio=mlp_ratio[i],
  693. qkv_bias=qkv_bias,
  694. proj_drop=proj_drop_rate,
  695. attn_drop=attn_drop_rate,
  696. drop_path=dpr[i],
  697. norm_layer=norm_layer,
  698. **dd,
  699. )]
  700. in_dim = out_dim
  701. if i > 0:
  702. scale *= 2
  703. self.feature_info += [dict(num_chs=out_dim, reduction=patch_size * scale, module=f'layers.{i}')]
  704. self.layers = nn.Sequential(*layers)
  705. self.norm = norm_layer(self.num_features, **dd)
  706. self.head = ClassifierHead(
  707. self.num_features,
  708. num_classes,
  709. pool_type=global_pool,
  710. drop_rate=drop_rate,
  711. input_fmt=self.output_fmt,
  712. **dd,
  713. )
  714. self.weight_init_mode = 'reset' if weight_init == 'skip' else weight_init
  715. # TODO: skip init when on meta device when safe to do so
  716. if weight_init != 'skip':
  717. self.init_weights(needs_reset=False)
  718. @torch.jit.ignore
  719. def init_weights(self, mode: str = '', needs_reset: bool = True) -> None:
  720. """Initialize model weights.
  721. Args:
  722. mode: Weight initialization mode ('jax', 'jax_nlhb', 'moco', or '').
  723. needs_reset: If True, call reset_parameters() on modules that have it.
  724. Set to False when modules have already self-initialized in __init__.
  725. """
  726. mode = mode or self.weight_init_mode
  727. assert mode in ('jax', 'jax_nlhb', 'moco', 'reset', '')
  728. head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
  729. named_apply(get_init_weights_vit(mode, head_bias=head_bias, needs_reset=needs_reset), self)
  730. @torch.jit.ignore
  731. def no_weight_decay(self) -> Set[str]:
  732. """Parameters that should not use weight decay."""
  733. nwd = set()
  734. for n, _ in self.named_parameters():
  735. if 'relative_position_bias_table' in n:
  736. nwd.add(n)
  737. return nwd
  738. def set_input_size(
  739. self,
  740. img_size: Optional[Tuple[int, int]] = None,
  741. patch_size: Optional[Tuple[int, int]] = None,
  742. window_size: Optional[Tuple[int, int]] = None,
  743. window_ratio: int = 8,
  744. always_partition: Optional[bool] = None,
  745. ) -> None:
  746. """Update the image resolution and window size.
  747. Args:
  748. img_size: New input resolution, if None current resolution is used.
  749. patch_size: New patch size, if None use current patch size.
  750. window_size: New window size, if None based on new_img_size // window_div.
  751. window_ratio: Divisor for calculating window size from grid size.
  752. always_partition: Always partition into windows and shift (even if window size < feat size).
  753. """
  754. if img_size is not None or patch_size is not None:
  755. self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size)
  756. patch_grid = self.patch_embed.grid_size
  757. if window_size is None:
  758. window_size = tuple([pg // window_ratio for pg in patch_grid])
  759. for index, stage in enumerate(self.layers):
  760. stage_scale = 2 ** max(index - 1, 0)
  761. stage.set_input_size(
  762. feat_size=(patch_grid[0] // stage_scale, patch_grid[1] // stage_scale),
  763. window_size=window_size,
  764. always_partition=always_partition,
  765. )
  766. @torch.jit.ignore
  767. def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
  768. """Group parameters for optimization."""
  769. return dict(
  770. stem=r'^patch_embed', # stem and embed
  771. blocks=r'^layers\.(\d+)' if coarse else [
  772. (r'^layers\.(\d+).downsample', (0,)),
  773. (r'^layers\.(\d+)\.\w+\.(\d+)', None),
  774. (r'^norm', (99999,)),
  775. ]
  776. )
  777. @torch.jit.ignore
  778. def set_grad_checkpointing(self, enable: bool = True) -> None:
  779. """Enable or disable gradient checkpointing."""
  780. for l in self.layers:
  781. l.grad_checkpointing = enable
  782. @torch.jit.ignore
  783. def get_classifier(self) -> nn.Module:
  784. """Get the classifier head."""
  785. return self.head.fc
  786. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
  787. """Reset the classifier head.
  788. Args:
  789. num_classes: Number of classes for new classifier.
  790. global_pool: Global pooling type.
  791. """
  792. self.num_classes = num_classes
  793. self.head.reset(num_classes, pool_type=global_pool)
  794. def forward_intermediates(
  795. self,
  796. x: torch.Tensor,
  797. indices: Optional[Union[int, List[int]]] = None,
  798. norm: bool = False,
  799. stop_early: bool = False,
  800. output_fmt: str = 'NCHW',
  801. intermediates_only: bool = False,
  802. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  803. """Forward features that returns intermediates.
  804. Args:
  805. x: Input image tensor.
  806. indices: Take last n blocks if int, all if None, select matching indices if sequence.
  807. norm: Apply norm layer to compatible intermediates.
  808. stop_early: Stop iterating over blocks when last desired intermediate hit.
  809. output_fmt: Shape of intermediate feature outputs.
  810. intermediates_only: Only return intermediate features.
  811. Returns:
  812. List of intermediate features or tuple of (final features, intermediates).
  813. """
  814. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  815. intermediates = []
  816. take_indices, max_index = feature_take_indices(len(self.layers), indices)
  817. # forward pass
  818. x = self.patch_embed(x)
  819. num_stages = len(self.layers)
  820. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  821. stages = self.layers
  822. else:
  823. stages = self.layers[:max_index + 1]
  824. for i, stage in enumerate(stages):
  825. x = stage(x)
  826. if i in take_indices:
  827. if norm and i == num_stages - 1:
  828. x_inter = self.norm(x) # applying final norm last intermediate
  829. else:
  830. x_inter = x
  831. x_inter = x_inter.permute(0, 3, 1, 2).contiguous()
  832. intermediates.append(x_inter)
  833. if intermediates_only:
  834. return intermediates
  835. x = self.norm(x)
  836. return x, intermediates
  837. def prune_intermediate_layers(
  838. self,
  839. indices: Union[int, List[int]] = 1,
  840. prune_norm: bool = False,
  841. prune_head: bool = True,
  842. ) -> List[int]:
  843. """Prune layers not required for specified intermediates.
  844. Args:
  845. indices: Indices of intermediate layers to keep.
  846. prune_norm: Whether to prune normalization layer.
  847. prune_head: Whether to prune the classifier head.
  848. Returns:
  849. List of indices that were kept.
  850. """
  851. take_indices, max_index = feature_take_indices(len(self.layers), indices)
  852. self.layers = self.layers[:max_index + 1] # truncate blocks
  853. if prune_norm:
  854. self.norm = nn.Identity()
  855. if prune_head:
  856. self.reset_classifier(0, '')
  857. return take_indices
  858. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  859. """Forward pass through feature extraction layers."""
  860. x = self.patch_embed(x)
  861. x = self.layers(x)
  862. x = self.norm(x)
  863. return x
  864. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  865. """Forward pass through classifier head.
  866. Args:
  867. x: Feature tensor.
  868. pre_logits: Return features before final classifier.
  869. Returns:
  870. Output tensor.
  871. """
  872. return self.head(x, pre_logits=True) if pre_logits else self.head(x)
  873. def forward(self, x: torch.Tensor) -> torch.Tensor:
  874. """Forward pass.
  875. Args:
  876. x: Input tensor.
  877. Returns:
  878. Output logits.
  879. """
  880. x = self.forward_features(x)
  881. x = self.forward_head(x)
  882. return x
  883. def checkpoint_filter_fn(state_dict: dict, model: nn.Module) -> Dict[str, torch.Tensor]:
  884. """Convert patch embedding weight from manual patchify + linear proj to conv.
  885. Args:
  886. state_dict: State dictionary from checkpoint.
  887. model: Model instance.
  888. Returns:
  889. Filtered state dictionary.
  890. """
  891. old_weights = True
  892. if 'head.fc.weight' in state_dict:
  893. old_weights = False
  894. import re
  895. out_dict = {}
  896. state_dict = state_dict.get('model', state_dict)
  897. state_dict = state_dict.get('state_dict', state_dict)
  898. for k, v in state_dict.items():
  899. if any([n in k for n in ('relative_position_index', 'attn_mask')]):
  900. continue # skip buffers that should not be persistent
  901. if 'patch_embed.proj.weight' in k:
  902. _, _, H, W = model.patch_embed.proj.weight.shape
  903. if v.shape[-2] != H or v.shape[-1] != W:
  904. v = resample_patch_embed(
  905. v,
  906. (H, W),
  907. interpolation='bicubic',
  908. antialias=True,
  909. verbose=True,
  910. )
  911. if k.endswith('relative_position_bias_table'):
  912. m = model.get_submodule(k[:-29])
  913. if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]:
  914. v = resize_rel_pos_bias_table(
  915. v,
  916. new_window_size=m.window_size,
  917. new_bias_shape=m.relative_position_bias_table.shape,
  918. )
  919. if old_weights:
  920. k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k)
  921. k = k.replace('head.', 'head.fc.')
  922. out_dict[k] = v
  923. return out_dict
  924. def _create_swin_transformer(variant: str, pretrained: bool = False, **kwargs) -> SwinTransformer:
  925. """Create a Swin Transformer model.
  926. Args:
  927. variant: Model variant name.
  928. pretrained: Load pretrained weights.
  929. **kwargs: Additional model arguments.
  930. Returns:
  931. SwinTransformer model instance.
  932. """
  933. default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
  934. out_indices = kwargs.pop('out_indices', default_out_indices)
  935. model = build_model_with_cfg(
  936. SwinTransformer, variant, pretrained,
  937. pretrained_filter_fn=checkpoint_filter_fn,
  938. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  939. **kwargs)
  940. return model
  941. def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
  942. """Create default configuration for Swin Transformer models."""
  943. return {
  944. 'url': url,
  945. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  946. 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
  947. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  948. 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
  949. 'license': 'mit', **kwargs
  950. }
  951. default_cfgs = generate_default_cfgs({
  952. 'swin_small_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
  953. hf_hub_id='timm/',
  954. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22kto1k_finetune.pth', ),
  955. 'swin_base_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
  956. hf_hub_id='timm/',
  957. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth',),
  958. 'swin_base_patch4_window12_384.ms_in22k_ft_in1k': _cfg(
  959. hf_hub_id='timm/',
  960. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth',
  961. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
  962. 'swin_large_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
  963. hf_hub_id='timm/',
  964. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth',),
  965. 'swin_large_patch4_window12_384.ms_in22k_ft_in1k': _cfg(
  966. hf_hub_id='timm/',
  967. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth',
  968. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
  969. 'swin_tiny_patch4_window7_224.ms_in1k': _cfg(
  970. hf_hub_id='timm/',
  971. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth',),
  972. 'swin_small_patch4_window7_224.ms_in1k': _cfg(
  973. hf_hub_id='timm/',
  974. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth',),
  975. 'swin_base_patch4_window7_224.ms_in1k': _cfg(
  976. hf_hub_id='timm/',
  977. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth',),
  978. 'swin_base_patch4_window12_384.ms_in1k': _cfg(
  979. hf_hub_id='timm/',
  980. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth',
  981. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
  982. # tiny 22k pretrain is worse than 1k, so moved after (untagged priority is based on order)
  983. 'swin_tiny_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
  984. hf_hub_id='timm/',
  985. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22kto1k_finetune.pth',),
  986. 'swin_tiny_patch4_window7_224.ms_in22k': _cfg(
  987. hf_hub_id='timm/',
  988. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22k.pth',
  989. num_classes=21841),
  990. 'swin_small_patch4_window7_224.ms_in22k': _cfg(
  991. hf_hub_id='timm/',
  992. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22k.pth',
  993. num_classes=21841),
  994. 'swin_base_patch4_window7_224.ms_in22k': _cfg(
  995. hf_hub_id='timm/',
  996. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth',
  997. num_classes=21841),
  998. 'swin_base_patch4_window12_384.ms_in22k': _cfg(
  999. hf_hub_id='timm/',
  1000. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth',
  1001. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841),
  1002. 'swin_large_patch4_window7_224.ms_in22k': _cfg(
  1003. hf_hub_id='timm/',
  1004. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth',
  1005. num_classes=21841),
  1006. 'swin_large_patch4_window12_384.ms_in22k': _cfg(
  1007. hf_hub_id='timm/',
  1008. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth',
  1009. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841),
  1010. 'swin_s3_tiny_224.ms_in1k': _cfg(
  1011. hf_hub_id='timm/',
  1012. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_t-1d53f6a8.pth'),
  1013. 'swin_s3_small_224.ms_in1k': _cfg(
  1014. hf_hub_id='timm/',
  1015. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_s-3bb4c69d.pth'),
  1016. 'swin_s3_base_224.ms_in1k': _cfg(
  1017. hf_hub_id='timm/',
  1018. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_b-a1e95db4.pth'),
  1019. })
  1020. @register_model
  1021. def swin_tiny_patch4_window7_224(pretrained=False, **kwargs) -> SwinTransformer:
  1022. """ Swin-T @ 224x224, trained ImageNet-1k
  1023. """
  1024. model_args = dict(patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24))
  1025. return _create_swin_transformer(
  1026. 'swin_tiny_patch4_window7_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1027. @register_model
  1028. def swin_small_patch4_window7_224(pretrained=False, **kwargs) -> SwinTransformer:
  1029. """ Swin-S @ 224x224
  1030. """
  1031. model_args = dict(patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24))
  1032. return _create_swin_transformer(
  1033. 'swin_small_patch4_window7_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1034. @register_model
  1035. def swin_base_patch4_window7_224(pretrained=False, **kwargs) -> SwinTransformer:
  1036. """ Swin-B @ 224x224
  1037. """
  1038. model_args = dict(patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32))
  1039. return _create_swin_transformer(
  1040. 'swin_base_patch4_window7_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1041. @register_model
  1042. def swin_base_patch4_window12_384(pretrained=False, **kwargs) -> SwinTransformer:
  1043. """ Swin-B @ 384x384
  1044. """
  1045. model_args = dict(patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32))
  1046. return _create_swin_transformer(
  1047. 'swin_base_patch4_window12_384', pretrained=pretrained, **dict(model_args, **kwargs))
  1048. @register_model
  1049. def swin_large_patch4_window7_224(pretrained=False, **kwargs) -> SwinTransformer:
  1050. """ Swin-L @ 224x224
  1051. """
  1052. model_args = dict(patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48))
  1053. return _create_swin_transformer(
  1054. 'swin_large_patch4_window7_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1055. @register_model
  1056. def swin_large_patch4_window12_384(pretrained=False, **kwargs) -> SwinTransformer:
  1057. """ Swin-L @ 384x384
  1058. """
  1059. model_args = dict(patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48))
  1060. return _create_swin_transformer(
  1061. 'swin_large_patch4_window12_384', pretrained=pretrained, **dict(model_args, **kwargs))
  1062. @register_model
  1063. def swin_s3_tiny_224(pretrained=False, **kwargs) -> SwinTransformer:
  1064. """ Swin-S3-T @ 224x224, https://arxiv.org/abs/2111.14725
  1065. """
  1066. model_args = dict(
  1067. patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24))
  1068. return _create_swin_transformer('swin_s3_tiny_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1069. @register_model
  1070. def swin_s3_small_224(pretrained=False, **kwargs) -> SwinTransformer:
  1071. """ Swin-S3-S @ 224x224, https://arxiv.org/abs/2111.14725
  1072. """
  1073. model_args = dict(
  1074. patch_size=4, window_size=(14, 14, 14, 7), embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24))
  1075. return _create_swin_transformer('swin_s3_small_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1076. @register_model
  1077. def swin_s3_base_224(pretrained=False, **kwargs) -> SwinTransformer:
  1078. """ Swin-S3-B @ 224x224, https://arxiv.org/abs/2111.14725
  1079. """
  1080. model_args = dict(
  1081. patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 30, 2), num_heads=(3, 6, 12, 24))
  1082. return _create_swin_transformer('swin_s3_base_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1083. register_model_deprecations(__name__, {
  1084. 'swin_base_patch4_window7_224_in22k': 'swin_base_patch4_window7_224.ms_in22k',
  1085. 'swin_base_patch4_window12_384_in22k': 'swin_base_patch4_window12_384.ms_in22k',
  1086. 'swin_large_patch4_window7_224_in22k': 'swin_large_patch4_window7_224.ms_in22k',
  1087. 'swin_large_patch4_window12_384_in22k': 'swin_large_patch4_window12_384.ms_in22k',
  1088. })