modeling_swinv2.py 55 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287
  1. # Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Swinv2 Transformer model."""
  15. import collections.abc
  16. import math
  17. from dataclasses import dataclass
  18. import torch
  19. from torch import Tensor, nn
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...backbone_utils import BackboneMixin, filter_output_hidden_states
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_outputs import BackboneOutput
  25. from ...modeling_utils import PreTrainedModel
  26. from ...utils import ModelOutput, auto_docstring, logging, torch_int
  27. from ...utils.generic import can_return_tuple
  28. from .configuration_swinv2 import Swinv2Config
  29. logger = logging.get_logger(__name__)
  30. # drop_path, Swinv2PatchEmbeddings, Swinv2PatchMerging and Swinv2DropPath are from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/swin_transformer_v2.py.
  31. @dataclass
  32. @auto_docstring(
  33. custom_intro="""
  34. Swinv2 encoder's outputs, with potential hidden states and attentions.
  35. """
  36. )
  37. # Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->Swinv2
  38. class Swinv2EncoderOutput(ModelOutput):
  39. r"""
  40. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  41. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  42. shape `(batch_size, hidden_size, height, width)`.
  43. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  44. include the spatial dimensions.
  45. """
  46. last_hidden_state: torch.FloatTensor | None = None
  47. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  48. attentions: tuple[torch.FloatTensor, ...] | None = None
  49. reshaped_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  50. @dataclass
  51. @auto_docstring(
  52. custom_intro="""
  53. Swinv2 model's outputs that also contains a pooling of the last hidden states.
  54. """
  55. )
  56. # Copied from transformers.models.swin.modeling_swin.SwinModelOutput with Swin->Swinv2
  57. class Swinv2ModelOutput(ModelOutput):
  58. r"""
  59. pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
  60. Average pooling of the last layer hidden-state.
  61. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  62. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  63. shape `(batch_size, hidden_size, height, width)`.
  64. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  65. include the spatial dimensions.
  66. """
  67. last_hidden_state: torch.FloatTensor | None = None
  68. pooler_output: torch.FloatTensor | None = None
  69. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  70. attentions: tuple[torch.FloatTensor, ...] | None = None
  71. reshaped_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  72. @dataclass
  73. @auto_docstring(
  74. custom_intro="""
  75. Swinv2 masked image model outputs.
  76. """
  77. )
  78. # Copied from transformers.models.swin.modeling_swin.SwinMaskedImageModelingOutput with Swin->Swinv2
  79. class Swinv2MaskedImageModelingOutput(ModelOutput):
  80. r"""
  81. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
  82. Masked image modeling (MLM) loss.
  83. reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  84. Reconstructed pixel values.
  85. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  86. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  87. shape `(batch_size, hidden_size, height, width)`.
  88. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  89. include the spatial dimensions.
  90. """
  91. loss: torch.FloatTensor | None = None
  92. reconstruction: torch.FloatTensor | None = None
  93. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  94. attentions: tuple[torch.FloatTensor, ...] | None = None
  95. reshaped_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  96. @dataclass
  97. @auto_docstring(
  98. custom_intro="""
  99. Swinv2 outputs for image classification.
  100. """
  101. )
  102. # Copied from transformers.models.swin.modeling_swin.SwinImageClassifierOutput with Swin->Swinv2
  103. class Swinv2ImageClassifierOutput(ModelOutput):
  104. r"""
  105. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  106. Classification (or regression if config.num_labels==1) loss.
  107. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  108. Classification (or regression if config.num_labels==1) scores (before SoftMax).
  109. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  110. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  111. shape `(batch_size, hidden_size, height, width)`.
  112. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  113. include the spatial dimensions.
  114. """
  115. loss: torch.FloatTensor | None = None
  116. logits: torch.FloatTensor | None = None
  117. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  118. attentions: tuple[torch.FloatTensor, ...] | None = None
  119. reshaped_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  120. # Copied from transformers.models.swin.modeling_swin.window_partition
  121. def window_partition(input_feature, window_size):
  122. """
  123. Partitions the given input into windows.
  124. """
  125. batch_size, height, width, num_channels = input_feature.shape
  126. input_feature = input_feature.view(
  127. batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
  128. )
  129. windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
  130. return windows
  131. # Copied from transformers.models.swin.modeling_swin.window_reverse
  132. def window_reverse(windows, window_size, height, width):
  133. """
  134. Merges windows to produce higher resolution features.
  135. """
  136. num_channels = windows.shape[-1]
  137. windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
  138. windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
  139. return windows
  140. # Copied from transformers.models.swin.modeling_swin.drop_path
  141. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  142. """
  143. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  144. """
  145. if drop_prob == 0.0 or not training:
  146. return input
  147. keep_prob = 1 - drop_prob
  148. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  149. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  150. random_tensor.floor_() # binarize
  151. output = input.div(keep_prob) * random_tensor
  152. return output
  153. # Copied from transformers.models.swin.modeling_swin.SwinDropPath with Swin->Swinv2
  154. class Swinv2DropPath(nn.Module):
  155. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  156. def __init__(self, drop_prob: float | None = None) -> None:
  157. super().__init__()
  158. self.drop_prob = drop_prob
  159. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  160. return drop_path(hidden_states, self.drop_prob, self.training)
  161. def extra_repr(self) -> str:
  162. return f"p={self.drop_prob}"
  163. # Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->Swinv2
  164. class Swinv2Embeddings(nn.Module):
  165. """
  166. Construct the patch and position embeddings. Optionally, also the mask token.
  167. """
  168. def __init__(self, config, use_mask_token=False):
  169. super().__init__()
  170. self.patch_embeddings = Swinv2PatchEmbeddings(config)
  171. num_patches = self.patch_embeddings.num_patches
  172. self.patch_grid = self.patch_embeddings.grid_size
  173. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
  174. if config.use_absolute_embeddings:
  175. self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
  176. else:
  177. self.position_embeddings = None
  178. self.norm = nn.LayerNorm(config.embed_dim)
  179. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  180. self.patch_size = config.patch_size
  181. self.config = config
  182. # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
  183. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  184. """
  185. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  186. images. This method is also adapted to support torch.jit tracing.
  187. Adapted from:
  188. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  189. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  190. """
  191. num_patches = embeddings.shape[1] - 1
  192. num_positions = self.position_embeddings.shape[1] - 1
  193. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  194. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  195. return self.position_embeddings
  196. class_pos_embed = self.position_embeddings[:, :1]
  197. patch_pos_embed = self.position_embeddings[:, 1:]
  198. dim = embeddings.shape[-1]
  199. new_height = height // self.patch_size
  200. new_width = width // self.patch_size
  201. sqrt_num_positions = torch_int(num_positions**0.5)
  202. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  203. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  204. patch_pos_embed = nn.functional.interpolate(
  205. patch_pos_embed,
  206. size=(new_height, new_width),
  207. mode="bicubic",
  208. align_corners=False,
  209. )
  210. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  211. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  212. def forward(
  213. self,
  214. pixel_values: torch.FloatTensor | None,
  215. bool_masked_pos: torch.BoolTensor | None = None,
  216. interpolate_pos_encoding: bool = False,
  217. ) -> tuple[torch.Tensor]:
  218. _, num_channels, height, width = pixel_values.shape
  219. embeddings, output_dimensions = self.patch_embeddings(pixel_values)
  220. embeddings = self.norm(embeddings)
  221. batch_size, seq_len, _ = embeddings.size()
  222. if bool_masked_pos is not None:
  223. mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
  224. # replace the masked visual tokens by mask_tokens
  225. mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
  226. embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
  227. if self.position_embeddings is not None:
  228. if interpolate_pos_encoding:
  229. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  230. else:
  231. embeddings = embeddings + self.position_embeddings
  232. embeddings = self.dropout(embeddings)
  233. return embeddings, output_dimensions
  234. # Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->Swinv2
  235. class Swinv2PatchEmbeddings(nn.Module):
  236. """
  237. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  238. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  239. Transformer.
  240. """
  241. def __init__(self, config):
  242. super().__init__()
  243. image_size, patch_size = config.image_size, config.patch_size
  244. num_channels, hidden_size = config.num_channels, config.embed_dim
  245. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  246. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  247. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  248. self.image_size = image_size
  249. self.patch_size = patch_size
  250. self.num_channels = num_channels
  251. self.num_patches = num_patches
  252. self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
  253. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  254. def maybe_pad(self, pixel_values, height, width):
  255. if width % self.patch_size[1] != 0:
  256. pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
  257. pixel_values = nn.functional.pad(pixel_values, pad_values)
  258. if height % self.patch_size[0] != 0:
  259. pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
  260. pixel_values = nn.functional.pad(pixel_values, pad_values)
  261. return pixel_values
  262. def forward(self, pixel_values: torch.FloatTensor | None) -> tuple[torch.Tensor, tuple[int]]:
  263. _, num_channels, height, width = pixel_values.shape
  264. # pad the input to be divisible by self.patch_size, if needed
  265. pixel_values = self.maybe_pad(pixel_values, height, width)
  266. embeddings = self.projection(pixel_values)
  267. _, _, height, width = embeddings.shape
  268. output_dimensions = (height, width)
  269. embeddings = embeddings.flatten(2).transpose(1, 2)
  270. return embeddings, output_dimensions
  271. class Swinv2PatchMerging(nn.Module):
  272. """
  273. Patch Merging Layer.
  274. Args:
  275. input_resolution (`tuple[int]`):
  276. Resolution of input feature.
  277. dim (`int`):
  278. Number of input channels.
  279. norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
  280. Normalization layer class.
  281. """
  282. def __init__(self, input_resolution: tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
  283. super().__init__()
  284. self.input_resolution = input_resolution
  285. self.dim = dim
  286. self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
  287. self.norm = norm_layer(2 * dim)
  288. def maybe_pad(self, input_feature, height, width):
  289. should_pad = (height % 2 == 1) or (width % 2 == 1)
  290. if should_pad:
  291. pad_values = (0, 0, 0, width % 2, 0, height % 2)
  292. input_feature = nn.functional.pad(input_feature, pad_values)
  293. return input_feature
  294. def forward(self, input_feature: torch.Tensor, input_dimensions: tuple[int, int]) -> torch.Tensor:
  295. height, width = input_dimensions
  296. # `dim` is height * width
  297. batch_size, dim, num_channels = input_feature.shape
  298. input_feature = input_feature.view(batch_size, height, width, num_channels)
  299. # pad input to be divisible by width and height, if needed
  300. input_feature = self.maybe_pad(input_feature, height, width)
  301. # [batch_size, height/2, width/2, num_channels]
  302. input_feature_0 = input_feature[:, 0::2, 0::2, :]
  303. # [batch_size, height/2, width/2, num_channels]
  304. input_feature_1 = input_feature[:, 1::2, 0::2, :]
  305. # [batch_size, height/2, width/2, num_channels]
  306. input_feature_2 = input_feature[:, 0::2, 1::2, :]
  307. # [batch_size, height/2, width/2, num_channels]
  308. input_feature_3 = input_feature[:, 1::2, 1::2, :]
  309. # [batch_size, height/2 * width/2, 4*num_channels]
  310. input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
  311. input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # [batch_size, height/2 * width/2, 4*C]
  312. input_feature = self.reduction(input_feature)
  313. input_feature = self.norm(input_feature)
  314. return input_feature
  315. class Swinv2SelfAttention(nn.Module):
  316. def __init__(self, config, dim, num_heads, window_size, pretrained_window_size=[0, 0]):
  317. super().__init__()
  318. if dim % num_heads != 0:
  319. raise ValueError(
  320. f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
  321. )
  322. self.num_attention_heads = num_heads
  323. self.attention_head_size = int(dim / num_heads)
  324. self.all_head_size = self.num_attention_heads * self.attention_head_size
  325. self.window_size = (
  326. window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
  327. )
  328. self.pretrained_window_size = pretrained_window_size
  329. self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
  330. # mlp to generate continuous relative position bias
  331. self.continuous_position_bias_mlp = nn.Sequential(
  332. nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False)
  333. )
  334. relative_coords_table, relative_position_index = self.create_coords_table_and_index()
  335. self.register_buffer("relative_coords_table", relative_coords_table, persistent=False)
  336. self.register_buffer("relative_position_index", relative_position_index, persistent=False)
  337. self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
  338. self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=False)
  339. self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
  340. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  341. def forward(
  342. self,
  343. hidden_states: torch.Tensor,
  344. attention_mask: torch.FloatTensor | None = None,
  345. output_attentions: bool | None = False,
  346. ) -> tuple[torch.Tensor]:
  347. batch_size, dim, num_channels = hidden_states.shape
  348. query_layer = (
  349. self.query(hidden_states)
  350. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  351. .transpose(1, 2)
  352. )
  353. key_layer = (
  354. self.key(hidden_states)
  355. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  356. .transpose(1, 2)
  357. )
  358. value_layer = (
  359. self.value(hidden_states)
  360. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  361. .transpose(1, 2)
  362. )
  363. # cosine attention
  364. attention_scores = nn.functional.normalize(query_layer, dim=-1) @ nn.functional.normalize(
  365. key_layer, dim=-1
  366. ).transpose(-2, -1)
  367. logit_scale = torch.clamp(self.logit_scale, max=math.log(1.0 / 0.01)).exp()
  368. attention_scores = attention_scores * logit_scale
  369. relative_position_bias_table = self.continuous_position_bias_mlp(self.relative_coords_table).view(
  370. -1, self.num_attention_heads
  371. )
  372. # [window_height*window_width,window_height*window_width,num_attention_heads]
  373. relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
  374. self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
  375. )
  376. # [num_attention_heads,window_height*window_width,window_height*window_width]
  377. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
  378. relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
  379. attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
  380. if attention_mask is not None:
  381. # Apply the attention mask is (precomputed for all layers in Swinv2Model forward() function)
  382. mask_shape = attention_mask.shape[0]
  383. attention_scores = attention_scores.view(
  384. batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
  385. ) + attention_mask.unsqueeze(1).unsqueeze(0)
  386. attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0)
  387. attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)
  388. # Normalize the attention scores to probabilities.
  389. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  390. # This is actually dropping out entire tokens to attend to, which might
  391. # seem a bit unusual, but is taken from the original Transformer paper.
  392. attention_probs = self.dropout(attention_probs)
  393. # Mask heads if we want to
  394. context_layer = torch.matmul(attention_probs, value_layer)
  395. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  396. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  397. context_layer = context_layer.view(new_context_layer_shape)
  398. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  399. return outputs
  400. def create_coords_table_and_index(self):
  401. # get relative_coords_table
  402. relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.int64).float()
  403. relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.int64).float()
  404. relative_coords_table = (
  405. torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w], indexing="ij"))
  406. .permute(1, 2, 0)
  407. .contiguous()
  408. .unsqueeze(0)
  409. ) # [1, 2*window_height - 1, 2*window_width - 1, 2]
  410. if self.pretrained_window_size[0] > 0:
  411. relative_coords_table[:, :, :, 0] /= self.pretrained_window_size[0] - 1
  412. relative_coords_table[:, :, :, 1] /= self.pretrained_window_size[1] - 1
  413. elif self.window_size[0] > 1:
  414. relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
  415. relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
  416. relative_coords_table *= 8 # normalize to -8, 8
  417. relative_coords_table = (
  418. torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / math.log2(8)
  419. )
  420. # set to same dtype as mlp weight
  421. relative_coords_table = relative_coords_table.to(next(self.continuous_position_bias_mlp.parameters()).dtype)
  422. # get pair-wise relative position index for each token inside the window
  423. coords_h = torch.arange(self.window_size[0])
  424. coords_w = torch.arange(self.window_size[1])
  425. coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))
  426. coords_flatten = torch.flatten(coords, 1)
  427. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
  428. relative_coords = relative_coords.permute(1, 2, 0).contiguous()
  429. relative_coords[:, :, 0] += self.window_size[0] - 1
  430. relative_coords[:, :, 1] += self.window_size[1] - 1
  431. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  432. relative_position_index = relative_coords.sum(-1)
  433. return relative_coords_table, relative_position_index
  434. # Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->Swinv2
  435. class Swinv2SelfOutput(nn.Module):
  436. def __init__(self, config, dim):
  437. super().__init__()
  438. self.dense = nn.Linear(dim, dim)
  439. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  440. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  441. hidden_states = self.dense(hidden_states)
  442. hidden_states = self.dropout(hidden_states)
  443. return hidden_states
  444. class Swinv2Attention(nn.Module):
  445. def __init__(self, config, dim, num_heads, window_size, pretrained_window_size=0):
  446. super().__init__()
  447. self.self = Swinv2SelfAttention(
  448. config=config,
  449. dim=dim,
  450. num_heads=num_heads,
  451. window_size=window_size,
  452. pretrained_window_size=pretrained_window_size
  453. if isinstance(pretrained_window_size, collections.abc.Iterable)
  454. else (pretrained_window_size, pretrained_window_size),
  455. )
  456. self.output = Swinv2SelfOutput(config, dim)
  457. def forward(
  458. self,
  459. hidden_states: torch.Tensor,
  460. attention_mask: torch.FloatTensor | None = None,
  461. output_attentions: bool | None = False,
  462. ) -> tuple[torch.Tensor]:
  463. self_outputs = self.self(hidden_states, attention_mask, output_attentions)
  464. attention_output = self.output(self_outputs[0], hidden_states)
  465. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  466. return outputs
  467. # Copied from transformers.models.swin.modeling_swin.SwinIntermediate with Swin->Swinv2
  468. class Swinv2Intermediate(nn.Module):
  469. def __init__(self, config, dim):
  470. super().__init__()
  471. self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
  472. if isinstance(config.hidden_act, str):
  473. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  474. else:
  475. self.intermediate_act_fn = config.hidden_act
  476. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  477. hidden_states = self.dense(hidden_states)
  478. hidden_states = self.intermediate_act_fn(hidden_states)
  479. return hidden_states
  480. # Copied from transformers.models.swin.modeling_swin.SwinOutput with Swin->Swinv2
  481. class Swinv2Output(nn.Module):
  482. def __init__(self, config, dim):
  483. super().__init__()
  484. self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
  485. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  486. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  487. hidden_states = self.dense(hidden_states)
  488. hidden_states = self.dropout(hidden_states)
  489. return hidden_states
  490. class Swinv2Layer(nn.Module):
  491. def __init__(
  492. self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0, pretrained_window_size=0
  493. ):
  494. super().__init__()
  495. self.input_resolution = input_resolution
  496. window_size, shift_size = self._compute_window_shift(
  497. (config.window_size, config.window_size), (shift_size, shift_size)
  498. )
  499. self.window_size = window_size[0]
  500. self.shift_size = shift_size[0]
  501. self.attention = Swinv2Attention(
  502. config=config,
  503. dim=dim,
  504. num_heads=num_heads,
  505. window_size=self.window_size,
  506. pretrained_window_size=pretrained_window_size
  507. if isinstance(pretrained_window_size, collections.abc.Iterable)
  508. else (pretrained_window_size, pretrained_window_size),
  509. )
  510. self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  511. self.drop_path = Swinv2DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
  512. self.intermediate = Swinv2Intermediate(config, dim)
  513. self.output = Swinv2Output(config, dim)
  514. self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  515. def _compute_window_shift(self, target_window_size, target_shift_size) -> tuple[tuple[int, int], tuple[int, int]]:
  516. window_size = [min(r, w) for r, w in zip(self.input_resolution, target_window_size)]
  517. shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
  518. return window_size, shift_size
  519. def get_attn_mask(self, height, width, dtype):
  520. if self.shift_size > 0:
  521. # calculate attention mask for shifted window multihead self attention
  522. img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
  523. height_slices = (
  524. slice(0, -self.window_size),
  525. slice(-self.window_size, -self.shift_size),
  526. slice(-self.shift_size, None),
  527. )
  528. width_slices = (
  529. slice(0, -self.window_size),
  530. slice(-self.window_size, -self.shift_size),
  531. slice(-self.shift_size, None),
  532. )
  533. count = 0
  534. for height_slice in height_slices:
  535. for width_slice in width_slices:
  536. img_mask[:, height_slice, width_slice, :] = count
  537. count += 1
  538. mask_windows = window_partition(img_mask, self.window_size)
  539. mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
  540. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  541. attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0).masked_fill(attn_mask == 0, 0.0)
  542. else:
  543. attn_mask = None
  544. return attn_mask
  545. def maybe_pad(self, hidden_states, height, width):
  546. pad_right = (self.window_size - width % self.window_size) % self.window_size
  547. pad_bottom = (self.window_size - height % self.window_size) % self.window_size
  548. pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
  549. hidden_states = nn.functional.pad(hidden_states, pad_values)
  550. return hidden_states, pad_values
  551. def forward(
  552. self,
  553. hidden_states: torch.Tensor,
  554. input_dimensions: tuple[int, int],
  555. output_attentions: bool | None = False,
  556. ) -> tuple[torch.Tensor, torch.Tensor]:
  557. height, width = input_dimensions
  558. batch_size, _, channels = hidden_states.size()
  559. shortcut = hidden_states
  560. # pad hidden_states to multiples of window size
  561. hidden_states = hidden_states.view(batch_size, height, width, channels)
  562. hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
  563. _, height_pad, width_pad, _ = hidden_states.shape
  564. # cyclic shift
  565. if self.shift_size > 0:
  566. shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
  567. else:
  568. shifted_hidden_states = hidden_states
  569. # partition windows
  570. hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
  571. hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
  572. attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
  573. if attn_mask is not None:
  574. attn_mask = attn_mask.to(hidden_states_windows.device)
  575. attention_outputs = self.attention(hidden_states_windows, attn_mask, output_attentions=output_attentions)
  576. attention_output = attention_outputs[0]
  577. attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
  578. shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad)
  579. # reverse cyclic shift
  580. if self.shift_size > 0:
  581. attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
  582. else:
  583. attention_windows = shifted_windows
  584. was_padded = pad_values[3] > 0 or pad_values[5] > 0
  585. if was_padded:
  586. attention_windows = attention_windows[:, :height, :width, :].contiguous()
  587. attention_windows = attention_windows.view(batch_size, height * width, channels)
  588. hidden_states = self.layernorm_before(attention_windows)
  589. hidden_states = shortcut + self.drop_path(hidden_states)
  590. layer_output = self.intermediate(hidden_states)
  591. layer_output = self.output(layer_output)
  592. layer_output = hidden_states + self.drop_path(self.layernorm_after(layer_output))
  593. layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
  594. return layer_outputs
  595. class Swinv2Stage(GradientCheckpointingLayer):
  596. def __init__(
  597. self, config, dim, input_resolution, depth, num_heads, drop_path, downsample, pretrained_window_size=0
  598. ):
  599. super().__init__()
  600. self.config = config
  601. self.dim = dim
  602. blocks = []
  603. for i in range(depth):
  604. block = Swinv2Layer(
  605. config=config,
  606. dim=dim,
  607. input_resolution=input_resolution,
  608. num_heads=num_heads,
  609. drop_path_rate=drop_path[i],
  610. shift_size=0 if (i % 2 == 0) else config.window_size // 2,
  611. pretrained_window_size=pretrained_window_size,
  612. )
  613. blocks.append(block)
  614. self.blocks = nn.ModuleList(blocks)
  615. # patch merging layer
  616. if downsample is not None:
  617. self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm)
  618. else:
  619. self.downsample = None
  620. self.pointing = False
  621. def forward(
  622. self,
  623. hidden_states: torch.Tensor,
  624. input_dimensions: tuple[int, int],
  625. output_attentions: bool | None = False,
  626. ) -> tuple[torch.Tensor]:
  627. height, width = input_dimensions
  628. for i, layer_module in enumerate(self.blocks):
  629. layer_outputs = layer_module(
  630. hidden_states,
  631. input_dimensions,
  632. output_attentions,
  633. )
  634. hidden_states = layer_outputs[0]
  635. hidden_states_before_downsampling = hidden_states
  636. if self.downsample is not None:
  637. height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
  638. output_dimensions = (height, width, height_downsampled, width_downsampled)
  639. hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions)
  640. else:
  641. output_dimensions = (height, width, height, width)
  642. stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)
  643. if output_attentions:
  644. stage_outputs += layer_outputs[1:]
  645. return stage_outputs
  646. class Swinv2Encoder(nn.Module):
  647. def __init__(self, config, grid_size, pretrained_window_sizes=(0, 0, 0, 0)):
  648. super().__init__()
  649. self.num_layers = len(config.depths)
  650. self.config = config
  651. if self.config.pretrained_window_sizes is not None:
  652. pretrained_window_sizes = config.pretrained_window_sizes
  653. dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
  654. layers = []
  655. for i_layer in range(self.num_layers):
  656. stage = Swinv2Stage(
  657. config=config,
  658. dim=int(config.embed_dim * 2**i_layer),
  659. input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
  660. depth=config.depths[i_layer],
  661. num_heads=config.num_heads[i_layer],
  662. drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
  663. downsample=Swinv2PatchMerging if (i_layer < self.num_layers - 1) else None,
  664. pretrained_window_size=pretrained_window_sizes[i_layer],
  665. )
  666. layers.append(stage)
  667. self.layers = nn.ModuleList(layers)
  668. self.gradient_checkpointing = False
  669. def forward(
  670. self,
  671. hidden_states: torch.Tensor,
  672. input_dimensions: tuple[int, int],
  673. output_attentions: bool | None = False,
  674. output_hidden_states: bool | None = False,
  675. output_hidden_states_before_downsampling: bool | None = False,
  676. return_dict: bool | None = True,
  677. ) -> tuple | Swinv2EncoderOutput:
  678. all_hidden_states = () if output_hidden_states else None
  679. all_reshaped_hidden_states = () if output_hidden_states else None
  680. all_self_attentions = () if output_attentions else None
  681. if output_hidden_states:
  682. batch_size, _, hidden_size = hidden_states.shape
  683. # rearrange b (h w) c -> b c h w
  684. reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
  685. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  686. all_hidden_states += (hidden_states,)
  687. all_reshaped_hidden_states += (reshaped_hidden_state,)
  688. for i, layer_module in enumerate(self.layers):
  689. layer_outputs = layer_module(
  690. hidden_states,
  691. input_dimensions,
  692. output_attentions,
  693. )
  694. hidden_states = layer_outputs[0]
  695. hidden_states_before_downsampling = layer_outputs[1]
  696. output_dimensions = layer_outputs[2]
  697. input_dimensions = (output_dimensions[-2], output_dimensions[-1])
  698. if output_hidden_states and output_hidden_states_before_downsampling:
  699. batch_size, _, hidden_size = hidden_states_before_downsampling.shape
  700. # rearrange b (h w) c -> b c h w
  701. # here we use the original (not downsampled) height and width
  702. reshaped_hidden_state = hidden_states_before_downsampling.view(
  703. batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
  704. )
  705. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  706. all_hidden_states += (hidden_states_before_downsampling,)
  707. all_reshaped_hidden_states += (reshaped_hidden_state,)
  708. elif output_hidden_states and not output_hidden_states_before_downsampling:
  709. batch_size, _, hidden_size = hidden_states.shape
  710. # rearrange b (h w) c -> b c h w
  711. reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
  712. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  713. all_hidden_states += (hidden_states,)
  714. all_reshaped_hidden_states += (reshaped_hidden_state,)
  715. if output_attentions:
  716. all_self_attentions += layer_outputs[3:]
  717. if not return_dict:
  718. return tuple(
  719. v
  720. for v in [hidden_states, all_hidden_states, all_self_attentions, all_reshaped_hidden_states]
  721. if v is not None
  722. )
  723. return Swinv2EncoderOutput(
  724. last_hidden_state=hidden_states,
  725. hidden_states=all_hidden_states,
  726. attentions=all_self_attentions,
  727. reshaped_hidden_states=all_reshaped_hidden_states,
  728. )
  729. @auto_docstring
  730. class Swinv2PreTrainedModel(PreTrainedModel):
  731. config: Swinv2Config
  732. base_model_prefix = "swinv2"
  733. main_input_name = "pixel_values"
  734. input_modalities = ("image",)
  735. supports_gradient_checkpointing = True
  736. _no_split_modules = ["Swinv2Stage"]
  737. @torch.no_grad()
  738. def _init_weights(self, module):
  739. """Initialize the weights"""
  740. if isinstance(module, (nn.Linear, nn.Conv2d)):
  741. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  742. if module.bias is not None:
  743. init.zeros_(module.bias)
  744. elif isinstance(module, nn.LayerNorm):
  745. init.zeros_(module.bias)
  746. init.ones_(module.weight)
  747. elif isinstance(module, Swinv2Embeddings):
  748. if module.mask_token is not None:
  749. init.zeros_(module.mask_token)
  750. if module.position_embeddings is not None:
  751. init.zeros_(module.position_embeddings)
  752. elif isinstance(module, Swinv2SelfAttention):
  753. init.constant_(module.logit_scale, math.log(10))
  754. relative_coords_table, relative_position_index = module.create_coords_table_and_index()
  755. init.copy_(module.relative_coords_table, relative_coords_table)
  756. init.copy_(module.relative_position_index, relative_position_index)
  757. @auto_docstring
  758. # Copied from transformers.models.swin.modeling_swin.SwinModel with SWIN->SWINV2,Swin->Swinv2
  759. class Swinv2Model(Swinv2PreTrainedModel):
  760. def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
  761. r"""
  762. add_pooling_layer (`bool`, *optional*, defaults to `True`):
  763. Whether or not to apply pooling layer.
  764. use_mask_token (`bool`, *optional*, defaults to `False`):
  765. Whether or not to create and apply mask tokens in the embedding layer.
  766. """
  767. super().__init__(config)
  768. self.config = config
  769. self.num_layers = len(config.depths)
  770. self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
  771. self.embeddings = Swinv2Embeddings(config, use_mask_token=use_mask_token)
  772. self.encoder = Swinv2Encoder(config, self.embeddings.patch_grid)
  773. self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
  774. self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
  775. # Initialize weights and apply final processing
  776. self.post_init()
  777. def get_input_embeddings(self):
  778. return self.embeddings.patch_embeddings
  779. @auto_docstring
  780. def forward(
  781. self,
  782. pixel_values: torch.FloatTensor | None = None,
  783. bool_masked_pos: torch.BoolTensor | None = None,
  784. output_attentions: bool | None = None,
  785. output_hidden_states: bool | None = None,
  786. interpolate_pos_encoding: bool = False,
  787. return_dict: bool | None = None,
  788. **kwargs,
  789. ) -> tuple | Swinv2ModelOutput:
  790. r"""
  791. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
  792. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  793. """
  794. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  795. output_hidden_states = (
  796. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  797. )
  798. return_dict = return_dict if return_dict is not None else self.config.return_dict
  799. if pixel_values is None:
  800. raise ValueError("You have to specify pixel_values")
  801. embedding_output, input_dimensions = self.embeddings(
  802. pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
  803. )
  804. encoder_outputs = self.encoder(
  805. embedding_output,
  806. input_dimensions,
  807. output_attentions=output_attentions,
  808. output_hidden_states=output_hidden_states,
  809. return_dict=return_dict,
  810. )
  811. sequence_output = encoder_outputs[0]
  812. sequence_output = self.layernorm(sequence_output)
  813. pooled_output = None
  814. if self.pooler is not None:
  815. pooled_output = self.pooler(sequence_output.transpose(1, 2))
  816. pooled_output = torch.flatten(pooled_output, 1)
  817. if not return_dict:
  818. output = (sequence_output, pooled_output) + encoder_outputs[1:]
  819. return output
  820. return Swinv2ModelOutput(
  821. last_hidden_state=sequence_output,
  822. pooler_output=pooled_output,
  823. hidden_states=encoder_outputs.hidden_states,
  824. attentions=encoder_outputs.attentions,
  825. reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
  826. )
  827. @auto_docstring(
  828. custom_intro="""
  829. Swinv2 Model with a decoder on top for masked image modeling, as proposed in
  830. [SimMIM](https://huggingface.co/papers/2111.09886).
  831. <Tip>
  832. Note that we provide a script to pre-train this model on custom data in our [examples
  833. directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
  834. </Tip>
  835. """
  836. )
  837. # Copied from transformers.models.swin.modeling_swin.SwinForMaskedImageModeling with swin->swinv2, base-simmim-window6-192->tiny-patch4-window8-256,SWIN->SWINV2,Swin->Swinv2,192->256
  838. class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel):
  839. def __init__(self, config):
  840. super().__init__(config)
  841. self.swinv2 = Swinv2Model(config, add_pooling_layer=False, use_mask_token=True)
  842. num_features = int(config.embed_dim * 2 ** (config.num_layers - 1))
  843. self.decoder = nn.Sequential(
  844. nn.Conv2d(
  845. in_channels=num_features, out_channels=config.encoder_stride**2 * config.num_channels, kernel_size=1
  846. ),
  847. nn.PixelShuffle(config.encoder_stride),
  848. )
  849. # Initialize weights and apply final processing
  850. self.post_init()
  851. @auto_docstring
  852. def forward(
  853. self,
  854. pixel_values: torch.FloatTensor | None = None,
  855. bool_masked_pos: torch.BoolTensor | None = None,
  856. output_attentions: bool | None = None,
  857. output_hidden_states: bool | None = None,
  858. interpolate_pos_encoding: bool = False,
  859. return_dict: bool | None = None,
  860. **kwargs,
  861. ) -> tuple | Swinv2MaskedImageModelingOutput:
  862. r"""
  863. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
  864. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  865. Examples:
  866. ```python
  867. >>> from transformers import AutoImageProcessor, Swinv2ForMaskedImageModeling
  868. >>> import torch
  869. >>> from PIL import Image
  870. >>> import httpx
  871. >>> from io import BytesIO
  872. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  873. >>> with httpx.stream("GET", url) as response:
  874. ... image = Image.open(BytesIO(response.read()))
  875. >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
  876. >>> model = Swinv2ForMaskedImageModeling.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
  877. >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
  878. >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
  879. >>> # create random boolean mask of shape (batch_size, num_patches)
  880. >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
  881. >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
  882. >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
  883. >>> list(reconstructed_pixel_values.shape)
  884. [1, 3, 256, 256]
  885. ```"""
  886. return_dict = return_dict if return_dict is not None else self.config.return_dict
  887. outputs = self.swinv2(
  888. pixel_values,
  889. bool_masked_pos=bool_masked_pos,
  890. output_attentions=output_attentions,
  891. output_hidden_states=output_hidden_states,
  892. interpolate_pos_encoding=interpolate_pos_encoding,
  893. return_dict=return_dict,
  894. )
  895. sequence_output = outputs[0]
  896. # Reshape to (batch_size, num_channels, height, width)
  897. sequence_output = sequence_output.transpose(1, 2)
  898. batch_size, num_channels, sequence_length = sequence_output.shape
  899. height = width = math.floor(sequence_length**0.5)
  900. sequence_output = sequence_output.reshape(batch_size, num_channels, height, width)
  901. # Reconstruct pixel values
  902. reconstructed_pixel_values = self.decoder(sequence_output)
  903. masked_im_loss = None
  904. if bool_masked_pos is not None:
  905. size = self.config.image_size // self.config.patch_size
  906. bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
  907. mask = (
  908. bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
  909. .repeat_interleave(self.config.patch_size, 2)
  910. .unsqueeze(1)
  911. .contiguous()
  912. )
  913. reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
  914. masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
  915. if not return_dict:
  916. output = (reconstructed_pixel_values,) + outputs[2:]
  917. return ((masked_im_loss,) + output) if masked_im_loss is not None else output
  918. return Swinv2MaskedImageModelingOutput(
  919. loss=masked_im_loss,
  920. reconstruction=reconstructed_pixel_values,
  921. hidden_states=outputs.hidden_states,
  922. attentions=outputs.attentions,
  923. reshaped_hidden_states=outputs.reshaped_hidden_states,
  924. )
  925. @auto_docstring(
  926. custom_intro="""
  927. Swinv2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
  928. of the [CLS] token) e.g. for ImageNet.
  929. <Tip>
  930. Note that it's possible to fine-tune SwinV2 on higher resolution images than the ones it has been trained on, by
  931. setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
  932. position embeddings to the higher resolution.
  933. </Tip>
  934. """
  935. )
  936. # Copied from transformers.models.swin.modeling_swin.SwinForImageClassification with SWIN->SWINV2,Swin->Swinv2,swin->swinv2
  937. class Swinv2ForImageClassification(Swinv2PreTrainedModel):
  938. def __init__(self, config):
  939. super().__init__(config)
  940. self.num_labels = config.num_labels
  941. self.swinv2 = Swinv2Model(config)
  942. # Classifier head
  943. self.classifier = (
  944. nn.Linear(self.swinv2.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
  945. )
  946. # Initialize weights and apply final processing
  947. self.post_init()
  948. @auto_docstring
  949. def forward(
  950. self,
  951. pixel_values: torch.FloatTensor | None = None,
  952. labels: torch.LongTensor | None = None,
  953. output_attentions: bool | None = None,
  954. output_hidden_states: bool | None = None,
  955. interpolate_pos_encoding: bool = False,
  956. return_dict: bool | None = None,
  957. **kwargs,
  958. ) -> tuple | Swinv2ImageClassifierOutput:
  959. r"""
  960. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  961. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  962. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  963. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  964. """
  965. return_dict = return_dict if return_dict is not None else self.config.return_dict
  966. outputs = self.swinv2(
  967. pixel_values,
  968. output_attentions=output_attentions,
  969. output_hidden_states=output_hidden_states,
  970. interpolate_pos_encoding=interpolate_pos_encoding,
  971. return_dict=return_dict,
  972. )
  973. pooled_output = outputs[1]
  974. logits = self.classifier(pooled_output)
  975. loss = None
  976. if labels is not None:
  977. loss = self.loss_function(labels, logits, self.config)
  978. if not return_dict:
  979. output = (logits,) + outputs[2:]
  980. return ((loss,) + output) if loss is not None else output
  981. return Swinv2ImageClassifierOutput(
  982. loss=loss,
  983. logits=logits,
  984. hidden_states=outputs.hidden_states,
  985. attentions=outputs.attentions,
  986. reshaped_hidden_states=outputs.reshaped_hidden_states,
  987. )
  988. @auto_docstring(
  989. custom_intro="""
  990. Swinv2 backbone, to be used with frameworks like DETR and MaskFormer.
  991. """
  992. )
  993. class Swinv2Backbone(BackboneMixin, Swinv2PreTrainedModel):
  994. def __init__(self, config):
  995. super().__init__(config)
  996. self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
  997. self.embeddings = Swinv2Embeddings(config)
  998. self.encoder = Swinv2Encoder(config, self.embeddings.patch_grid)
  999. # initialize weights and apply final processing
  1000. self.post_init()
  1001. def get_input_embeddings(self):
  1002. return self.embeddings.patch_embeddings
  1003. @can_return_tuple
  1004. @filter_output_hidden_states
  1005. @auto_docstring
  1006. def forward(
  1007. self,
  1008. pixel_values: Tensor,
  1009. output_attentions: bool | None = None,
  1010. output_hidden_states: bool | None = None,
  1011. return_dict: bool | None = None,
  1012. **kwargs,
  1013. ) -> BackboneOutput:
  1014. r"""
  1015. Examples:
  1016. ```python
  1017. >>> from transformers import AutoImageProcessor, AutoBackbone
  1018. >>> import torch
  1019. >>> from PIL import Image
  1020. >>> import httpx
  1021. >>> from io import BytesIO
  1022. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1023. >>> with httpx.stream("GET", url) as response:
  1024. ... image = Image.open(BytesIO(response.read()))
  1025. >>> processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
  1026. >>> model = AutoBackbone.from_pretrained(
  1027. ... "microsoft/swinv2-tiny-patch4-window8-256", out_features=["stage1", "stage2", "stage3", "stage4"]
  1028. ... )
  1029. >>> inputs = processor(image, return_tensors="pt")
  1030. >>> outputs = model(**inputs)
  1031. >>> feature_maps = outputs.feature_maps
  1032. >>> list(feature_maps[-1].shape)
  1033. [1, 2048, 7, 7]
  1034. ```"""
  1035. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1036. output_hidden_states = (
  1037. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1038. )
  1039. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1040. embedding_output, input_dimensions = self.embeddings(pixel_values)
  1041. outputs = self.encoder(
  1042. embedding_output,
  1043. input_dimensions,
  1044. output_attentions=output_attentions,
  1045. output_hidden_states=True,
  1046. output_hidden_states_before_downsampling=True,
  1047. return_dict=return_dict,
  1048. )
  1049. hidden_states = outputs.reshaped_hidden_states if return_dict else outputs[-1]
  1050. feature_maps = ()
  1051. for stage, hidden_state in zip(self.stage_names, hidden_states):
  1052. if stage in self.out_features:
  1053. feature_maps += (hidden_state,)
  1054. if not return_dict:
  1055. output = (feature_maps,)
  1056. if output_hidden_states:
  1057. output += (outputs[1],)
  1058. if output_attentions:
  1059. output += (outputs[2],)
  1060. return output
  1061. return BackboneOutput(
  1062. feature_maps=feature_maps,
  1063. hidden_states=outputs.hidden_states if output_hidden_states else None,
  1064. attentions=outputs.attentions,
  1065. )
  1066. __all__ = [
  1067. "Swinv2ForImageClassification",
  1068. "Swinv2ForMaskedImageModeling",
  1069. "Swinv2Model",
  1070. "Swinv2PreTrainedModel",
  1071. "Swinv2Backbone",
  1072. ]