| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209 |
- # Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch Swin Transformer model."""
- import collections.abc
- import math
- from dataclasses import dataclass
- import torch
- from torch import nn
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...backbone_utils import BackboneMixin, filter_output_hidden_states
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import BackboneOutput
- from ...modeling_utils import PreTrainedModel
- from ...utils import ModelOutput, auto_docstring, logging, torch_int
- from ...utils.generic import can_return_tuple
- from .configuration_swin import SwinConfig
- logger = logging.get_logger(__name__)
- # drop_path, SwinPatchEmbeddings, SwinPatchMerging and SwinDropPath are from the timm library.
- @dataclass
- @auto_docstring(
- custom_intro="""
- Swin encoder's outputs, with potential hidden states and attentions.
- """
- )
- class SwinEncoderOutput(ModelOutput):
- r"""
- reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
- shape `(batch_size, hidden_size, height, width)`.
- Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
- include the spatial dimensions.
- """
- last_hidden_state: torch.FloatTensor | None = None
- hidden_states: tuple[torch.FloatTensor, ...] | None = None
- attentions: tuple[torch.FloatTensor, ...] | None = None
- reshaped_hidden_states: tuple[torch.FloatTensor, ...] | None = None
- @dataclass
- @auto_docstring(
- custom_intro="""
- Swin model's outputs that also contains a pooling of the last hidden states.
- """
- )
- class SwinModelOutput(ModelOutput):
- r"""
- pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
- Average pooling of the last layer hidden-state.
- reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
- shape `(batch_size, hidden_size, height, width)`.
- Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
- include the spatial dimensions.
- """
- last_hidden_state: torch.FloatTensor | None = None
- pooler_output: torch.FloatTensor | None = None
- hidden_states: tuple[torch.FloatTensor, ...] | None = None
- attentions: tuple[torch.FloatTensor, ...] | None = None
- reshaped_hidden_states: tuple[torch.FloatTensor, ...] | None = None
- @dataclass
- @auto_docstring(
- custom_intro="""
- Swin masked image model outputs.
- """
- )
- class SwinMaskedImageModelingOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
- Masked image modeling (MLM) loss.
- reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
- Reconstructed pixel values.
- reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
- shape `(batch_size, hidden_size, height, width)`.
- Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
- include the spatial dimensions.
- """
- loss: torch.FloatTensor | None = None
- reconstruction: torch.FloatTensor | None = None
- hidden_states: tuple[torch.FloatTensor, ...] | None = None
- attentions: tuple[torch.FloatTensor, ...] | None = None
- reshaped_hidden_states: tuple[torch.FloatTensor, ...] | None = None
- @dataclass
- @auto_docstring(
- custom_intro="""
- Swin outputs for image classification.
- """
- )
- class SwinImageClassifierOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Classification (or regression if config.num_labels==1) loss.
- logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
- Classification (or regression if config.num_labels==1) scores (before SoftMax).
- reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
- shape `(batch_size, hidden_size, height, width)`.
- Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
- include the spatial dimensions.
- """
- loss: torch.FloatTensor | None = None
- logits: torch.FloatTensor | None = None
- hidden_states: tuple[torch.FloatTensor, ...] | None = None
- attentions: tuple[torch.FloatTensor, ...] | None = None
- reshaped_hidden_states: tuple[torch.FloatTensor, ...] | None = None
- def window_partition(input_feature, window_size):
- """
- Partitions the given input into windows.
- """
- batch_size, height, width, num_channels = input_feature.shape
- input_feature = input_feature.view(
- batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
- )
- windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
- return windows
- def window_reverse(windows, window_size, height, width):
- """
- Merges windows to produce higher resolution features.
- """
- num_channels = windows.shape[-1]
- windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
- windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
- return windows
- class SwinEmbeddings(nn.Module):
- """
- Construct the patch and position embeddings. Optionally, also the mask token.
- """
- def __init__(self, config, use_mask_token=False):
- super().__init__()
- self.patch_embeddings = SwinPatchEmbeddings(config)
- num_patches = self.patch_embeddings.num_patches
- self.patch_grid = self.patch_embeddings.grid_size
- self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
- if config.use_absolute_embeddings:
- self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
- else:
- self.position_embeddings = None
- self.norm = nn.LayerNorm(config.embed_dim)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- self.patch_size = config.patch_size
- self.config = config
- # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
- def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
- """
- This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
- images. This method is also adapted to support torch.jit tracing.
- Adapted from:
- - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
- """
- num_patches = embeddings.shape[1] - 1
- num_positions = self.position_embeddings.shape[1] - 1
- # always interpolate when tracing to ensure the exported model works for dynamic input shapes
- if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
- return self.position_embeddings
- class_pos_embed = self.position_embeddings[:, :1]
- patch_pos_embed = self.position_embeddings[:, 1:]
- dim = embeddings.shape[-1]
- new_height = height // self.patch_size
- new_width = width // self.patch_size
- sqrt_num_positions = torch_int(num_positions**0.5)
- patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
- patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
- patch_pos_embed = nn.functional.interpolate(
- patch_pos_embed,
- size=(new_height, new_width),
- mode="bicubic",
- align_corners=False,
- )
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
- return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
- def forward(
- self,
- pixel_values: torch.FloatTensor | None,
- bool_masked_pos: torch.BoolTensor | None = None,
- interpolate_pos_encoding: bool = False,
- ) -> tuple[torch.Tensor]:
- _, num_channels, height, width = pixel_values.shape
- embeddings, output_dimensions = self.patch_embeddings(pixel_values)
- embeddings = self.norm(embeddings)
- batch_size, seq_len, _ = embeddings.size()
- if bool_masked_pos is not None:
- mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
- # replace the masked visual tokens by mask_tokens
- mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
- embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
- if self.position_embeddings is not None:
- if interpolate_pos_encoding:
- embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
- else:
- embeddings = embeddings + self.position_embeddings
- embeddings = self.dropout(embeddings)
- return embeddings, output_dimensions
- class SwinPatchEmbeddings(nn.Module):
- """
- This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
- `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
- Transformer.
- """
- def __init__(self, config):
- super().__init__()
- image_size, patch_size = config.image_size, config.patch_size
- num_channels, hidden_size = config.num_channels, config.embed_dim
- image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
- patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
- self.image_size = image_size
- self.patch_size = patch_size
- self.num_channels = num_channels
- self.num_patches = num_patches
- self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
- self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
- def maybe_pad(self, pixel_values, height, width):
- if width % self.patch_size[1] != 0:
- pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
- pixel_values = nn.functional.pad(pixel_values, pad_values)
- if height % self.patch_size[0] != 0:
- pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
- pixel_values = nn.functional.pad(pixel_values, pad_values)
- return pixel_values
- def forward(self, pixel_values: torch.FloatTensor | None) -> tuple[torch.Tensor, tuple[int]]:
- _, num_channels, height, width = pixel_values.shape
- # pad the input to be divisible by self.patch_size, if needed
- pixel_values = self.maybe_pad(pixel_values, height, width)
- embeddings = self.projection(pixel_values)
- _, _, height, width = embeddings.shape
- output_dimensions = (height, width)
- embeddings = embeddings.flatten(2).transpose(1, 2)
- return embeddings, output_dimensions
- class SwinPatchMerging(nn.Module):
- """
- Patch Merging Layer.
- Args:
- input_resolution (`tuple[int]`):
- Resolution of input feature.
- dim (`int`):
- Number of input channels.
- norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
- Normalization layer class.
- """
- def __init__(self, input_resolution: tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
- super().__init__()
- self.input_resolution = input_resolution
- self.dim = dim
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
- self.norm = norm_layer(4 * dim)
- def maybe_pad(self, input_feature, height, width):
- should_pad = (height % 2 == 1) or (width % 2 == 1)
- if should_pad:
- pad_values = (0, 0, 0, width % 2, 0, height % 2)
- input_feature = nn.functional.pad(input_feature, pad_values)
- return input_feature
- def forward(self, input_feature: torch.Tensor, input_dimensions: tuple[int, int]) -> torch.Tensor:
- height, width = input_dimensions
- # `dim` is height * width
- batch_size, dim, num_channels = input_feature.shape
- input_feature = input_feature.view(batch_size, height, width, num_channels)
- # pad input to be divisible by width and height, if needed
- input_feature = self.maybe_pad(input_feature, height, width)
- # [batch_size, height/2, width/2, num_channels]
- input_feature_0 = input_feature[:, 0::2, 0::2, :]
- # [batch_size, height/2, width/2, num_channels]
- input_feature_1 = input_feature[:, 1::2, 0::2, :]
- # [batch_size, height/2, width/2, num_channels]
- input_feature_2 = input_feature[:, 0::2, 1::2, :]
- # [batch_size, height/2, width/2, num_channels]
- input_feature_3 = input_feature[:, 1::2, 1::2, :]
- # batch_size height/2 width/2 4*num_channels
- input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
- input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C
- input_feature = self.norm(input_feature)
- input_feature = self.reduction(input_feature)
- return input_feature
- # Copied from transformers.models.beit.modeling_beit.drop_path
- def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
- """
- Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
- """
- if drop_prob == 0.0 or not training:
- return input
- keep_prob = 1 - drop_prob
- shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
- random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
- random_tensor.floor_() # binarize
- output = input.div(keep_prob) * random_tensor
- return output
- # Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Swin
- class SwinDropPath(nn.Module):
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
- def __init__(self, drop_prob: float | None = None) -> None:
- super().__init__()
- self.drop_prob = drop_prob
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- return drop_path(hidden_states, self.drop_prob, self.training)
- def extra_repr(self) -> str:
- return f"p={self.drop_prob}"
- class SwinSelfAttention(nn.Module):
- def __init__(self, config, dim, num_heads, window_size):
- super().__init__()
- if dim % num_heads != 0:
- raise ValueError(
- f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
- )
- self.num_attention_heads = num_heads
- self.attention_head_size = int(dim / num_heads)
- self.all_head_size = self.num_attention_heads * self.attention_head_size
- self.window_size = (
- window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
- )
- self.relative_position_bias_table = nn.Parameter(
- torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
- )
- self.register_buffer("relative_position_index", self.create_relative_position_index())
- self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
- self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
- self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.FloatTensor | None = None,
- output_attentions: bool | None = False,
- ) -> tuple[torch.Tensor]:
- batch_size, dim, num_channels = hidden_states.shape
- hidden_shape = (batch_size, dim, -1, self.attention_head_size)
- query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
- key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
- value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
- # Take the dot product between "query" and "key" to get the raw attention scores.
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
- attention_scores = attention_scores / math.sqrt(self.attention_head_size)
- relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
- relative_position_bias = relative_position_bias.view(
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
- )
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
- attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
- if attention_mask is not None:
- # Apply the attention mask is (precomputed for all layers in SwinModel forward() function)
- mask_shape = attention_mask.shape[0]
- attention_scores = attention_scores.view(
- batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
- )
- attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0)
- attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)
- # Normalize the attention scores to probabilities.
- attention_probs = nn.functional.softmax(attention_scores, dim=-1)
- # This is actually dropping out entire tokens to attend to, which might
- # seem a bit unusual, but is taken from the original Transformer paper.
- attention_probs = self.dropout(attention_probs)
- context_layer = torch.matmul(attention_probs, value_layer)
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(new_context_layer_shape)
- outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
- return outputs
- def create_relative_position_index(self):
- # get pair-wise relative position index for each token inside the window
- coords_h = torch.arange(self.window_size[0])
- coords_w = torch.arange(self.window_size[1])
- coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))
- coords_flatten = torch.flatten(coords, 1)
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
- relative_coords = relative_coords.permute(1, 2, 0).contiguous()
- relative_coords[:, :, 0] += self.window_size[0] - 1
- relative_coords[:, :, 1] += self.window_size[1] - 1
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
- relative_position_index = relative_coords.sum(-1)
- return relative_position_index
- class SwinSelfOutput(nn.Module):
- def __init__(self, config, dim):
- super().__init__()
- self.dense = nn.Linear(dim, dim)
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
- def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- return hidden_states
- class SwinAttention(nn.Module):
- def __init__(self, config, dim, num_heads, window_size):
- super().__init__()
- self.self = SwinSelfAttention(config, dim, num_heads, window_size)
- self.output = SwinSelfOutput(config, dim)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.FloatTensor | None = None,
- output_attentions: bool | None = False,
- ) -> tuple[torch.Tensor]:
- self_outputs = self.self(hidden_states, attention_mask, output_attentions)
- attention_output = self.output(self_outputs[0], hidden_states)
- outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
- return outputs
- class SwinIntermediate(nn.Module):
- def __init__(self, config, dim):
- super().__init__()
- self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
- if isinstance(config.hidden_act, str):
- self.intermediate_act_fn = ACT2FN[config.hidden_act]
- else:
- self.intermediate_act_fn = config.hidden_act
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.intermediate_act_fn(hidden_states)
- return hidden_states
- class SwinOutput(nn.Module):
- def __init__(self, config, dim):
- super().__init__()
- self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- return hidden_states
- class SwinLayer(nn.Module):
- def __init__(self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0):
- super().__init__()
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
- self.shift_size = shift_size
- self.window_size = config.window_size
- self.input_resolution = input_resolution
- self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
- self.attention = SwinAttention(config, dim, num_heads, window_size=self.window_size)
- self.drop_path = SwinDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
- self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
- self.intermediate = SwinIntermediate(config, dim)
- self.output = SwinOutput(config, dim)
- def set_shift_and_window_size(self, input_resolution):
- if min(input_resolution) <= self.window_size:
- # if window size is larger than input resolution, we don't partition windows
- self.shift_size = torch_int(0)
- self.window_size = (
- torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
- )
- def get_attn_mask(self, height, width, dtype, device):
- if self.shift_size > 0:
- # calculate attention mask for SW-MSA
- img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
- height_slices = (
- slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None),
- )
- width_slices = (
- slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None),
- )
- count = 0
- for height_slice in height_slices:
- for width_slice in width_slices:
- img_mask[:, height_slice, width_slice, :] = count
- count += 1
- mask_windows = window_partition(img_mask, self.window_size)
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
- attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0).masked_fill(attn_mask == 0, 0.0)
- else:
- attn_mask = None
- return attn_mask
- def maybe_pad(self, hidden_states, height, width):
- pad_right = (self.window_size - width % self.window_size) % self.window_size
- pad_bottom = (self.window_size - height % self.window_size) % self.window_size
- pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
- hidden_states = nn.functional.pad(hidden_states, pad_values)
- return hidden_states, pad_values
- def forward(
- self,
- hidden_states: torch.Tensor,
- input_dimensions: tuple[int, int],
- output_attentions: bool | None = False,
- always_partition: bool | None = False,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- if not always_partition:
- self.set_shift_and_window_size(input_dimensions)
- else:
- pass
- height, width = input_dimensions
- batch_size, _, channels = hidden_states.size()
- shortcut = hidden_states
- hidden_states = self.layernorm_before(hidden_states)
- hidden_states = hidden_states.view(batch_size, height, width, channels)
- # pad hidden_states to multiples of window size
- hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
- _, height_pad, width_pad, _ = hidden_states.shape
- # cyclic shift
- if self.shift_size > 0:
- shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
- else:
- shifted_hidden_states = hidden_states
- # partition windows
- hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
- hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
- attn_mask = self.get_attn_mask(
- height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device
- )
- attention_outputs = self.attention(hidden_states_windows, attn_mask, output_attentions=output_attentions)
- attention_output = attention_outputs[0]
- attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
- shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad)
- # reverse cyclic shift
- if self.shift_size > 0:
- attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
- else:
- attention_windows = shifted_windows
- was_padded = pad_values[3] > 0 or pad_values[5] > 0
- if was_padded:
- attention_windows = attention_windows[:, :height, :width, :].contiguous()
- attention_windows = attention_windows.view(batch_size, height * width, channels)
- hidden_states = shortcut + self.drop_path(attention_windows)
- layer_output = self.layernorm_after(hidden_states)
- layer_output = self.intermediate(layer_output)
- layer_output = hidden_states + self.output(layer_output)
- layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
- return layer_outputs
- class SwinStage(GradientCheckpointingLayer):
- def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
- super().__init__()
- self.config = config
- self.dim = dim
- self.blocks = nn.ModuleList(
- [
- SwinLayer(
- config=config,
- dim=dim,
- input_resolution=input_resolution,
- num_heads=num_heads,
- drop_path_rate=drop_path[i],
- shift_size=0 if (i % 2 == 0) else config.window_size // 2,
- )
- for i in range(depth)
- ]
- )
- # patch merging layer
- if downsample is not None:
- self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm)
- else:
- self.downsample = None
- self.pointing = False
- def forward(
- self,
- hidden_states: torch.Tensor,
- input_dimensions: tuple[int, int],
- output_attentions: bool | None = False,
- always_partition: bool | None = False,
- ) -> tuple[torch.Tensor]:
- height, width = input_dimensions
- for i, layer_module in enumerate(self.blocks):
- layer_outputs = layer_module(hidden_states, input_dimensions, output_attentions, always_partition)
- hidden_states = layer_outputs[0]
- hidden_states_before_downsampling = hidden_states
- if self.downsample is not None:
- height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
- output_dimensions = (height, width, height_downsampled, width_downsampled)
- hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions)
- else:
- output_dimensions = (height, width, height, width)
- stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)
- if output_attentions:
- stage_outputs += layer_outputs[1:]
- return stage_outputs
- class SwinEncoder(nn.Module):
- def __init__(self, config, grid_size):
- super().__init__()
- self.num_layers = len(config.depths)
- self.config = config
- dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
- self.layers = nn.ModuleList(
- [
- SwinStage(
- config=config,
- dim=int(config.embed_dim * 2**i_layer),
- input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
- depth=config.depths[i_layer],
- num_heads=config.num_heads[i_layer],
- drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
- downsample=SwinPatchMerging if (i_layer < self.num_layers - 1) else None,
- )
- for i_layer in range(self.num_layers)
- ]
- )
- self.gradient_checkpointing = False
- def forward(
- self,
- hidden_states: torch.Tensor,
- input_dimensions: tuple[int, int],
- output_attentions: bool | None = False,
- output_hidden_states: bool | None = False,
- output_hidden_states_before_downsampling: bool | None = False,
- always_partition: bool | None = False,
- return_dict: bool | None = True,
- ) -> tuple | SwinEncoderOutput:
- all_hidden_states = () if output_hidden_states else None
- all_reshaped_hidden_states = () if output_hidden_states else None
- all_self_attentions = () if output_attentions else None
- if output_hidden_states:
- batch_size, _, hidden_size = hidden_states.shape
- # rearrange b (h w) c -> b c h w
- reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
- reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
- all_hidden_states += (hidden_states,)
- all_reshaped_hidden_states += (reshaped_hidden_state,)
- for i, layer_module in enumerate(self.layers):
- layer_outputs = layer_module(hidden_states, input_dimensions, output_attentions, always_partition)
- hidden_states = layer_outputs[0]
- hidden_states_before_downsampling = layer_outputs[1]
- output_dimensions = layer_outputs[2]
- input_dimensions = (output_dimensions[-2], output_dimensions[-1])
- if output_hidden_states and output_hidden_states_before_downsampling:
- batch_size, _, hidden_size = hidden_states_before_downsampling.shape
- # rearrange b (h w) c -> b c h w
- # here we use the original (not downsampled) height and width
- reshaped_hidden_state = hidden_states_before_downsampling.view(
- batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
- )
- reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
- all_hidden_states += (hidden_states_before_downsampling,)
- all_reshaped_hidden_states += (reshaped_hidden_state,)
- elif output_hidden_states and not output_hidden_states_before_downsampling:
- batch_size, _, hidden_size = hidden_states.shape
- # rearrange b (h w) c -> b c h w
- reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
- reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
- all_hidden_states += (hidden_states,)
- all_reshaped_hidden_states += (reshaped_hidden_state,)
- if output_attentions:
- all_self_attentions += layer_outputs[3:]
- if not return_dict:
- return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
- return SwinEncoderOutput(
- last_hidden_state=hidden_states,
- hidden_states=all_hidden_states,
- attentions=all_self_attentions,
- reshaped_hidden_states=all_reshaped_hidden_states,
- )
- @auto_docstring
- class SwinPreTrainedModel(PreTrainedModel):
- config: SwinConfig
- base_model_prefix = "swin"
- main_input_name = "pixel_values"
- input_modalities = ("image",)
- supports_gradient_checkpointing = True
- _no_split_modules = ["SwinStage"]
- @torch.no_grad()
- def _init_weights(self, module):
- """Initialize the weights"""
- super()._init_weights(module)
- if isinstance(module, SwinEmbeddings):
- if module.mask_token is not None:
- init.zeros_(module.mask_token)
- if module.position_embeddings is not None:
- init.zeros_(module.position_embeddings)
- elif isinstance(module, SwinSelfAttention):
- init.zeros_(module.relative_position_bias_table)
- init.copy_(module.relative_position_index, module.create_relative_position_index())
- @auto_docstring
- class SwinModel(SwinPreTrainedModel):
- def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
- r"""
- add_pooling_layer (`bool`, *optional*, defaults to `True`):
- Whether or not to apply pooling layer.
- use_mask_token (`bool`, *optional*, defaults to `False`):
- Whether or not to create and apply mask tokens in the embedding layer.
- """
- super().__init__(config)
- self.config = config
- self.num_layers = len(config.depths)
- self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
- self.embeddings = SwinEmbeddings(config, use_mask_token=use_mask_token)
- self.encoder = SwinEncoder(config, self.embeddings.patch_grid)
- self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
- self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.embeddings.patch_embeddings
- @auto_docstring
- def forward(
- self,
- pixel_values: torch.FloatTensor | None = None,
- bool_masked_pos: torch.BoolTensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- interpolate_pos_encoding: bool = False,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | SwinModelOutput:
- r"""
- bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
- Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
- """
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- if pixel_values is None:
- raise ValueError("You have to specify pixel_values")
- embedding_output, input_dimensions = self.embeddings(
- pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
- )
- encoder_outputs = self.encoder(
- embedding_output,
- input_dimensions,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = encoder_outputs[0]
- sequence_output = self.layernorm(sequence_output)
- pooled_output = None
- if self.pooler is not None:
- pooled_output = self.pooler(sequence_output.transpose(1, 2))
- pooled_output = torch.flatten(pooled_output, 1)
- if not return_dict:
- output = (sequence_output, pooled_output) + encoder_outputs[1:]
- return output
- return SwinModelOutput(
- last_hidden_state=sequence_output,
- pooler_output=pooled_output,
- hidden_states=encoder_outputs.hidden_states,
- attentions=encoder_outputs.attentions,
- reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
- )
- @auto_docstring(
- custom_intro="""
- Swin Model with a decoder on top for masked image modeling, as proposed in [SimMIM](https://huggingface.co/papers/2111.09886).
- <Tip>
- Note that we provide a script to pre-train this model on custom data in our [examples
- directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
- </Tip>
- """
- )
- class SwinForMaskedImageModeling(SwinPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.swin = SwinModel(config, add_pooling_layer=False, use_mask_token=True)
- num_features = int(config.embed_dim * 2 ** (config.num_layers - 1))
- self.decoder = nn.Sequential(
- nn.Conv2d(
- in_channels=num_features, out_channels=config.encoder_stride**2 * config.num_channels, kernel_size=1
- ),
- nn.PixelShuffle(config.encoder_stride),
- )
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- pixel_values: torch.FloatTensor | None = None,
- bool_masked_pos: torch.BoolTensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- interpolate_pos_encoding: bool = False,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | SwinMaskedImageModelingOutput:
- r"""
- bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
- Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
- Examples:
- ```python
- >>> from transformers import AutoImageProcessor, SwinForMaskedImageModeling
- >>> import torch
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-base-simmim-window6-192")
- >>> model = SwinForMaskedImageModeling.from_pretrained("microsoft/swin-base-simmim-window6-192")
- >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
- >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
- >>> # create random boolean mask of shape (batch_size, num_patches)
- >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
- >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
- >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
- >>> list(reconstructed_pixel_values.shape)
- [1, 3, 192, 192]
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.swin(
- pixel_values,
- bool_masked_pos=bool_masked_pos,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- interpolate_pos_encoding=interpolate_pos_encoding,
- return_dict=return_dict,
- )
- sequence_output = outputs[0]
- # Reshape to (batch_size, num_channels, height, width)
- sequence_output = sequence_output.transpose(1, 2)
- batch_size, num_channels, sequence_length = sequence_output.shape
- height = width = math.floor(sequence_length**0.5)
- sequence_output = sequence_output.reshape(batch_size, num_channels, height, width)
- # Reconstruct pixel values
- reconstructed_pixel_values = self.decoder(sequence_output)
- masked_im_loss = None
- if bool_masked_pos is not None:
- size = self.config.image_size // self.config.patch_size
- bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
- mask = (
- bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
- .repeat_interleave(self.config.patch_size, 2)
- .unsqueeze(1)
- .contiguous()
- )
- reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
- masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
- if not return_dict:
- output = (reconstructed_pixel_values,) + outputs[2:]
- return ((masked_im_loss,) + output) if masked_im_loss is not None else output
- return SwinMaskedImageModelingOutput(
- loss=masked_im_loss,
- reconstruction=reconstructed_pixel_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- reshaped_hidden_states=outputs.reshaped_hidden_states,
- )
- @auto_docstring(
- custom_intro="""
- Swin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
- the [CLS] token) e.g. for ImageNet.
- <Tip>
- Note that it's possible to fine-tune Swin on higher resolution images than the ones it has been trained on, by
- setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
- position embeddings to the higher resolution.
- </Tip>
- """
- )
- class SwinForImageClassification(SwinPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.swin = SwinModel(config)
- # Classifier head
- self.classifier = (
- nn.Linear(self.swin.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
- )
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- pixel_values: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- interpolate_pos_encoding: bool = False,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | SwinImageClassifierOutput:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.swin(
- pixel_values,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- interpolate_pos_encoding=interpolate_pos_encoding,
- return_dict=return_dict,
- )
- pooled_output = outputs[1]
- logits = self.classifier(pooled_output)
- loss = None
- if labels is not None:
- loss = self.loss_function(labels, logits, self.config)
- if not return_dict:
- output = (logits,) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return SwinImageClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- reshaped_hidden_states=outputs.reshaped_hidden_states,
- )
- @auto_docstring(
- custom_intro="""
- Swin backbone, to be used with frameworks like DETR and MaskFormer.
- """
- )
- class SwinBackbone(BackboneMixin, SwinPreTrainedModel):
- def __init__(self, config: SwinConfig):
- super().__init__(config)
- self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
- self.embeddings = SwinEmbeddings(config)
- self.encoder = SwinEncoder(config, self.embeddings.patch_grid)
- # Add layer norms to hidden states of out_features
- hidden_states_norms = {}
- for stage, num_channels in zip(self.out_features, self.channels):
- hidden_states_norms[stage] = nn.LayerNorm(num_channels)
- self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.embeddings.patch_embeddings
- @can_return_tuple
- @filter_output_hidden_states
- def forward(
- self,
- pixel_values: torch.Tensor,
- output_hidden_states: bool | None = None,
- output_attentions: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> BackboneOutput:
- """
- Returns:
- Examples:
- ```python
- >>> from transformers import AutoImageProcessor, AutoBackbone
- >>> import torch
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224")
- >>> model = AutoBackbone.from_pretrained(
- ... "microsoft/swin-tiny-patch4-window7-224", out_features=["stage1", "stage2", "stage3", "stage4"]
- ... )
- >>> inputs = processor(image, return_tensors="pt")
- >>> outputs = model(**inputs)
- >>> feature_maps = outputs.feature_maps
- >>> list(feature_maps[-1].shape)
- [1, 768, 7, 7]
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- embedding_output, input_dimensions = self.embeddings(pixel_values)
- outputs = self.encoder(
- embedding_output,
- input_dimensions,
- output_attentions=output_attentions,
- output_hidden_states=True,
- output_hidden_states_before_downsampling=True,
- always_partition=True,
- return_dict=True,
- )
- hidden_states = outputs.reshaped_hidden_states
- feature_maps = ()
- for stage, hidden_state in zip(self.stage_names, hidden_states):
- if stage in self.out_features:
- batch_size, num_channels, height, width = hidden_state.shape
- hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
- hidden_state = hidden_state.view(batch_size, height * width, num_channels)
- hidden_state = self.hidden_states_norms[stage](hidden_state)
- hidden_state = hidden_state.view(batch_size, height, width, num_channels)
- hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
- feature_maps += (hidden_state,)
- if not return_dict:
- output = (feature_maps,)
- if output_hidden_states:
- output += (outputs.hidden_states,)
- return output
- return BackboneOutput(
- feature_maps=feature_maps,
- hidden_states=outputs.hidden_states if output_hidden_states else None,
- attentions=outputs.attentions,
- )
- __all__ = [
- "SwinForImageClassification",
- "SwinForMaskedImageModeling",
- "SwinModel",
- "SwinPreTrainedModel",
- "SwinBackbone",
- ]
|