modeling_swin.py 50 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209
  1. # Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Swin Transformer model."""
  15. import collections.abc
  16. import math
  17. from dataclasses import dataclass
  18. import torch
  19. from torch import nn
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...backbone_utils import BackboneMixin, filter_output_hidden_states
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_outputs import BackboneOutput
  25. from ...modeling_utils import PreTrainedModel
  26. from ...utils import ModelOutput, auto_docstring, logging, torch_int
  27. from ...utils.generic import can_return_tuple
  28. from .configuration_swin import SwinConfig
  29. logger = logging.get_logger(__name__)
  30. # drop_path, SwinPatchEmbeddings, SwinPatchMerging and SwinDropPath are from the timm library.
  31. @dataclass
  32. @auto_docstring(
  33. custom_intro="""
  34. Swin encoder's outputs, with potential hidden states and attentions.
  35. """
  36. )
  37. class SwinEncoderOutput(ModelOutput):
  38. r"""
  39. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  40. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  41. shape `(batch_size, hidden_size, height, width)`.
  42. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  43. include the spatial dimensions.
  44. """
  45. last_hidden_state: torch.FloatTensor | None = None
  46. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  47. attentions: tuple[torch.FloatTensor, ...] | None = None
  48. reshaped_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  49. @dataclass
  50. @auto_docstring(
  51. custom_intro="""
  52. Swin model's outputs that also contains a pooling of the last hidden states.
  53. """
  54. )
  55. class SwinModelOutput(ModelOutput):
  56. r"""
  57. pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
  58. Average pooling of the last layer hidden-state.
  59. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  60. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  61. shape `(batch_size, hidden_size, height, width)`.
  62. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  63. include the spatial dimensions.
  64. """
  65. last_hidden_state: torch.FloatTensor | None = None
  66. pooler_output: torch.FloatTensor | None = None
  67. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  68. attentions: tuple[torch.FloatTensor, ...] | None = None
  69. reshaped_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  70. @dataclass
  71. @auto_docstring(
  72. custom_intro="""
  73. Swin masked image model outputs.
  74. """
  75. )
  76. class SwinMaskedImageModelingOutput(ModelOutput):
  77. r"""
  78. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
  79. Masked image modeling (MLM) loss.
  80. reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  81. Reconstructed pixel values.
  82. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  83. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  84. shape `(batch_size, hidden_size, height, width)`.
  85. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  86. include the spatial dimensions.
  87. """
  88. loss: torch.FloatTensor | None = None
  89. reconstruction: torch.FloatTensor | None = None
  90. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  91. attentions: tuple[torch.FloatTensor, ...] | None = None
  92. reshaped_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  93. @dataclass
  94. @auto_docstring(
  95. custom_intro="""
  96. Swin outputs for image classification.
  97. """
  98. )
  99. class SwinImageClassifierOutput(ModelOutput):
  100. r"""
  101. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  102. Classification (or regression if config.num_labels==1) loss.
  103. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  104. Classification (or regression if config.num_labels==1) scores (before SoftMax).
  105. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  106. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  107. shape `(batch_size, hidden_size, height, width)`.
  108. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  109. include the spatial dimensions.
  110. """
  111. loss: torch.FloatTensor | None = None
  112. logits: torch.FloatTensor | None = None
  113. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  114. attentions: tuple[torch.FloatTensor, ...] | None = None
  115. reshaped_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  116. def window_partition(input_feature, window_size):
  117. """
  118. Partitions the given input into windows.
  119. """
  120. batch_size, height, width, num_channels = input_feature.shape
  121. input_feature = input_feature.view(
  122. batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
  123. )
  124. windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
  125. return windows
  126. def window_reverse(windows, window_size, height, width):
  127. """
  128. Merges windows to produce higher resolution features.
  129. """
  130. num_channels = windows.shape[-1]
  131. windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
  132. windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
  133. return windows
  134. class SwinEmbeddings(nn.Module):
  135. """
  136. Construct the patch and position embeddings. Optionally, also the mask token.
  137. """
  138. def __init__(self, config, use_mask_token=False):
  139. super().__init__()
  140. self.patch_embeddings = SwinPatchEmbeddings(config)
  141. num_patches = self.patch_embeddings.num_patches
  142. self.patch_grid = self.patch_embeddings.grid_size
  143. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
  144. if config.use_absolute_embeddings:
  145. self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
  146. else:
  147. self.position_embeddings = None
  148. self.norm = nn.LayerNorm(config.embed_dim)
  149. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  150. self.patch_size = config.patch_size
  151. self.config = config
  152. # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
  153. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  154. """
  155. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  156. images. This method is also adapted to support torch.jit tracing.
  157. Adapted from:
  158. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  159. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  160. """
  161. num_patches = embeddings.shape[1] - 1
  162. num_positions = self.position_embeddings.shape[1] - 1
  163. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  164. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  165. return self.position_embeddings
  166. class_pos_embed = self.position_embeddings[:, :1]
  167. patch_pos_embed = self.position_embeddings[:, 1:]
  168. dim = embeddings.shape[-1]
  169. new_height = height // self.patch_size
  170. new_width = width // self.patch_size
  171. sqrt_num_positions = torch_int(num_positions**0.5)
  172. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  173. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  174. patch_pos_embed = nn.functional.interpolate(
  175. patch_pos_embed,
  176. size=(new_height, new_width),
  177. mode="bicubic",
  178. align_corners=False,
  179. )
  180. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  181. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  182. def forward(
  183. self,
  184. pixel_values: torch.FloatTensor | None,
  185. bool_masked_pos: torch.BoolTensor | None = None,
  186. interpolate_pos_encoding: bool = False,
  187. ) -> tuple[torch.Tensor]:
  188. _, num_channels, height, width = pixel_values.shape
  189. embeddings, output_dimensions = self.patch_embeddings(pixel_values)
  190. embeddings = self.norm(embeddings)
  191. batch_size, seq_len, _ = embeddings.size()
  192. if bool_masked_pos is not None:
  193. mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
  194. # replace the masked visual tokens by mask_tokens
  195. mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
  196. embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
  197. if self.position_embeddings is not None:
  198. if interpolate_pos_encoding:
  199. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  200. else:
  201. embeddings = embeddings + self.position_embeddings
  202. embeddings = self.dropout(embeddings)
  203. return embeddings, output_dimensions
  204. class SwinPatchEmbeddings(nn.Module):
  205. """
  206. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  207. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  208. Transformer.
  209. """
  210. def __init__(self, config):
  211. super().__init__()
  212. image_size, patch_size = config.image_size, config.patch_size
  213. num_channels, hidden_size = config.num_channels, config.embed_dim
  214. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  215. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  216. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  217. self.image_size = image_size
  218. self.patch_size = patch_size
  219. self.num_channels = num_channels
  220. self.num_patches = num_patches
  221. self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
  222. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  223. def maybe_pad(self, pixel_values, height, width):
  224. if width % self.patch_size[1] != 0:
  225. pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
  226. pixel_values = nn.functional.pad(pixel_values, pad_values)
  227. if height % self.patch_size[0] != 0:
  228. pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
  229. pixel_values = nn.functional.pad(pixel_values, pad_values)
  230. return pixel_values
  231. def forward(self, pixel_values: torch.FloatTensor | None) -> tuple[torch.Tensor, tuple[int]]:
  232. _, num_channels, height, width = pixel_values.shape
  233. # pad the input to be divisible by self.patch_size, if needed
  234. pixel_values = self.maybe_pad(pixel_values, height, width)
  235. embeddings = self.projection(pixel_values)
  236. _, _, height, width = embeddings.shape
  237. output_dimensions = (height, width)
  238. embeddings = embeddings.flatten(2).transpose(1, 2)
  239. return embeddings, output_dimensions
  240. class SwinPatchMerging(nn.Module):
  241. """
  242. Patch Merging Layer.
  243. Args:
  244. input_resolution (`tuple[int]`):
  245. Resolution of input feature.
  246. dim (`int`):
  247. Number of input channels.
  248. norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
  249. Normalization layer class.
  250. """
  251. def __init__(self, input_resolution: tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
  252. super().__init__()
  253. self.input_resolution = input_resolution
  254. self.dim = dim
  255. self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
  256. self.norm = norm_layer(4 * dim)
  257. def maybe_pad(self, input_feature, height, width):
  258. should_pad = (height % 2 == 1) or (width % 2 == 1)
  259. if should_pad:
  260. pad_values = (0, 0, 0, width % 2, 0, height % 2)
  261. input_feature = nn.functional.pad(input_feature, pad_values)
  262. return input_feature
  263. def forward(self, input_feature: torch.Tensor, input_dimensions: tuple[int, int]) -> torch.Tensor:
  264. height, width = input_dimensions
  265. # `dim` is height * width
  266. batch_size, dim, num_channels = input_feature.shape
  267. input_feature = input_feature.view(batch_size, height, width, num_channels)
  268. # pad input to be divisible by width and height, if needed
  269. input_feature = self.maybe_pad(input_feature, height, width)
  270. # [batch_size, height/2, width/2, num_channels]
  271. input_feature_0 = input_feature[:, 0::2, 0::2, :]
  272. # [batch_size, height/2, width/2, num_channels]
  273. input_feature_1 = input_feature[:, 1::2, 0::2, :]
  274. # [batch_size, height/2, width/2, num_channels]
  275. input_feature_2 = input_feature[:, 0::2, 1::2, :]
  276. # [batch_size, height/2, width/2, num_channels]
  277. input_feature_3 = input_feature[:, 1::2, 1::2, :]
  278. # batch_size height/2 width/2 4*num_channels
  279. input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
  280. input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C
  281. input_feature = self.norm(input_feature)
  282. input_feature = self.reduction(input_feature)
  283. return input_feature
  284. # Copied from transformers.models.beit.modeling_beit.drop_path
  285. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  286. """
  287. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  288. """
  289. if drop_prob == 0.0 or not training:
  290. return input
  291. keep_prob = 1 - drop_prob
  292. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  293. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  294. random_tensor.floor_() # binarize
  295. output = input.div(keep_prob) * random_tensor
  296. return output
  297. # Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Swin
  298. class SwinDropPath(nn.Module):
  299. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  300. def __init__(self, drop_prob: float | None = None) -> None:
  301. super().__init__()
  302. self.drop_prob = drop_prob
  303. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  304. return drop_path(hidden_states, self.drop_prob, self.training)
  305. def extra_repr(self) -> str:
  306. return f"p={self.drop_prob}"
  307. class SwinSelfAttention(nn.Module):
  308. def __init__(self, config, dim, num_heads, window_size):
  309. super().__init__()
  310. if dim % num_heads != 0:
  311. raise ValueError(
  312. f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
  313. )
  314. self.num_attention_heads = num_heads
  315. self.attention_head_size = int(dim / num_heads)
  316. self.all_head_size = self.num_attention_heads * self.attention_head_size
  317. self.window_size = (
  318. window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
  319. )
  320. self.relative_position_bias_table = nn.Parameter(
  321. torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
  322. )
  323. self.register_buffer("relative_position_index", self.create_relative_position_index())
  324. self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
  325. self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
  326. self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
  327. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  328. def forward(
  329. self,
  330. hidden_states: torch.Tensor,
  331. attention_mask: torch.FloatTensor | None = None,
  332. output_attentions: bool | None = False,
  333. ) -> tuple[torch.Tensor]:
  334. batch_size, dim, num_channels = hidden_states.shape
  335. hidden_shape = (batch_size, dim, -1, self.attention_head_size)
  336. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  337. key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
  338. value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
  339. # Take the dot product between "query" and "key" to get the raw attention scores.
  340. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  341. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  342. relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
  343. relative_position_bias = relative_position_bias.view(
  344. self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
  345. )
  346. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
  347. attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
  348. if attention_mask is not None:
  349. # Apply the attention mask is (precomputed for all layers in SwinModel forward() function)
  350. mask_shape = attention_mask.shape[0]
  351. attention_scores = attention_scores.view(
  352. batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
  353. )
  354. attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0)
  355. attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)
  356. # Normalize the attention scores to probabilities.
  357. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  358. # This is actually dropping out entire tokens to attend to, which might
  359. # seem a bit unusual, but is taken from the original Transformer paper.
  360. attention_probs = self.dropout(attention_probs)
  361. context_layer = torch.matmul(attention_probs, value_layer)
  362. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  363. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  364. context_layer = context_layer.view(new_context_layer_shape)
  365. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  366. return outputs
  367. def create_relative_position_index(self):
  368. # get pair-wise relative position index for each token inside the window
  369. coords_h = torch.arange(self.window_size[0])
  370. coords_w = torch.arange(self.window_size[1])
  371. coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))
  372. coords_flatten = torch.flatten(coords, 1)
  373. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
  374. relative_coords = relative_coords.permute(1, 2, 0).contiguous()
  375. relative_coords[:, :, 0] += self.window_size[0] - 1
  376. relative_coords[:, :, 1] += self.window_size[1] - 1
  377. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  378. relative_position_index = relative_coords.sum(-1)
  379. return relative_position_index
  380. class SwinSelfOutput(nn.Module):
  381. def __init__(self, config, dim):
  382. super().__init__()
  383. self.dense = nn.Linear(dim, dim)
  384. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  385. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  386. hidden_states = self.dense(hidden_states)
  387. hidden_states = self.dropout(hidden_states)
  388. return hidden_states
  389. class SwinAttention(nn.Module):
  390. def __init__(self, config, dim, num_heads, window_size):
  391. super().__init__()
  392. self.self = SwinSelfAttention(config, dim, num_heads, window_size)
  393. self.output = SwinSelfOutput(config, dim)
  394. def forward(
  395. self,
  396. hidden_states: torch.Tensor,
  397. attention_mask: torch.FloatTensor | None = None,
  398. output_attentions: bool | None = False,
  399. ) -> tuple[torch.Tensor]:
  400. self_outputs = self.self(hidden_states, attention_mask, output_attentions)
  401. attention_output = self.output(self_outputs[0], hidden_states)
  402. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  403. return outputs
  404. class SwinIntermediate(nn.Module):
  405. def __init__(self, config, dim):
  406. super().__init__()
  407. self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
  408. if isinstance(config.hidden_act, str):
  409. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  410. else:
  411. self.intermediate_act_fn = config.hidden_act
  412. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  413. hidden_states = self.dense(hidden_states)
  414. hidden_states = self.intermediate_act_fn(hidden_states)
  415. return hidden_states
  416. class SwinOutput(nn.Module):
  417. def __init__(self, config, dim):
  418. super().__init__()
  419. self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
  420. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  421. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  422. hidden_states = self.dense(hidden_states)
  423. hidden_states = self.dropout(hidden_states)
  424. return hidden_states
  425. class SwinLayer(nn.Module):
  426. def __init__(self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0):
  427. super().__init__()
  428. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  429. self.shift_size = shift_size
  430. self.window_size = config.window_size
  431. self.input_resolution = input_resolution
  432. self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  433. self.attention = SwinAttention(config, dim, num_heads, window_size=self.window_size)
  434. self.drop_path = SwinDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
  435. self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  436. self.intermediate = SwinIntermediate(config, dim)
  437. self.output = SwinOutput(config, dim)
  438. def set_shift_and_window_size(self, input_resolution):
  439. if min(input_resolution) <= self.window_size:
  440. # if window size is larger than input resolution, we don't partition windows
  441. self.shift_size = torch_int(0)
  442. self.window_size = (
  443. torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
  444. )
  445. def get_attn_mask(self, height, width, dtype, device):
  446. if self.shift_size > 0:
  447. # calculate attention mask for SW-MSA
  448. img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
  449. height_slices = (
  450. slice(0, -self.window_size),
  451. slice(-self.window_size, -self.shift_size),
  452. slice(-self.shift_size, None),
  453. )
  454. width_slices = (
  455. slice(0, -self.window_size),
  456. slice(-self.window_size, -self.shift_size),
  457. slice(-self.shift_size, None),
  458. )
  459. count = 0
  460. for height_slice in height_slices:
  461. for width_slice in width_slices:
  462. img_mask[:, height_slice, width_slice, :] = count
  463. count += 1
  464. mask_windows = window_partition(img_mask, self.window_size)
  465. mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
  466. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  467. attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0).masked_fill(attn_mask == 0, 0.0)
  468. else:
  469. attn_mask = None
  470. return attn_mask
  471. def maybe_pad(self, hidden_states, height, width):
  472. pad_right = (self.window_size - width % self.window_size) % self.window_size
  473. pad_bottom = (self.window_size - height % self.window_size) % self.window_size
  474. pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
  475. hidden_states = nn.functional.pad(hidden_states, pad_values)
  476. return hidden_states, pad_values
  477. def forward(
  478. self,
  479. hidden_states: torch.Tensor,
  480. input_dimensions: tuple[int, int],
  481. output_attentions: bool | None = False,
  482. always_partition: bool | None = False,
  483. ) -> tuple[torch.Tensor, torch.Tensor]:
  484. if not always_partition:
  485. self.set_shift_and_window_size(input_dimensions)
  486. else:
  487. pass
  488. height, width = input_dimensions
  489. batch_size, _, channels = hidden_states.size()
  490. shortcut = hidden_states
  491. hidden_states = self.layernorm_before(hidden_states)
  492. hidden_states = hidden_states.view(batch_size, height, width, channels)
  493. # pad hidden_states to multiples of window size
  494. hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
  495. _, height_pad, width_pad, _ = hidden_states.shape
  496. # cyclic shift
  497. if self.shift_size > 0:
  498. shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
  499. else:
  500. shifted_hidden_states = hidden_states
  501. # partition windows
  502. hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
  503. hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
  504. attn_mask = self.get_attn_mask(
  505. height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device
  506. )
  507. attention_outputs = self.attention(hidden_states_windows, attn_mask, output_attentions=output_attentions)
  508. attention_output = attention_outputs[0]
  509. attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
  510. shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad)
  511. # reverse cyclic shift
  512. if self.shift_size > 0:
  513. attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
  514. else:
  515. attention_windows = shifted_windows
  516. was_padded = pad_values[3] > 0 or pad_values[5] > 0
  517. if was_padded:
  518. attention_windows = attention_windows[:, :height, :width, :].contiguous()
  519. attention_windows = attention_windows.view(batch_size, height * width, channels)
  520. hidden_states = shortcut + self.drop_path(attention_windows)
  521. layer_output = self.layernorm_after(hidden_states)
  522. layer_output = self.intermediate(layer_output)
  523. layer_output = hidden_states + self.output(layer_output)
  524. layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
  525. return layer_outputs
  526. class SwinStage(GradientCheckpointingLayer):
  527. def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
  528. super().__init__()
  529. self.config = config
  530. self.dim = dim
  531. self.blocks = nn.ModuleList(
  532. [
  533. SwinLayer(
  534. config=config,
  535. dim=dim,
  536. input_resolution=input_resolution,
  537. num_heads=num_heads,
  538. drop_path_rate=drop_path[i],
  539. shift_size=0 if (i % 2 == 0) else config.window_size // 2,
  540. )
  541. for i in range(depth)
  542. ]
  543. )
  544. # patch merging layer
  545. if downsample is not None:
  546. self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm)
  547. else:
  548. self.downsample = None
  549. self.pointing = False
  550. def forward(
  551. self,
  552. hidden_states: torch.Tensor,
  553. input_dimensions: tuple[int, int],
  554. output_attentions: bool | None = False,
  555. always_partition: bool | None = False,
  556. ) -> tuple[torch.Tensor]:
  557. height, width = input_dimensions
  558. for i, layer_module in enumerate(self.blocks):
  559. layer_outputs = layer_module(hidden_states, input_dimensions, output_attentions, always_partition)
  560. hidden_states = layer_outputs[0]
  561. hidden_states_before_downsampling = hidden_states
  562. if self.downsample is not None:
  563. height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
  564. output_dimensions = (height, width, height_downsampled, width_downsampled)
  565. hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions)
  566. else:
  567. output_dimensions = (height, width, height, width)
  568. stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)
  569. if output_attentions:
  570. stage_outputs += layer_outputs[1:]
  571. return stage_outputs
  572. class SwinEncoder(nn.Module):
  573. def __init__(self, config, grid_size):
  574. super().__init__()
  575. self.num_layers = len(config.depths)
  576. self.config = config
  577. dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
  578. self.layers = nn.ModuleList(
  579. [
  580. SwinStage(
  581. config=config,
  582. dim=int(config.embed_dim * 2**i_layer),
  583. input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
  584. depth=config.depths[i_layer],
  585. num_heads=config.num_heads[i_layer],
  586. drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
  587. downsample=SwinPatchMerging if (i_layer < self.num_layers - 1) else None,
  588. )
  589. for i_layer in range(self.num_layers)
  590. ]
  591. )
  592. self.gradient_checkpointing = False
  593. def forward(
  594. self,
  595. hidden_states: torch.Tensor,
  596. input_dimensions: tuple[int, int],
  597. output_attentions: bool | None = False,
  598. output_hidden_states: bool | None = False,
  599. output_hidden_states_before_downsampling: bool | None = False,
  600. always_partition: bool | None = False,
  601. return_dict: bool | None = True,
  602. ) -> tuple | SwinEncoderOutput:
  603. all_hidden_states = () if output_hidden_states else None
  604. all_reshaped_hidden_states = () if output_hidden_states else None
  605. all_self_attentions = () if output_attentions else None
  606. if output_hidden_states:
  607. batch_size, _, hidden_size = hidden_states.shape
  608. # rearrange b (h w) c -> b c h w
  609. reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
  610. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  611. all_hidden_states += (hidden_states,)
  612. all_reshaped_hidden_states += (reshaped_hidden_state,)
  613. for i, layer_module in enumerate(self.layers):
  614. layer_outputs = layer_module(hidden_states, input_dimensions, output_attentions, always_partition)
  615. hidden_states = layer_outputs[0]
  616. hidden_states_before_downsampling = layer_outputs[1]
  617. output_dimensions = layer_outputs[2]
  618. input_dimensions = (output_dimensions[-2], output_dimensions[-1])
  619. if output_hidden_states and output_hidden_states_before_downsampling:
  620. batch_size, _, hidden_size = hidden_states_before_downsampling.shape
  621. # rearrange b (h w) c -> b c h w
  622. # here we use the original (not downsampled) height and width
  623. reshaped_hidden_state = hidden_states_before_downsampling.view(
  624. batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
  625. )
  626. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  627. all_hidden_states += (hidden_states_before_downsampling,)
  628. all_reshaped_hidden_states += (reshaped_hidden_state,)
  629. elif output_hidden_states and not output_hidden_states_before_downsampling:
  630. batch_size, _, hidden_size = hidden_states.shape
  631. # rearrange b (h w) c -> b c h w
  632. reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
  633. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  634. all_hidden_states += (hidden_states,)
  635. all_reshaped_hidden_states += (reshaped_hidden_state,)
  636. if output_attentions:
  637. all_self_attentions += layer_outputs[3:]
  638. if not return_dict:
  639. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  640. return SwinEncoderOutput(
  641. last_hidden_state=hidden_states,
  642. hidden_states=all_hidden_states,
  643. attentions=all_self_attentions,
  644. reshaped_hidden_states=all_reshaped_hidden_states,
  645. )
  646. @auto_docstring
  647. class SwinPreTrainedModel(PreTrainedModel):
  648. config: SwinConfig
  649. base_model_prefix = "swin"
  650. main_input_name = "pixel_values"
  651. input_modalities = ("image",)
  652. supports_gradient_checkpointing = True
  653. _no_split_modules = ["SwinStage"]
  654. @torch.no_grad()
  655. def _init_weights(self, module):
  656. """Initialize the weights"""
  657. super()._init_weights(module)
  658. if isinstance(module, SwinEmbeddings):
  659. if module.mask_token is not None:
  660. init.zeros_(module.mask_token)
  661. if module.position_embeddings is not None:
  662. init.zeros_(module.position_embeddings)
  663. elif isinstance(module, SwinSelfAttention):
  664. init.zeros_(module.relative_position_bias_table)
  665. init.copy_(module.relative_position_index, module.create_relative_position_index())
  666. @auto_docstring
  667. class SwinModel(SwinPreTrainedModel):
  668. def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
  669. r"""
  670. add_pooling_layer (`bool`, *optional*, defaults to `True`):
  671. Whether or not to apply pooling layer.
  672. use_mask_token (`bool`, *optional*, defaults to `False`):
  673. Whether or not to create and apply mask tokens in the embedding layer.
  674. """
  675. super().__init__(config)
  676. self.config = config
  677. self.num_layers = len(config.depths)
  678. self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
  679. self.embeddings = SwinEmbeddings(config, use_mask_token=use_mask_token)
  680. self.encoder = SwinEncoder(config, self.embeddings.patch_grid)
  681. self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
  682. self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
  683. # Initialize weights and apply final processing
  684. self.post_init()
  685. def get_input_embeddings(self):
  686. return self.embeddings.patch_embeddings
  687. @auto_docstring
  688. def forward(
  689. self,
  690. pixel_values: torch.FloatTensor | None = None,
  691. bool_masked_pos: torch.BoolTensor | None = None,
  692. output_attentions: bool | None = None,
  693. output_hidden_states: bool | None = None,
  694. interpolate_pos_encoding: bool = False,
  695. return_dict: bool | None = None,
  696. **kwargs,
  697. ) -> tuple | SwinModelOutput:
  698. r"""
  699. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
  700. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  701. """
  702. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  703. output_hidden_states = (
  704. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  705. )
  706. return_dict = return_dict if return_dict is not None else self.config.return_dict
  707. if pixel_values is None:
  708. raise ValueError("You have to specify pixel_values")
  709. embedding_output, input_dimensions = self.embeddings(
  710. pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
  711. )
  712. encoder_outputs = self.encoder(
  713. embedding_output,
  714. input_dimensions,
  715. output_attentions=output_attentions,
  716. output_hidden_states=output_hidden_states,
  717. return_dict=return_dict,
  718. )
  719. sequence_output = encoder_outputs[0]
  720. sequence_output = self.layernorm(sequence_output)
  721. pooled_output = None
  722. if self.pooler is not None:
  723. pooled_output = self.pooler(sequence_output.transpose(1, 2))
  724. pooled_output = torch.flatten(pooled_output, 1)
  725. if not return_dict:
  726. output = (sequence_output, pooled_output) + encoder_outputs[1:]
  727. return output
  728. return SwinModelOutput(
  729. last_hidden_state=sequence_output,
  730. pooler_output=pooled_output,
  731. hidden_states=encoder_outputs.hidden_states,
  732. attentions=encoder_outputs.attentions,
  733. reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
  734. )
  735. @auto_docstring(
  736. custom_intro="""
  737. Swin Model with a decoder on top for masked image modeling, as proposed in [SimMIM](https://huggingface.co/papers/2111.09886).
  738. <Tip>
  739. Note that we provide a script to pre-train this model on custom data in our [examples
  740. directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
  741. </Tip>
  742. """
  743. )
  744. class SwinForMaskedImageModeling(SwinPreTrainedModel):
  745. def __init__(self, config):
  746. super().__init__(config)
  747. self.swin = SwinModel(config, add_pooling_layer=False, use_mask_token=True)
  748. num_features = int(config.embed_dim * 2 ** (config.num_layers - 1))
  749. self.decoder = nn.Sequential(
  750. nn.Conv2d(
  751. in_channels=num_features, out_channels=config.encoder_stride**2 * config.num_channels, kernel_size=1
  752. ),
  753. nn.PixelShuffle(config.encoder_stride),
  754. )
  755. # Initialize weights and apply final processing
  756. self.post_init()
  757. @auto_docstring
  758. def forward(
  759. self,
  760. pixel_values: torch.FloatTensor | None = None,
  761. bool_masked_pos: torch.BoolTensor | None = None,
  762. output_attentions: bool | None = None,
  763. output_hidden_states: bool | None = None,
  764. interpolate_pos_encoding: bool = False,
  765. return_dict: bool | None = None,
  766. **kwargs,
  767. ) -> tuple | SwinMaskedImageModelingOutput:
  768. r"""
  769. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
  770. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  771. Examples:
  772. ```python
  773. >>> from transformers import AutoImageProcessor, SwinForMaskedImageModeling
  774. >>> import torch
  775. >>> from PIL import Image
  776. >>> import httpx
  777. >>> from io import BytesIO
  778. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  779. >>> with httpx.stream("GET", url) as response:
  780. ... image = Image.open(BytesIO(response.read()))
  781. >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-base-simmim-window6-192")
  782. >>> model = SwinForMaskedImageModeling.from_pretrained("microsoft/swin-base-simmim-window6-192")
  783. >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
  784. >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
  785. >>> # create random boolean mask of shape (batch_size, num_patches)
  786. >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
  787. >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
  788. >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
  789. >>> list(reconstructed_pixel_values.shape)
  790. [1, 3, 192, 192]
  791. ```"""
  792. return_dict = return_dict if return_dict is not None else self.config.return_dict
  793. outputs = self.swin(
  794. pixel_values,
  795. bool_masked_pos=bool_masked_pos,
  796. output_attentions=output_attentions,
  797. output_hidden_states=output_hidden_states,
  798. interpolate_pos_encoding=interpolate_pos_encoding,
  799. return_dict=return_dict,
  800. )
  801. sequence_output = outputs[0]
  802. # Reshape to (batch_size, num_channels, height, width)
  803. sequence_output = sequence_output.transpose(1, 2)
  804. batch_size, num_channels, sequence_length = sequence_output.shape
  805. height = width = math.floor(sequence_length**0.5)
  806. sequence_output = sequence_output.reshape(batch_size, num_channels, height, width)
  807. # Reconstruct pixel values
  808. reconstructed_pixel_values = self.decoder(sequence_output)
  809. masked_im_loss = None
  810. if bool_masked_pos is not None:
  811. size = self.config.image_size // self.config.patch_size
  812. bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
  813. mask = (
  814. bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
  815. .repeat_interleave(self.config.patch_size, 2)
  816. .unsqueeze(1)
  817. .contiguous()
  818. )
  819. reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
  820. masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
  821. if not return_dict:
  822. output = (reconstructed_pixel_values,) + outputs[2:]
  823. return ((masked_im_loss,) + output) if masked_im_loss is not None else output
  824. return SwinMaskedImageModelingOutput(
  825. loss=masked_im_loss,
  826. reconstruction=reconstructed_pixel_values,
  827. hidden_states=outputs.hidden_states,
  828. attentions=outputs.attentions,
  829. reshaped_hidden_states=outputs.reshaped_hidden_states,
  830. )
  831. @auto_docstring(
  832. custom_intro="""
  833. Swin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
  834. the [CLS] token) e.g. for ImageNet.
  835. <Tip>
  836. Note that it's possible to fine-tune Swin on higher resolution images than the ones it has been trained on, by
  837. setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
  838. position embeddings to the higher resolution.
  839. </Tip>
  840. """
  841. )
  842. class SwinForImageClassification(SwinPreTrainedModel):
  843. def __init__(self, config):
  844. super().__init__(config)
  845. self.num_labels = config.num_labels
  846. self.swin = SwinModel(config)
  847. # Classifier head
  848. self.classifier = (
  849. nn.Linear(self.swin.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
  850. )
  851. # Initialize weights and apply final processing
  852. self.post_init()
  853. @auto_docstring
  854. def forward(
  855. self,
  856. pixel_values: torch.FloatTensor | None = None,
  857. labels: torch.LongTensor | None = None,
  858. output_attentions: bool | None = None,
  859. output_hidden_states: bool | None = None,
  860. interpolate_pos_encoding: bool = False,
  861. return_dict: bool | None = None,
  862. **kwargs,
  863. ) -> tuple | SwinImageClassifierOutput:
  864. r"""
  865. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  866. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  867. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  868. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  869. """
  870. return_dict = return_dict if return_dict is not None else self.config.return_dict
  871. outputs = self.swin(
  872. pixel_values,
  873. output_attentions=output_attentions,
  874. output_hidden_states=output_hidden_states,
  875. interpolate_pos_encoding=interpolate_pos_encoding,
  876. return_dict=return_dict,
  877. )
  878. pooled_output = outputs[1]
  879. logits = self.classifier(pooled_output)
  880. loss = None
  881. if labels is not None:
  882. loss = self.loss_function(labels, logits, self.config)
  883. if not return_dict:
  884. output = (logits,) + outputs[2:]
  885. return ((loss,) + output) if loss is not None else output
  886. return SwinImageClassifierOutput(
  887. loss=loss,
  888. logits=logits,
  889. hidden_states=outputs.hidden_states,
  890. attentions=outputs.attentions,
  891. reshaped_hidden_states=outputs.reshaped_hidden_states,
  892. )
  893. @auto_docstring(
  894. custom_intro="""
  895. Swin backbone, to be used with frameworks like DETR and MaskFormer.
  896. """
  897. )
  898. class SwinBackbone(BackboneMixin, SwinPreTrainedModel):
  899. def __init__(self, config: SwinConfig):
  900. super().__init__(config)
  901. self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
  902. self.embeddings = SwinEmbeddings(config)
  903. self.encoder = SwinEncoder(config, self.embeddings.patch_grid)
  904. # Add layer norms to hidden states of out_features
  905. hidden_states_norms = {}
  906. for stage, num_channels in zip(self.out_features, self.channels):
  907. hidden_states_norms[stage] = nn.LayerNorm(num_channels)
  908. self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
  909. # Initialize weights and apply final processing
  910. self.post_init()
  911. def get_input_embeddings(self):
  912. return self.embeddings.patch_embeddings
  913. @can_return_tuple
  914. @filter_output_hidden_states
  915. def forward(
  916. self,
  917. pixel_values: torch.Tensor,
  918. output_hidden_states: bool | None = None,
  919. output_attentions: bool | None = None,
  920. return_dict: bool | None = None,
  921. **kwargs,
  922. ) -> BackboneOutput:
  923. """
  924. Returns:
  925. Examples:
  926. ```python
  927. >>> from transformers import AutoImageProcessor, AutoBackbone
  928. >>> import torch
  929. >>> from PIL import Image
  930. >>> import httpx
  931. >>> from io import BytesIO
  932. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  933. >>> with httpx.stream("GET", url) as response:
  934. ... image = Image.open(BytesIO(response.read()))
  935. >>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224")
  936. >>> model = AutoBackbone.from_pretrained(
  937. ... "microsoft/swin-tiny-patch4-window7-224", out_features=["stage1", "stage2", "stage3", "stage4"]
  938. ... )
  939. >>> inputs = processor(image, return_tensors="pt")
  940. >>> outputs = model(**inputs)
  941. >>> feature_maps = outputs.feature_maps
  942. >>> list(feature_maps[-1].shape)
  943. [1, 768, 7, 7]
  944. ```"""
  945. return_dict = return_dict if return_dict is not None else self.config.return_dict
  946. output_hidden_states = (
  947. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  948. )
  949. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  950. embedding_output, input_dimensions = self.embeddings(pixel_values)
  951. outputs = self.encoder(
  952. embedding_output,
  953. input_dimensions,
  954. output_attentions=output_attentions,
  955. output_hidden_states=True,
  956. output_hidden_states_before_downsampling=True,
  957. always_partition=True,
  958. return_dict=True,
  959. )
  960. hidden_states = outputs.reshaped_hidden_states
  961. feature_maps = ()
  962. for stage, hidden_state in zip(self.stage_names, hidden_states):
  963. if stage in self.out_features:
  964. batch_size, num_channels, height, width = hidden_state.shape
  965. hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
  966. hidden_state = hidden_state.view(batch_size, height * width, num_channels)
  967. hidden_state = self.hidden_states_norms[stage](hidden_state)
  968. hidden_state = hidden_state.view(batch_size, height, width, num_channels)
  969. hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
  970. feature_maps += (hidden_state,)
  971. if not return_dict:
  972. output = (feature_maps,)
  973. if output_hidden_states:
  974. output += (outputs.hidden_states,)
  975. return output
  976. return BackboneOutput(
  977. feature_maps=feature_maps,
  978. hidden_states=outputs.hidden_states if output_hidden_states else None,
  979. attentions=outputs.attentions,
  980. )
  981. __all__ = [
  982. "SwinForImageClassification",
  983. "SwinForMaskedImageModeling",
  984. "SwinModel",
  985. "SwinPreTrainedModel",
  986. "SwinBackbone",
  987. ]