swin_transformer_v2.py 52 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319
  1. """ Swin Transformer V2
  2. A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution`
  3. - https://arxiv.org/abs/2111.09883
  4. Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below
  5. Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman
  6. """
  7. # --------------------------------------------------------
  8. # Swin Transformer V2
  9. # Copyright (c) 2022 Microsoft
  10. # Licensed under The MIT License [see LICENSE for details]
  11. # Written by Ze Liu
  12. # --------------------------------------------------------
  13. import math
  14. from functools import partial
  15. from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
  16. import torch
  17. import torch.nn as nn
  18. import torch.nn.functional as F
  19. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  20. from timm.layers import PatchEmbed, Mlp, DropPath, calculate_drop_path_rates, to_2tuple, trunc_normal_, ClassifierHead,\
  21. resample_patch_embed, ndgrid, get_act_layer, LayerType
  22. from ._builder import build_model_with_cfg
  23. from ._features import feature_take_indices
  24. from ._features_fx import register_notrace_function
  25. from ._manipulate import checkpoint
  26. from ._registry import generate_default_cfgs, register_model, register_model_deprecations
  27. __all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to this
  28. _int_or_tuple_2_t = Union[int, Tuple[int, int]]
  29. def window_partition(
  30. x: torch.Tensor,
  31. window_size: Tuple[int, int],
  32. ) -> torch.Tensor:
  33. """Partition into non-overlapping windows.
  34. Args:
  35. x: Input tensor of shape (B, H, W, C).
  36. window_size: Window size (height, width).
  37. Returns:
  38. Windows tensor of shape (num_windows*B, window_size[0], window_size[1], C).
  39. """
  40. B, H, W, C = x.shape
  41. x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
  42. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
  43. return windows
  44. @register_notrace_function # reason: int argument is a Proxy
  45. def window_reverse(
  46. windows: torch.Tensor,
  47. window_size: Tuple[int, int],
  48. img_size: Tuple[int, int],
  49. ) -> torch.Tensor:
  50. """Merge windows back to feature map.
  51. Args:
  52. windows: Windows tensor of shape (num_windows * B, window_size[0], window_size[1], C).
  53. window_size: Window size (height, width).
  54. img_size: Image size (height, width).
  55. Returns:
  56. Feature map tensor of shape (B, H, W, C).
  57. """
  58. H, W = img_size
  59. C = windows.shape[-1]
  60. x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
  61. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
  62. return x
  63. class WindowAttention(nn.Module):
  64. """Window based multi-head self attention (W-MSA) module with relative position bias.
  65. Supports both shifted and non-shifted window attention with continuous relative
  66. position bias and cosine attention.
  67. """
  68. def __init__(
  69. self,
  70. dim: int,
  71. window_size: Tuple[int, int],
  72. num_heads: int,
  73. qkv_bias: bool = True,
  74. qkv_bias_separate: bool = False,
  75. attn_drop: float = 0.,
  76. proj_drop: float = 0.,
  77. pretrained_window_size: Tuple[int, int] = (0, 0),
  78. device=None,
  79. dtype=None,
  80. ) -> None:
  81. """Initialize window attention module.
  82. Args:
  83. dim: Number of input channels.
  84. window_size: The height and width of the window.
  85. num_heads: Number of attention heads.
  86. qkv_bias: If True, add a learnable bias to query, key, value.
  87. qkv_bias_separate: If True, use separate bias for q, k, v projections.
  88. attn_drop: Dropout ratio of attention weight.
  89. proj_drop: Dropout ratio of output.
  90. pretrained_window_size: The height and width of the window in pre-training.
  91. """
  92. dd = {'device': device, 'dtype': dtype}
  93. super().__init__()
  94. self.dim = dim
  95. self.window_size = window_size # Wh, Ww
  96. self.pretrained_window_size = to_2tuple(pretrained_window_size)
  97. self.num_heads = num_heads
  98. self.qkv_bias_separate = qkv_bias_separate
  99. self.logit_scale = nn.Parameter(torch.empty((num_heads, 1, 1), **dd))
  100. # mlp to generate continuous relative position bias
  101. self.cpb_mlp = nn.Sequential(
  102. nn.Linear(2, 512, bias=True, **dd),
  103. nn.ReLU(inplace=True),
  104. nn.Linear(512, num_heads, bias=False, **dd)
  105. )
  106. self.qkv = nn.Linear(dim, dim * 3, bias=False, **dd)
  107. if qkv_bias:
  108. self.q_bias = nn.Parameter(torch.empty(dim, **dd))
  109. self.register_buffer('k_bias', torch.empty(dim, **dd), persistent=False)
  110. self.v_bias = nn.Parameter(torch.empty(dim, **dd))
  111. else:
  112. self.q_bias = None
  113. self.k_bias = None
  114. self.v_bias = None
  115. self.attn_drop = nn.Dropout(attn_drop)
  116. self.proj = nn.Linear(dim, dim, **dd)
  117. self.proj_drop = nn.Dropout(proj_drop)
  118. self.softmax = nn.Softmax(dim=-1)
  119. # Register empty buffers with correct shapes
  120. win_h, win_w = self.window_size
  121. self.register_buffer(
  122. "relative_coords_table",
  123. torch.empty(1, 2 * win_h - 1, 2 * win_w - 1, 2, **dd),
  124. persistent=False,
  125. )
  126. self.register_buffer(
  127. "relative_position_index",
  128. torch.empty(win_h * win_w, win_h * win_w, device=device, dtype=torch.long),
  129. persistent=False,
  130. )
  131. # TODO: skip init when on meta device when safe to do so
  132. self.reset_parameters()
  133. def reset_parameters(self) -> None:
  134. """Initialize parameters and buffers."""
  135. nn.init.constant_(self.logit_scale, math.log(10))
  136. if self.q_bias is not None:
  137. nn.init.zeros_(self.q_bias)
  138. nn.init.zeros_(self.v_bias)
  139. self._init_buffers()
  140. def _init_buffers(self) -> None:
  141. """Compute and fill non-persistent buffer values."""
  142. if self.k_bias is not None:
  143. self.k_bias.zero_()
  144. relative_coords_table, relative_position_index = self._make_pair_wise_relative_positions(
  145. device=self.proj.weight.device, dtype=self.proj.weight.dtype
  146. )
  147. self.relative_coords_table.copy_(relative_coords_table)
  148. self.relative_position_index.copy_(relative_position_index)
  149. def _make_pair_wise_relative_positions(
  150. self,
  151. device=None,
  152. dtype=None,
  153. ) -> Tuple[torch.Tensor, torch.Tensor]:
  154. """Compute pair-wise relative position index and coordinates table.
  155. Returns:
  156. Tuple of (relative_coords_table, relative_position_index)
  157. """
  158. # get relative_coords_table
  159. relative_coords_h = torch.arange(
  160. -(self.window_size[0] - 1), self.window_size[0], device=device, dtype=torch.float32)
  161. relative_coords_w = torch.arange(
  162. -(self.window_size[1] - 1), self.window_size[1], device=device, dtype=torch.float32)
  163. relative_coords_table = torch.stack(ndgrid(relative_coords_h, relative_coords_w))
  164. relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
  165. if self.pretrained_window_size[0] > 0:
  166. relative_coords_table[:, :, :, 0] /= (self.pretrained_window_size[0] - 1)
  167. relative_coords_table[:, :, :, 1] /= (self.pretrained_window_size[1] - 1)
  168. else:
  169. relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
  170. relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
  171. relative_coords_table *= 8 # normalize to -8, 8
  172. relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
  173. torch.abs(relative_coords_table) + 1.0) / math.log2(8)
  174. relative_coords_table = relative_coords_table.to(dtype=dtype)
  175. # get pair-wise relative position index for each token inside the window
  176. coords_h = torch.arange(self.window_size[0], device=device, dtype=torch.long)
  177. coords_w = torch.arange(self.window_size[1], device=device, dtype=torch.long)
  178. coords = torch.stack(ndgrid(coords_h, coords_w)) # 2, Wh, Ww
  179. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
  180. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
  181. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
  182. relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
  183. relative_coords[:, :, 1] += self.window_size[1] - 1
  184. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  185. relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  186. return relative_coords_table, relative_position_index
  187. def set_window_size(self, window_size: Tuple[int, int]) -> None:
  188. """Update window size and regenerate relative position tables.
  189. Args:
  190. window_size: New window size (height, width).
  191. """
  192. window_size = to_2tuple(window_size)
  193. if window_size != self.window_size:
  194. assert self.relative_coords_table is not None
  195. device = self.relative_coords_table.device
  196. dtype = self.relative_coords_table.dtype
  197. self.window_size = window_size
  198. relative_coords_table, relative_position_index = \
  199. self._make_pair_wise_relative_positions(device=device, dtype=dtype)
  200. self.register_buffer("relative_coords_table", relative_coords_table, persistent=False)
  201. self.register_buffer("relative_position_index", relative_position_index, persistent=False)
  202. def init_non_persistent_buffers(self) -> None:
  203. """Initialize non-persistent buffers."""
  204. self._init_buffers()
  205. def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  206. """Forward pass of window attention.
  207. Args:
  208. x: Input features with shape of (num_windows*B, N, C).
  209. mask: Attention mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None.
  210. Returns:
  211. Output features with shape of (num_windows*B, N, C).
  212. """
  213. B_, N, C = x.shape
  214. if self.q_bias is None:
  215. qkv = self.qkv(x)
  216. else:
  217. qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias))
  218. if self.qkv_bias_separate:
  219. qkv = self.qkv(x)
  220. qkv += qkv_bias
  221. else:
  222. qkv = F.linear(x, weight=self.qkv.weight, bias=qkv_bias)
  223. qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
  224. q, k, v = qkv.unbind(0)
  225. # cosine attention
  226. attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
  227. logit_scale = torch.clamp(self.logit_scale, max=math.log(1. / 0.01)).exp()
  228. attn = attn * logit_scale
  229. relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
  230. relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
  231. self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
  232. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
  233. relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
  234. attn = attn + relative_position_bias.unsqueeze(0)
  235. if mask is not None:
  236. num_win = mask.shape[0]
  237. attn = attn.view(-1, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
  238. attn = attn.view(-1, self.num_heads, N, N)
  239. attn = self.softmax(attn)
  240. else:
  241. attn = self.softmax(attn)
  242. attn = self.attn_drop(attn)
  243. x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
  244. x = self.proj(x)
  245. x = self.proj_drop(x)
  246. return x
  247. class SwinTransformerV2Block(nn.Module):
  248. """Swin Transformer V2 Block.
  249. A standard transformer block with window attention and shifted window attention
  250. for modeling long-range dependencies efficiently.
  251. """
  252. def __init__(
  253. self,
  254. dim: int,
  255. input_resolution: _int_or_tuple_2_t,
  256. num_heads: int,
  257. window_size: _int_or_tuple_2_t = 7,
  258. shift_size: _int_or_tuple_2_t = 0,
  259. always_partition: bool = False,
  260. dynamic_mask: bool = False,
  261. mlp_ratio: float = 4.,
  262. qkv_bias: bool = True,
  263. proj_drop: float = 0.,
  264. attn_drop: float = 0.,
  265. drop_path: float = 0.,
  266. act_layer: LayerType = "gelu",
  267. norm_layer: Type[nn.Module] = nn.LayerNorm,
  268. pretrained_window_size: _int_or_tuple_2_t = 0,
  269. device=None,
  270. dtype=None,
  271. ):
  272. """
  273. Args:
  274. dim: Number of input channels.
  275. input_resolution: Input resolution.
  276. num_heads: Number of attention heads.
  277. window_size: Window size.
  278. shift_size: Shift size for SW-MSA.
  279. always_partition: Always partition into full windows and shift
  280. mlp_ratio: Ratio of mlp hidden dim to embedding dim.
  281. qkv_bias: If True, add a learnable bias to query, key, value.
  282. proj_drop: Dropout rate.
  283. attn_drop: Attention dropout rate.
  284. drop_path: Stochastic depth rate.
  285. act_layer: Activation layer.
  286. norm_layer: Normalization layer.
  287. pretrained_window_size: Window size in pretraining.
  288. """
  289. dd = {'device': device, 'dtype': dtype}
  290. super().__init__()
  291. self.dim = dim
  292. self.input_resolution = to_2tuple(input_resolution)
  293. self.num_heads = num_heads
  294. self.target_shift_size = to_2tuple(shift_size) # store for later resize
  295. self.always_partition = always_partition
  296. self.dynamic_mask = dynamic_mask
  297. self.window_size, self.shift_size = self._calc_window_shift(window_size, shift_size)
  298. self.window_area = self.window_size[0] * self.window_size[1]
  299. self.mlp_ratio = mlp_ratio
  300. act_layer = get_act_layer(act_layer)
  301. self.attn = WindowAttention(
  302. dim,
  303. window_size=to_2tuple(self.window_size),
  304. num_heads=num_heads,
  305. qkv_bias=qkv_bias,
  306. attn_drop=attn_drop,
  307. proj_drop=proj_drop,
  308. pretrained_window_size=to_2tuple(pretrained_window_size),
  309. **dd,
  310. )
  311. self.norm1 = norm_layer(dim, **dd)
  312. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  313. self.mlp = Mlp(
  314. in_features=dim,
  315. hidden_features=int(dim * mlp_ratio),
  316. act_layer=act_layer,
  317. drop=proj_drop,
  318. **dd,
  319. )
  320. self.norm2 = norm_layer(dim, **dd)
  321. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  322. self.register_buffer(
  323. "attn_mask",
  324. None if self.dynamic_mask else self.get_attn_mask(**dd),
  325. persistent=False,
  326. )
  327. def get_attn_mask(
  328. self,
  329. x: Optional[torch.Tensor] = None,
  330. device: Optional[torch.device] = None,
  331. dtype: Optional[torch.dtype] = None,
  332. ) -> Optional[torch.Tensor]:
  333. """Generate attention mask for shifted window attention.
  334. Args:
  335. x: Input tensor for dynamic shape calculation.
  336. Returns:
  337. Attention mask or None if no shift.
  338. """
  339. if any(self.shift_size):
  340. # calculate attention mask for SW-MSA
  341. if x is None:
  342. img_mask = torch.zeros((1, *self.input_resolution, 1), device=device, dtype=dtype) # 1 H W 1
  343. else:
  344. img_mask = torch.zeros((1, x.shape[1], x.shape[2], 1), device=x.device, dtype=x.dtype) # 1 H W 1
  345. cnt = 0
  346. for h in (
  347. (0, -self.window_size[0]),
  348. (-self.window_size[0], -self.shift_size[0]),
  349. (-self.shift_size[0], None),
  350. ):
  351. for w in (
  352. (0, -self.window_size[1]),
  353. (-self.window_size[1], -self.shift_size[1]),
  354. (-self.shift_size[1], None),
  355. ):
  356. img_mask[:, h[0]:h[1], w[0]:w[1], :] = cnt
  357. cnt += 1
  358. mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
  359. mask_windows = mask_windows.view(-1, self.window_area)
  360. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  361. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
  362. else:
  363. attn_mask = None
  364. return attn_mask
  365. def _calc_window_shift(
  366. self,
  367. target_window_size: _int_or_tuple_2_t,
  368. target_shift_size: Optional[_int_or_tuple_2_t] = None,
  369. ) -> Tuple[Tuple[int, int], Tuple[int, int]]:
  370. """Calculate window size and shift size based on input resolution.
  371. Args:
  372. target_window_size: Target window size.
  373. target_shift_size: Target shift size.
  374. Returns:
  375. Tuple of (adjusted_window_size, adjusted_shift_size).
  376. """
  377. target_window_size = to_2tuple(target_window_size)
  378. if target_shift_size is None:
  379. # if passed value is None, recalculate from default window_size // 2 if it was active
  380. target_shift_size = self.target_shift_size
  381. if any(target_shift_size):
  382. # if there was previously a non-zero shift, recalculate based on current window_size
  383. target_shift_size = (target_window_size[0] // 2, target_window_size[1] // 2)
  384. else:
  385. target_shift_size = to_2tuple(target_shift_size)
  386. if self.always_partition:
  387. return target_window_size, target_shift_size
  388. target_window_size = to_2tuple(target_window_size)
  389. target_shift_size = to_2tuple(target_shift_size)
  390. window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
  391. shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
  392. return tuple(window_size), tuple(shift_size)
  393. def set_input_size(
  394. self,
  395. feat_size: Tuple[int, int],
  396. window_size: Tuple[int, int],
  397. always_partition: Optional[bool] = None,
  398. ) -> None:
  399. """Set input size and update window configuration.
  400. Args:
  401. feat_size: New feature map size.
  402. window_size: New window size.
  403. always_partition: Override always_partition setting.
  404. """
  405. # Update input resolution
  406. self.input_resolution = feat_size
  407. if always_partition is not None:
  408. self.always_partition = always_partition
  409. self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size))
  410. self.window_area = self.window_size[0] * self.window_size[1]
  411. self.attn.set_window_size(self.window_size)
  412. device = self.attn_mask.device if self.attn_mask is not None else None
  413. dtype = self.attn_mask.dtype if self.attn_mask is not None else None
  414. self.register_buffer(
  415. "attn_mask",
  416. None if self.dynamic_mask else self.get_attn_mask(device=device, dtype=dtype),
  417. persistent=False,
  418. )
  419. def _attn(self, x: torch.Tensor) -> torch.Tensor:
  420. """Apply windowed attention with optional shift.
  421. Args:
  422. x: Input tensor of shape (B, H, W, C).
  423. Returns:
  424. Output tensor of shape (B, H, W, C).
  425. """
  426. B, H, W, C = x.shape
  427. # cyclic shift
  428. has_shift = any(self.shift_size)
  429. if has_shift:
  430. shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
  431. else:
  432. shifted_x = x
  433. pad_h = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
  434. pad_w = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
  435. shifted_x = torch.nn.functional.pad(shifted_x, (0, 0, 0, pad_w, 0, pad_h))
  436. _, Hp, Wp, _ = shifted_x.shape
  437. # partition windows
  438. x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
  439. x_windows = x_windows.view(-1, self.window_area, C) # nW*B, window_size*window_size, C
  440. # W-MSA/SW-MSA
  441. if getattr(self, 'dynamic_mask', False):
  442. attn_mask = self.get_attn_mask(shifted_x)
  443. else:
  444. attn_mask = self.attn_mask
  445. attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
  446. # merge windows
  447. attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
  448. shifted_x = window_reverse(attn_windows, self.window_size, (Hp, Wp)) # B H' W' C
  449. shifted_x = shifted_x[:, :H, :W, :].contiguous()
  450. # reverse cyclic shift
  451. if has_shift:
  452. x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2))
  453. else:
  454. x = shifted_x
  455. return x
  456. def forward(self, x: torch.Tensor) -> torch.Tensor:
  457. B, H, W, C = x.shape
  458. x = x + self.drop_path1(self.norm1(self._attn(x)))
  459. x = x.reshape(B, -1, C)
  460. x = x + self.drop_path2(self.norm2(self.mlp(x)))
  461. x = x.reshape(B, H, W, C)
  462. return x
  463. class PatchMerging(nn.Module):
  464. """Patch Merging Layer.
  465. Merges 2x2 neighboring patches and projects to higher dimension,
  466. effectively downsampling the feature maps.
  467. """
  468. def __init__(
  469. self,
  470. dim: int,
  471. out_dim: Optional[int] = None,
  472. norm_layer: Type[nn.Module] = nn.LayerNorm,
  473. device=None,
  474. dtype=None,
  475. ):
  476. """
  477. Args:
  478. dim (int): Number of input channels.
  479. out_dim (int): Number of output channels (or 2 * dim if None)
  480. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  481. """
  482. dd = {'device': device, 'dtype': dtype}
  483. super().__init__()
  484. self.dim = dim
  485. self.out_dim = out_dim or 2 * dim
  486. self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False, **dd)
  487. self.norm = norm_layer(self.out_dim, **dd)
  488. def forward(self, x: torch.Tensor) -> torch.Tensor:
  489. B, H, W, C = x.shape
  490. pad_values = (0, 0, 0, W % 2, 0, H % 2)
  491. x = nn.functional.pad(x, pad_values)
  492. _, H, W, _ = x.shape
  493. x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3)
  494. x = self.reduction(x)
  495. x = self.norm(x)
  496. return x
  497. class SwinTransformerV2Stage(nn.Module):
  498. """A Swin Transformer V2 Stage.
  499. A single stage consisting of multiple Swin Transformer blocks with
  500. optional downsampling at the beginning.
  501. """
  502. def __init__(
  503. self,
  504. dim: int,
  505. out_dim: int,
  506. input_resolution: _int_or_tuple_2_t,
  507. depth: int,
  508. num_heads: int,
  509. window_size: _int_or_tuple_2_t,
  510. always_partition: bool = False,
  511. dynamic_mask: bool = False,
  512. downsample: bool = False,
  513. mlp_ratio: float = 4.,
  514. qkv_bias: bool = True,
  515. proj_drop: float = 0.,
  516. attn_drop: float = 0.,
  517. drop_path: float = 0.,
  518. act_layer: Union[str, Type[nn.Module]] = 'gelu',
  519. norm_layer: Type[nn.Module] = nn.LayerNorm,
  520. pretrained_window_size: _int_or_tuple_2_t = 0,
  521. output_nchw: bool = False,
  522. device=None,
  523. dtype=None,
  524. ) -> None:
  525. """
  526. Args:
  527. dim: Number of input channels.
  528. out_dim: Number of output channels.
  529. input_resolution: Input resolution.
  530. depth: Number of blocks.
  531. num_heads: Number of attention heads.
  532. window_size: Local window size.
  533. always_partition: Always partition into full windows and shift
  534. dynamic_mask: Create attention mask in forward based on current input size
  535. downsample: Use downsample layer at start of the block.
  536. mlp_ratio: Ratio of mlp hidden dim to embedding dim.
  537. qkv_bias: If True, add a learnable bias to query, key, value.
  538. proj_drop: Projection dropout rate
  539. attn_drop: Attention dropout rate.
  540. drop_path: Stochastic depth rate.
  541. act_layer: Activation layer type.
  542. norm_layer: Normalization layer.
  543. pretrained_window_size: Local window size in pretraining.
  544. output_nchw: Output tensors on NCHW format instead of NHWC.
  545. """
  546. dd = {'device': device, 'dtype': dtype}
  547. super().__init__()
  548. self.dim = dim
  549. self.input_resolution = input_resolution
  550. self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution
  551. self.depth = depth
  552. self.output_nchw = output_nchw
  553. self.grad_checkpointing = False
  554. window_size = to_2tuple(window_size)
  555. shift_size = tuple([w // 2 for w in window_size])
  556. # patch merging / downsample layer
  557. if downsample:
  558. self.downsample = PatchMerging(dim=dim, out_dim=out_dim, norm_layer=norm_layer, **dd)
  559. else:
  560. assert dim == out_dim
  561. self.downsample = nn.Identity()
  562. # build blocks
  563. self.blocks = nn.ModuleList([
  564. SwinTransformerV2Block(
  565. dim=out_dim,
  566. input_resolution=self.output_resolution,
  567. num_heads=num_heads,
  568. window_size=window_size,
  569. shift_size=0 if (i % 2 == 0) else shift_size,
  570. always_partition=always_partition,
  571. dynamic_mask=dynamic_mask,
  572. mlp_ratio=mlp_ratio,
  573. qkv_bias=qkv_bias,
  574. proj_drop=proj_drop,
  575. attn_drop=attn_drop,
  576. drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
  577. act_layer=act_layer,
  578. norm_layer=norm_layer,
  579. pretrained_window_size=pretrained_window_size,
  580. **dd,
  581. )
  582. for i in range(depth)])
  583. def set_input_size(
  584. self,
  585. feat_size: Tuple[int, int],
  586. window_size: int,
  587. always_partition: Optional[bool] = None,
  588. ) -> None:
  589. """Update resolution, window size and relative positions.
  590. Args:
  591. feat_size: New input (feature) resolution.
  592. window_size: New window size.
  593. always_partition: Always partition / shift the window.
  594. """
  595. self.input_resolution = feat_size
  596. if isinstance(self.downsample, nn.Identity):
  597. self.output_resolution = feat_size
  598. else:
  599. assert isinstance(self.downsample, PatchMerging)
  600. self.output_resolution = tuple(i // 2 for i in feat_size)
  601. for block in self.blocks:
  602. block.set_input_size(
  603. feat_size=self.output_resolution,
  604. window_size=window_size,
  605. always_partition=always_partition,
  606. )
  607. def forward(self, x: torch.Tensor) -> torch.Tensor:
  608. """Forward pass through the stage.
  609. Args:
  610. x: Input tensor of shape (B, H, W, C).
  611. Returns:
  612. Output tensor of shape (B, H', W', C').
  613. """
  614. x = self.downsample(x)
  615. for blk in self.blocks:
  616. if self.grad_checkpointing and not torch.jit.is_scripting():
  617. x = checkpoint(blk, x)
  618. else:
  619. x = blk(x)
  620. return x
  621. def _init_respostnorm(self) -> None:
  622. """Initialize residual post-normalization weights."""
  623. for blk in self.blocks:
  624. nn.init.constant_(blk.norm1.bias, 0)
  625. nn.init.constant_(blk.norm1.weight, 0)
  626. nn.init.constant_(blk.norm2.bias, 0)
  627. nn.init.constant_(blk.norm2.weight, 0)
  628. class SwinTransformerV2(nn.Module):
  629. """Swin Transformer V2.
  630. A hierarchical vision transformer using shifted windows for efficient
  631. self-attention computation with continuous position bias.
  632. A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution`
  633. - https://arxiv.org/abs/2111.09883
  634. """
  635. def __init__(
  636. self,
  637. img_size: _int_or_tuple_2_t = 224,
  638. patch_size: int = 4,
  639. in_chans: int = 3,
  640. num_classes: int = 1000,
  641. global_pool: str = 'avg',
  642. embed_dim: int = 96,
  643. depths: Tuple[int, ...] = (2, 2, 6, 2),
  644. num_heads: Tuple[int, ...] = (3, 6, 12, 24),
  645. window_size: _int_or_tuple_2_t = 7,
  646. always_partition: bool = False,
  647. strict_img_size: bool = True,
  648. mlp_ratio: float = 4.,
  649. qkv_bias: bool = True,
  650. drop_rate: float = 0.,
  651. proj_drop_rate: float = 0.,
  652. attn_drop_rate: float = 0.,
  653. drop_path_rate: float = 0.1,
  654. act_layer: Union[str, Callable] = 'gelu',
  655. norm_layer: Type[nn.Module] = nn.LayerNorm,
  656. pretrained_window_sizes: Tuple[int, ...] = (0, 0, 0, 0),
  657. device=None,
  658. dtype=None,
  659. **kwargs,
  660. ):
  661. """
  662. Args:
  663. img_size: Input image size.
  664. patch_size: Patch size.
  665. in_chans: Number of input image channels.
  666. num_classes: Number of classes for classification head.
  667. embed_dim: Patch embedding dimension.
  668. depths: Depth of each Swin Transformer stage (layer).
  669. num_heads: Number of attention heads in different layers.
  670. window_size: Window size.
  671. mlp_ratio: Ratio of mlp hidden dim to embedding dim.
  672. qkv_bias: If True, add a learnable bias to query, key, value.
  673. drop_rate: Head dropout rate.
  674. proj_drop_rate: Projection dropout rate.
  675. attn_drop_rate: Attention dropout rate.
  676. drop_path_rate: Stochastic depth rate.
  677. norm_layer: Normalization layer.
  678. act_layer: Activation layer type.
  679. patch_norm: If True, add normalization after patch embedding.
  680. pretrained_window_sizes: Pretrained window sizes of each layer.
  681. output_fmt: Output tensor format if not None, otherwise output 'NHWC' by default.
  682. """
  683. super().__init__()
  684. dd = {'device': device, 'dtype': dtype}
  685. self.num_classes = num_classes
  686. self.in_chans = in_chans
  687. assert global_pool in ('', 'avg')
  688. self.global_pool = global_pool
  689. self.output_fmt = 'NHWC'
  690. self.num_layers = len(depths)
  691. self.embed_dim = embed_dim
  692. self.num_features = self.head_hidden_size = int(embed_dim * 2 ** (self.num_layers - 1))
  693. self.feature_info = []
  694. if not isinstance(embed_dim, (tuple, list)):
  695. embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
  696. # split image into non-overlapping patches
  697. self.patch_embed = PatchEmbed(
  698. img_size=img_size,
  699. patch_size=patch_size,
  700. in_chans=in_chans,
  701. embed_dim=embed_dim[0],
  702. norm_layer=norm_layer,
  703. strict_img_size=strict_img_size,
  704. output_fmt='NHWC',
  705. **dd,
  706. )
  707. grid_size = self.patch_embed.grid_size
  708. dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
  709. layers = []
  710. in_dim = embed_dim[0]
  711. scale = 1
  712. for i in range(self.num_layers):
  713. out_dim = embed_dim[i]
  714. layers += [SwinTransformerV2Stage(
  715. dim=in_dim,
  716. out_dim=out_dim,
  717. input_resolution=(grid_size[0] // scale, grid_size[1] // scale),
  718. depth=depths[i],
  719. downsample=i > 0,
  720. num_heads=num_heads[i],
  721. window_size=window_size,
  722. always_partition=always_partition,
  723. dynamic_mask=not strict_img_size,
  724. mlp_ratio=mlp_ratio,
  725. qkv_bias=qkv_bias,
  726. proj_drop=proj_drop_rate,
  727. attn_drop=attn_drop_rate,
  728. drop_path=dpr[i],
  729. act_layer=act_layer,
  730. norm_layer=norm_layer,
  731. pretrained_window_size=pretrained_window_sizes[i],
  732. **dd,
  733. )]
  734. in_dim = out_dim
  735. if i > 0:
  736. scale *= 2
  737. self.feature_info += [dict(num_chs=out_dim, reduction=4 * scale, module=f'layers.{i}')]
  738. self.layers = nn.Sequential(*layers)
  739. self.norm = norm_layer(self.num_features, **dd)
  740. self.head = ClassifierHead(
  741. self.num_features,
  742. num_classes,
  743. pool_type=global_pool,
  744. drop_rate=drop_rate,
  745. input_fmt=self.output_fmt,
  746. **dd,
  747. )
  748. # TODO: skip init when on meta device when safe to do so
  749. self.init_weights(needs_reset=False)
  750. def init_weights(self, needs_reset: bool = True) -> None:
  751. """Initialize model weights.
  752. Args:
  753. needs_reset: If True, call reset_parameters() on modules (default for after to_empty()).
  754. If False, skip reset_parameters() (for __init__ where modules already self-initialized).
  755. """
  756. self.apply(partial(self._init_weights, needs_reset=needs_reset))
  757. for bly in self.layers:
  758. bly._init_respostnorm()
  759. def _init_weights(self, m: nn.Module, needs_reset: bool = True) -> None:
  760. """Initialize weights for Linear layers.
  761. Args:
  762. m: Module to initialize.
  763. needs_reset: Whether to call reset_parameters() on modules.
  764. """
  765. if isinstance(m, nn.Linear):
  766. trunc_normal_(m.weight, std=.02)
  767. if m.bias is not None:
  768. nn.init.constant_(m.bias, 0)
  769. elif needs_reset and hasattr(m, 'reset_parameters'):
  770. m.reset_parameters()
  771. def set_input_size(
  772. self,
  773. img_size: Optional[Tuple[int, int]] = None,
  774. patch_size: Optional[Tuple[int, int]] = None,
  775. window_size: Optional[Tuple[int, int]] = None,
  776. window_ratio: Optional[int] = 8,
  777. always_partition: Optional[bool] = None,
  778. ):
  779. """Updates the image resolution, window size, and so the pair-wise relative positions.
  780. Args:
  781. img_size (Optional[Tuple[int, int]]): New input resolution, if None current resolution is used
  782. patch_size (Optional[Tuple[int, int]): New patch size, if None use current patch size
  783. window_size (Optional[int]): New window size, if None based on new_img_size // window_div
  784. window_ratio (int): divisor for calculating window size from patch grid size
  785. always_partition: always partition / shift windows even if feat size is < window
  786. """
  787. if img_size is not None or patch_size is not None:
  788. self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size)
  789. grid_size = self.patch_embed.grid_size
  790. if window_size is None and window_ratio is not None:
  791. window_size = tuple([s // window_ratio for s in grid_size])
  792. for index, stage in enumerate(self.layers):
  793. stage_scale = 2 ** max(index - 1, 0)
  794. stage.set_input_size(
  795. feat_size=(grid_size[0] // stage_scale, grid_size[1] // stage_scale),
  796. window_size=window_size,
  797. always_partition=always_partition,
  798. )
  799. @torch.jit.ignore
  800. def no_weight_decay(self) -> Set[str]:
  801. """Get parameter names that should not use weight decay.
  802. Returns:
  803. Set of parameter names to exclude from weight decay.
  804. """
  805. nod = set()
  806. for n, m in self.named_modules():
  807. if any([kw in n for kw in ("cpb_mlp", "logit_scale")]):
  808. nod.add(n)
  809. return nod
  810. @torch.jit.ignore
  811. def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
  812. """Create parameter group matcher for optimizer parameter groups.
  813. Args:
  814. coarse: If True, use coarse grouping.
  815. Returns:
  816. Dictionary mapping group names to regex patterns.
  817. """
  818. return dict(
  819. stem=r'^absolute_pos_embed|patch_embed', # stem and embed
  820. blocks=r'^layers\.(\d+)' if coarse else [
  821. (r'^layers\.(\d+).downsample', (0,)),
  822. (r'^layers\.(\d+)\.\w+\.(\d+)', None),
  823. (r'^norm', (99999,)),
  824. ]
  825. )
  826. @torch.jit.ignore
  827. def set_grad_checkpointing(self, enable: bool = True) -> None:
  828. """Enable or disable gradient checkpointing.
  829. Args:
  830. enable: If True, enable gradient checkpointing.
  831. """
  832. for l in self.layers:
  833. l.grad_checkpointing = enable
  834. @torch.jit.ignore
  835. def get_classifier(self) -> nn.Module:
  836. """Get the classifier head.
  837. Returns:
  838. The classification head module.
  839. """
  840. return self.head.fc
  841. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
  842. """Reset the classification head.
  843. Args:
  844. num_classes: Number of classes for new head.
  845. global_pool: Global pooling type.
  846. """
  847. self.num_classes = num_classes
  848. self.head.reset(num_classes, global_pool)
  849. def forward_intermediates(
  850. self,
  851. x: torch.Tensor,
  852. indices: Optional[Union[int, List[int]]] = None,
  853. norm: bool = False,
  854. stop_early: bool = False,
  855. output_fmt: str = 'NCHW',
  856. intermediates_only: bool = False,
  857. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  858. """ Forward features that returns intermediates.
  859. Args:
  860. x: Input image tensor
  861. indices: Take last n blocks if int, all if None, select matching indices if sequence
  862. norm: Apply norm layer to compatible intermediates
  863. stop_early: Stop iterating over blocks when last desired intermediate hit
  864. output_fmt: Shape of intermediate feature outputs
  865. intermediates_only: Only return intermediate features
  866. Returns:
  867. """
  868. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  869. intermediates = []
  870. take_indices, max_index = feature_take_indices(len(self.layers), indices)
  871. # forward pass
  872. x = self.patch_embed(x)
  873. num_stages = len(self.layers)
  874. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  875. stages = self.layers
  876. else:
  877. stages = self.layers[:max_index + 1]
  878. for i, stage in enumerate(stages):
  879. x = stage(x)
  880. if i in take_indices:
  881. if norm and i == num_stages - 1:
  882. x_inter = self.norm(x) # applying final norm last intermediate
  883. else:
  884. x_inter = x
  885. x_inter = x_inter.permute(0, 3, 1, 2).contiguous()
  886. intermediates.append(x_inter)
  887. if intermediates_only:
  888. return intermediates
  889. x = self.norm(x)
  890. return x, intermediates
  891. def prune_intermediate_layers(
  892. self,
  893. indices: Union[int, List[int]] = 1,
  894. prune_norm: bool = False,
  895. prune_head: bool = True,
  896. ):
  897. """ Prune layers not required for specified intermediates.
  898. """
  899. take_indices, max_index = feature_take_indices(len(self.layers), indices)
  900. self.layers = self.layers[:max_index + 1] # truncate blocks
  901. if prune_norm:
  902. self.norm = nn.Identity()
  903. if prune_head:
  904. self.reset_classifier(0, '')
  905. return take_indices
  906. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  907. """Forward pass through feature extraction layers.
  908. Args:
  909. x: Input tensor of shape (B, C, H, W).
  910. Returns:
  911. Feature tensor of shape (B, H', W', C).
  912. """
  913. x = self.patch_embed(x)
  914. x = self.layers(x)
  915. x = self.norm(x)
  916. return x
  917. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  918. """Forward pass through classification head.
  919. Args:
  920. x: Feature tensor of shape (B, H, W, C).
  921. pre_logits: If True, return features before final linear layer.
  922. Returns:
  923. Logits tensor of shape (B, num_classes) or pre-logits.
  924. """
  925. return self.head(x, pre_logits=True) if pre_logits else self.head(x)
  926. def forward(self, x: torch.Tensor) -> torch.Tensor:
  927. """Forward pass through the model.
  928. Args:
  929. x: Input tensor of shape (B, C, H, W).
  930. Returns:
  931. Logits tensor of shape (B, num_classes).
  932. """
  933. x = self.forward_features(x)
  934. x = self.forward_head(x)
  935. return x
  936. def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module) -> Dict[str, torch.Tensor]:
  937. """Filter and process checkpoint state dict for loading.
  938. Handles resizing of patch embeddings and relative position tables
  939. when model size differs from checkpoint.
  940. Args:
  941. state_dict: Checkpoint state dictionary.
  942. model: Target model to load weights into.
  943. Returns:
  944. Filtered state dictionary.
  945. """
  946. state_dict = state_dict.get('model', state_dict)
  947. state_dict = state_dict.get('state_dict', state_dict)
  948. native_checkpoint = 'head.fc.weight' in state_dict
  949. out_dict = {}
  950. import re
  951. for k, v in state_dict.items():
  952. if any([n in k for n in ('relative_position_index', 'relative_coords_table', 'attn_mask')]):
  953. continue # skip buffers that should not be persistent
  954. if 'patch_embed.proj.weight' in k:
  955. _, _, H, W = model.patch_embed.proj.weight.shape
  956. if v.shape[-2] != H or v.shape[-1] != W:
  957. v = resample_patch_embed(
  958. v,
  959. (H, W),
  960. interpolation='bicubic',
  961. antialias=True,
  962. verbose=True,
  963. )
  964. if not native_checkpoint:
  965. # skip layer remapping for updated checkpoints
  966. k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k)
  967. k = k.replace('head.', 'head.fc.')
  968. out_dict[k] = v
  969. return out_dict
  970. def _create_swin_transformer_v2(variant: str, pretrained: bool = False, **kwargs) -> SwinTransformerV2:
  971. """Create a Swin Transformer V2 model.
  972. Args:
  973. variant: Model variant name.
  974. pretrained: If True, load pretrained weights.
  975. **kwargs: Additional model arguments.
  976. Returns:
  977. SwinTransformerV2 model instance.
  978. """
  979. default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 1, 1))))
  980. out_indices = kwargs.pop('out_indices', default_out_indices)
  981. model = build_model_with_cfg(
  982. SwinTransformerV2, variant, pretrained,
  983. pretrained_filter_fn=checkpoint_filter_fn,
  984. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  985. **kwargs)
  986. return model
  987. def _cfg(url='', **kwargs):
  988. return {
  989. 'url': url,
  990. 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
  991. 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
  992. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  993. 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
  994. 'license': 'mit', **kwargs
  995. }
  996. default_cfgs = generate_default_cfgs({
  997. 'swinv2_base_window12to16_192to256.ms_in22k_ft_in1k': _cfg(
  998. hf_hub_id='timm/',
  999. url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth',
  1000. ),
  1001. 'swinv2_base_window12to24_192to384.ms_in22k_ft_in1k': _cfg(
  1002. hf_hub_id='timm/',
  1003. url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth',
  1004. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
  1005. ),
  1006. 'swinv2_large_window12to16_192to256.ms_in22k_ft_in1k': _cfg(
  1007. hf_hub_id='timm/',
  1008. url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth',
  1009. ),
  1010. 'swinv2_large_window12to24_192to384.ms_in22k_ft_in1k': _cfg(
  1011. hf_hub_id='timm/',
  1012. url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth',
  1013. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
  1014. ),
  1015. 'swinv2_tiny_window8_256.ms_in1k': _cfg(
  1016. hf_hub_id='timm/',
  1017. url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth',
  1018. ),
  1019. 'swinv2_tiny_window16_256.ms_in1k': _cfg(
  1020. hf_hub_id='timm/',
  1021. url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth',
  1022. ),
  1023. 'swinv2_small_window8_256.ms_in1k': _cfg(
  1024. hf_hub_id='timm/',
  1025. url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth',
  1026. ),
  1027. 'swinv2_small_window16_256.ms_in1k': _cfg(
  1028. hf_hub_id='timm/',
  1029. url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth',
  1030. ),
  1031. 'swinv2_base_window8_256.ms_in1k': _cfg(
  1032. hf_hub_id='timm/',
  1033. url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth',
  1034. ),
  1035. 'swinv2_base_window16_256.ms_in1k': _cfg(
  1036. hf_hub_id='timm/',
  1037. url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth',
  1038. ),
  1039. 'swinv2_base_window12_192.ms_in22k': _cfg(
  1040. hf_hub_id='timm/',
  1041. url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth',
  1042. num_classes=21841, input_size=(3, 192, 192), pool_size=(6, 6)
  1043. ),
  1044. 'swinv2_large_window12_192.ms_in22k': _cfg(
  1045. hf_hub_id='timm/',
  1046. url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth',
  1047. num_classes=21841, input_size=(3, 192, 192), pool_size=(6, 6)
  1048. ),
  1049. })
  1050. @register_model
  1051. def swinv2_tiny_window16_256(pretrained: bool = False, **kwargs) -> SwinTransformerV2:
  1052. """Swin-T V2 @ 256x256, window 16x16."""
  1053. model_args = dict(window_size=16, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24))
  1054. return _create_swin_transformer_v2(
  1055. 'swinv2_tiny_window16_256', pretrained=pretrained, **dict(model_args, **kwargs))
  1056. @register_model
  1057. def swinv2_tiny_window8_256(pretrained: bool = False, **kwargs) -> SwinTransformerV2:
  1058. """Swin-T V2 @ 256x256, window 8x8."""
  1059. model_args = dict(window_size=8, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24))
  1060. return _create_swin_transformer_v2(
  1061. 'swinv2_tiny_window8_256', pretrained=pretrained, **dict(model_args, **kwargs))
  1062. @register_model
  1063. def swinv2_small_window16_256(pretrained: bool = False, **kwargs) -> SwinTransformerV2:
  1064. """Swin-S V2 @ 256x256, window 16x16."""
  1065. model_args = dict(window_size=16, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24))
  1066. return _create_swin_transformer_v2(
  1067. 'swinv2_small_window16_256', pretrained=pretrained, **dict(model_args, **kwargs))
  1068. @register_model
  1069. def swinv2_small_window8_256(pretrained: bool = False, **kwargs) -> SwinTransformerV2:
  1070. """Swin-S V2 @ 256x256, window 8x8."""
  1071. model_args = dict(window_size=8, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24))
  1072. return _create_swin_transformer_v2(
  1073. 'swinv2_small_window8_256', pretrained=pretrained, **dict(model_args, **kwargs))
  1074. @register_model
  1075. def swinv2_base_window16_256(pretrained: bool = False, **kwargs) -> SwinTransformerV2:
  1076. """Swin-B V2 @ 256x256, window 16x16."""
  1077. model_args = dict(window_size=16, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32))
  1078. return _create_swin_transformer_v2(
  1079. 'swinv2_base_window16_256', pretrained=pretrained, **dict(model_args, **kwargs))
  1080. @register_model
  1081. def swinv2_base_window8_256(pretrained: bool = False, **kwargs) -> SwinTransformerV2:
  1082. """Swin-B V2 @ 256x256, window 8x8."""
  1083. model_args = dict(window_size=8, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32))
  1084. return _create_swin_transformer_v2(
  1085. 'swinv2_base_window8_256', pretrained=pretrained, **dict(model_args, **kwargs))
  1086. @register_model
  1087. def swinv2_base_window12_192(pretrained: bool = False, **kwargs) -> SwinTransformerV2:
  1088. """Swin-B V2 @ 192x192, window 12x12."""
  1089. model_args = dict(window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32))
  1090. return _create_swin_transformer_v2(
  1091. 'swinv2_base_window12_192', pretrained=pretrained, **dict(model_args, **kwargs))
  1092. @register_model
  1093. def swinv2_base_window12to16_192to256(pretrained: bool = False, **kwargs) -> SwinTransformerV2:
  1094. """Swin-B V2 @ 192x192, trained at window 12x12, fine-tuned to 256x256 window 16x16."""
  1095. model_args = dict(
  1096. window_size=16, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32),
  1097. pretrained_window_sizes=(12, 12, 12, 6))
  1098. return _create_swin_transformer_v2(
  1099. 'swinv2_base_window12to16_192to256', pretrained=pretrained, **dict(model_args, **kwargs))
  1100. @register_model
  1101. def swinv2_base_window12to24_192to384(pretrained: bool = False, **kwargs) -> SwinTransformerV2:
  1102. """Swin-B V2 @ 192x192, trained at window 12x12, fine-tuned to 384x384 window 24x24."""
  1103. model_args = dict(
  1104. window_size=24, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32),
  1105. pretrained_window_sizes=(12, 12, 12, 6))
  1106. return _create_swin_transformer_v2(
  1107. 'swinv2_base_window12to24_192to384', pretrained=pretrained, **dict(model_args, **kwargs))
  1108. @register_model
  1109. def swinv2_large_window12_192(pretrained: bool = False, **kwargs) -> SwinTransformerV2:
  1110. """Swin-L V2 @ 192x192, window 12x12."""
  1111. model_args = dict(window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48))
  1112. return _create_swin_transformer_v2(
  1113. 'swinv2_large_window12_192', pretrained=pretrained, **dict(model_args, **kwargs))
  1114. @register_model
  1115. def swinv2_large_window12to16_192to256(pretrained: bool = False, **kwargs) -> SwinTransformerV2:
  1116. """Swin-L V2 @ 192x192, trained at window 12x12, fine-tuned to 256x256 window 16x16."""
  1117. model_args = dict(
  1118. window_size=16, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48),
  1119. pretrained_window_sizes=(12, 12, 12, 6))
  1120. return _create_swin_transformer_v2(
  1121. 'swinv2_large_window12to16_192to256', pretrained=pretrained, **dict(model_args, **kwargs))
  1122. @register_model
  1123. def swinv2_large_window12to24_192to384(pretrained: bool = False, **kwargs) -> SwinTransformerV2:
  1124. """Swin-L V2 @ 192x192, trained at window 12x12, fine-tuned to 384x384 window 24x24."""
  1125. model_args = dict(
  1126. window_size=24, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48),
  1127. pretrained_window_sizes=(12, 12, 12, 6))
  1128. return _create_swin_transformer_v2(
  1129. 'swinv2_large_window12to24_192to384', pretrained=pretrained, **dict(model_args, **kwargs))
  1130. register_model_deprecations(__name__, {
  1131. 'swinv2_base_window12_192_22k': 'swinv2_base_window12_192.ms_in22k',
  1132. 'swinv2_base_window12to16_192to256_22kft1k': 'swinv2_base_window12to16_192to256.ms_in22k_ft_in1k',
  1133. 'swinv2_base_window12to24_192to384_22kft1k': 'swinv2_base_window12to24_192to384.ms_in22k_ft_in1k',
  1134. 'swinv2_large_window12_192_22k': 'swinv2_large_window12_192.ms_in22k',
  1135. 'swinv2_large_window12to16_192to256_22kft1k': 'swinv2_large_window12to16_192to256.ms_in22k_ft_in1k',
  1136. 'swinv2_large_window12to24_192to384_22kft1k': 'swinv2_large_window12to24_192to384.ms_in22k_ft_in1k',
  1137. })