| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448 |
- # Copyright 2021 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch BEiT model."""
- import collections.abc
- import math
- from dataclasses import dataclass
- import torch
- from torch import Tensor, nn
- from torch.nn import CrossEntropyLoss
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...backbone_utils import BackboneMixin, filter_output_hidden_states
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import (
- BackboneOutput,
- BaseModelOutput,
- BaseModelOutputWithPooling,
- ImageClassifierOutput,
- MaskedLMOutput,
- SemanticSegmenterOutput,
- )
- from ...modeling_utils import PreTrainedModel
- from ...pytorch_utils import compile_compatible_method_lru_cache
- from ...utils import auto_docstring, logging, torch_int
- from ...utils.generic import can_return_tuple
- from .configuration_beit import BeitConfig
- logger = logging.get_logger(__name__)
- @dataclass
- @auto_docstring(
- custom_intro="""
- Class for outputs of [`BeitModel`].
- """
- )
- class BeitModelOutputWithPooling(BaseModelOutputWithPooling):
- r"""
- pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
- Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
- *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
- will be returned.
- """
- def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
- """
- Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
- """
- if drop_prob == 0.0 or not training:
- return input
- keep_prob = 1 - drop_prob
- shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
- random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
- random_tensor.floor_() # binarize
- output = input.div(keep_prob) * random_tensor
- return output
- class BeitDropPath(nn.Module):
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
- def __init__(self, drop_prob: float | None = None) -> None:
- super().__init__()
- self.drop_prob = drop_prob
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- return drop_path(hidden_states, self.drop_prob, self.training)
- def extra_repr(self) -> str:
- return f"p={self.drop_prob}"
- # Based on timm implementation, which can be found here:
- # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
- class BeitEmbeddings(nn.Module):
- """
- Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
- """
- def __init__(self, config: BeitConfig) -> None:
- super().__init__()
- self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
- if config.use_mask_token:
- self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
- else:
- self.mask_token = None
- self.patch_embeddings = BeitPatchEmbeddings(config)
- self.patch_size = config.patch_size
- self.image_size = (
- config.image_size
- if isinstance(config.image_size, collections.abc.Iterable)
- else (config.image_size, config.image_size)
- )
- num_patches = self.patch_embeddings.num_patches
- if config.use_absolute_position_embeddings:
- self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
- else:
- self.position_embeddings = None
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
- def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
- """
- This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
- images. This method is also adapted to support torch.jit tracing.
- Adapted from:
- - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
- """
- num_patches = embeddings.shape[1] - 1
- num_positions = self.position_embeddings.shape[1] - 1
- # always interpolate when tracing to ensure the exported model works for dynamic input shapes
- if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
- return self.position_embeddings
- class_pos_embed = self.position_embeddings[:, :1]
- patch_pos_embed = self.position_embeddings[:, 1:]
- dim = embeddings.shape[-1]
- new_height = height // self.patch_size
- new_width = width // self.patch_size
- sqrt_num_positions = torch_int(num_positions**0.5)
- patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
- patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
- patch_pos_embed = nn.functional.interpolate(
- patch_pos_embed,
- size=(new_height, new_width),
- mode="bicubic",
- align_corners=False,
- )
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
- return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
- def forward(
- self,
- pixel_values: torch.Tensor,
- bool_masked_pos: torch.BoolTensor | None = None,
- ) -> torch.Tensor:
- _, _, height, width = pixel_values.shape
- embeddings, (patch_height, patch_width) = self.patch_embeddings(pixel_values)
- batch_size, seq_len, _ = embeddings.size()
- if bool_masked_pos is not None:
- mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
- # replace the masked visual tokens by mask_tokens
- w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
- embeddings = embeddings * (1 - w) + mask_tokens * w
- cls_tokens = self.cls_token.expand(batch_size, -1, -1)
- embeddings = torch.cat((cls_tokens, embeddings), dim=1)
- if self.position_embeddings is not None:
- embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
- embeddings = self.dropout(embeddings)
- return embeddings, (patch_height, patch_width)
- class BeitPatchEmbeddings(nn.Module):
- """
- This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
- `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
- Transformer.
- """
- def __init__(self, config):
- super().__init__()
- image_size, patch_size = config.image_size, config.patch_size
- num_channels, hidden_size = config.num_channels, config.hidden_size
- image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
- patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
- patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
- self.image_size = image_size
- self.patch_size = patch_size
- self.num_channels = num_channels
- self.num_patches = num_patches
- self.patch_shape = patch_shape
- self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
- def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
- batch_size, num_channels, height, width = pixel_values.shape
- if num_channels != self.num_channels:
- raise ValueError(
- "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
- )
- embeddings = self.projection(pixel_values.to(self.projection.weight.dtype))
- patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
- embeddings = embeddings.flatten(2).transpose(1, 2)
- return embeddings, (patch_height, patch_width)
- class BeitSelfAttention(nn.Module):
- def __init__(self, config: BeitConfig, window_size: tuple | None = None) -> None:
- super().__init__()
- self.config = config
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
- raise ValueError(
- f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
- f"heads {config.num_attention_heads}."
- )
- self.num_attention_heads = config.num_attention_heads
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
- self.all_head_size = self.num_attention_heads * self.attention_head_size
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
- self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
- self.value = nn.Linear(config.hidden_size, self.all_head_size)
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
- self.has_relative_position_bias = bool(window_size)
- if self.has_relative_position_bias:
- self.relative_position_bias = BeitRelativePositionBias(config, window_size=window_size)
- def forward(
- self,
- hidden_states: torch.Tensor,
- output_attentions: bool = False,
- relative_position_bias: torch.Tensor | None = None,
- interpolate_pos_encoding: bool = False,
- resolution: tuple[int] | None = None,
- ) -> tuple[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.attention_head_size)
- query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
- key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
- value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
- # Take the dot product between "query" and "key" to get the raw attention scores.
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
- attention_scores = attention_scores / math.sqrt(self.attention_head_size)
- # Add relative position bias if present.
- if self.has_relative_position_bias:
- height, width = resolution
- window_size = (height // self.config.patch_size, width // self.config.patch_size)
- attention_scores = attention_scores + self.relative_position_bias(
- window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1]
- )
- # Add shared relative position bias if provided.
- if relative_position_bias is not None:
- attention_scores = attention_scores + relative_position_bias
- # Normalize the attention scores to probabilities.
- attention_probs = nn.functional.softmax(attention_scores, dim=-1)
- # This is actually dropping out entire tokens to attend to, which might
- # seem a bit unusual, but is taken from the original Transformer paper.
- attention_probs = self.dropout(attention_probs)
- context_layer = torch.matmul(attention_probs, value_layer)
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(*new_context_layer_shape)
- outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
- return outputs
- class BeitSdpaSelfAttention(BeitSelfAttention):
- def forward(
- self,
- hidden_states: torch.Tensor,
- output_attentions: bool = False,
- relative_position_bias: torch.Tensor | None = None,
- interpolate_pos_encoding: bool = False,
- resolution: tuple[int] | None = None,
- ) -> tuple[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]:
- if output_attentions:
- logger.warning_once(
- f"{self.__class__.__name__} does not support `output_attentions=True`. The returned attention weights will "
- "be `None`. If you want to get attention weights, please set `attn_implementation='eager'` when loading the model."
- )
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.attention_head_size)
- query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
- key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
- value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
- attn_bias = None
- if self.has_relative_position_bias:
- height, width = resolution
- window_size = (height // self.config.patch_size, width // self.config.patch_size)
- attn_bias = self.relative_position_bias(
- window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1]
- )
- # Add shared relative position bias if provided.
- if relative_position_bias is not None:
- if attn_bias is None:
- attn_bias = relative_position_bias
- else:
- attn_bias += relative_position_bias
- scaling = 1 / math.sqrt(self.attention_head_size)
- context_layer = torch.nn.functional.scaled_dot_product_attention(
- query_layer,
- key_layer,
- value_layer,
- attn_mask=attn_bias,
- dropout_p=self.config.attention_probs_dropout_prob if self.training else 0.0,
- is_causal=False,
- scale=scaling,
- )
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(*new_context_layer_shape)
- return context_layer, None
- class BeitSelfOutput(nn.Module):
- """
- The residual connection is defined in BeitLayer instead of here (as is the case with other models), due to the
- layernorm applied before each block.
- """
- def __init__(self, config: BeitConfig) -> None:
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, gamma=None) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- return hidden_states
- BEIT_SELF_ATTENTION_CLASSES = {
- "eager": BeitSelfAttention,
- "sdpa": BeitSdpaSelfAttention,
- }
- class BeitAttention(nn.Module):
- def __init__(self, config: BeitConfig, window_size: tuple | None = None) -> None:
- super().__init__()
- self.attention = BEIT_SELF_ATTENTION_CLASSES[config._attn_implementation](config, window_size=window_size)
- self.output = BeitSelfOutput(config)
- def forward(
- self,
- hidden_states: torch.Tensor,
- output_attentions: bool = False,
- relative_position_bias: torch.Tensor | None = None,
- interpolate_pos_encoding: bool = False,
- resolution: tuple[int] | None = None,
- ) -> tuple[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]:
- self_outputs = self.attention(
- hidden_states, output_attentions, relative_position_bias, interpolate_pos_encoding, resolution
- )
- attention_output = self.output(self_outputs[0], hidden_states)
- outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
- return outputs
- class BeitIntermediate(nn.Module):
- def __init__(self, config: BeitConfig) -> None:
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
- if isinstance(config.hidden_act, str):
- self.intermediate_act_fn = ACT2FN[config.hidden_act]
- else:
- self.intermediate_act_fn = config.hidden_act
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.intermediate_act_fn(hidden_states)
- return hidden_states
- class BeitOutput(nn.Module):
- def __init__(self, config: BeitConfig) -> None:
- super().__init__()
- self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- return hidden_states
- class BeitLayer(GradientCheckpointingLayer):
- """This corresponds to the Block class in the timm implementation."""
- def __init__(self, config: BeitConfig, window_size: tuple | None = None, drop_path_rate: float = 0.0) -> None:
- super().__init__()
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
- self.seq_len_dim = 1
- self.attention = BeitAttention(config, window_size=window_size)
- self.intermediate = BeitIntermediate(config)
- self.output = BeitOutput(config)
- self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.drop_path = BeitDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
- self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- init_values = config.layer_scale_init_value
- if init_values > 0:
- self.lambda_1 = nn.Parameter(init_values * torch.ones(config.hidden_size), requires_grad=True)
- self.lambda_2 = nn.Parameter(init_values * torch.ones(config.hidden_size), requires_grad=True)
- else:
- self.lambda_1, self.lambda_2 = None, None
- def forward(
- self,
- hidden_states: torch.Tensor,
- output_attentions: bool = False,
- relative_position_bias: torch.Tensor | None = None,
- interpolate_pos_encoding: bool = False,
- resolution: tuple[int, int] | None = None,
- ) -> tuple[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]:
- self_attention_outputs = self.attention(
- self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention
- output_attentions=output_attentions,
- relative_position_bias=relative_position_bias,
- interpolate_pos_encoding=interpolate_pos_encoding,
- resolution=resolution,
- )
- attention_output = self_attention_outputs[0]
- outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
- # apply lambda_1 if present
- if self.lambda_1 is not None:
- attention_output = self.lambda_1 * attention_output
- # first residual connection
- hidden_states = self.drop_path(attention_output) + hidden_states
- # in BEiT, layernorm is also applied after self-attention
- layer_output = self.layernorm_after(hidden_states)
- layer_output = self.intermediate(layer_output)
- layer_output = self.output(layer_output)
- if self.lambda_2 is not None:
- layer_output = self.lambda_2 * layer_output
- # second residual connection
- layer_output = self.drop_path(layer_output) + hidden_states
- outputs = (layer_output,) + outputs
- return outputs
- class BeitRelativePositionBias(nn.Module):
- def __init__(self, config: BeitConfig, window_size: tuple) -> None:
- super().__init__()
- self.window_size = window_size
- self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
- self.relative_position_bias_table = nn.Parameter(
- torch.zeros(self.num_relative_distance, config.num_attention_heads)
- ) # 2*Wh-1 * 2*Ww-1, nH
- # cls to token & token 2 cls & cls to cls
- @compile_compatible_method_lru_cache(maxsize=10)
- def generate_relative_position_index(self, window_size: tuple[int, int]) -> torch.Tensor:
- """
- This method creates the relative position index, modified to support arbitrary window sizes,
- as introduced in [MiDaS v3.1](https://huggingface.co/papers/2307.14460).
- """
- num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
- # cls to token & token 2 cls & cls to cls
- # get pair-wise relative position index for each token inside the window
- window_area = window_size[0] * window_size[1]
- grid = torch.meshgrid(torch.arange(window_size[0]), torch.arange(window_size[1]), indexing="ij")
- coords = torch.stack(grid) # 2, Wh, Ww
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
- relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
- relative_coords[:, :, 1] += window_size[1] - 1
- relative_coords[:, :, 0] *= 2 * window_size[1] - 1
- relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype)
- relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
- relative_position_index[0, 0:] = num_relative_distance - 3
- relative_position_index[0:, 0] = num_relative_distance - 2
- relative_position_index[0, 0] = num_relative_distance - 1
- return relative_position_index
- def forward(self, window_size, interpolate_pos_encoding: bool = False, dim_size=None) -> torch.Tensor:
- """
- Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes.
- """
- old_height = 2 * self.window_size[0] - 1
- old_width = 2 * self.window_size[1] - 1
- new_height = 2 * window_size[0] - 1
- new_width = 2 * window_size[1] - 1
- old_relative_position_bias_table = self.relative_position_bias_table
- old_num_relative_distance = self.num_relative_distance
- new_num_relative_distance = new_height * new_width + 3
- old_sub_table = old_relative_position_bias_table[: old_num_relative_distance - 3]
- old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
- new_sub_table = nn.functional.interpolate(
- old_sub_table, size=(torch_int(new_height), torch_int(new_width)), mode="bilinear"
- )
- new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)
- new_relative_position_bias_table = torch.cat(
- [new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3 :]]
- )
- relative_position_index = self.generate_relative_position_index(window_size)
- relative_position_bias = new_relative_position_bias_table[relative_position_index.view(-1)]
- # patch_size*num_patches_height, patch_size*num_patches_width, num_attention_heads
- relative_position_bias = relative_position_bias.view(
- window_size[0] * window_size[1] + 1, window_size[0] * window_size[1] + 1, -1
- )
- # num_attention_heads, patch_size*num_patches_width, patch_size*num_patches_height
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
- if interpolate_pos_encoding:
- relative_position_bias = nn.functional.interpolate(
- relative_position_bias.unsqueeze(1),
- size=(dim_size, dim_size),
- mode="bilinear",
- align_corners=False,
- ).squeeze(1)
- return relative_position_bias.unsqueeze(0)
- class BeitEncoder(nn.Module):
- def __init__(self, config: BeitConfig, window_size: tuple | None = None) -> None:
- super().__init__()
- self.config = config
- self.has_relative_position_bias = config.use_shared_relative_position_bias
- if self.has_relative_position_bias:
- self.relative_position_bias = BeitRelativePositionBias(config, window_size=window_size)
- # stochastic depth decay rule
- dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers, device="cpu")]
- self.layer = nn.ModuleList(
- [
- BeitLayer(
- config,
- window_size=window_size if config.use_relative_position_bias else None,
- drop_path_rate=dpr[i],
- )
- for i in range(config.num_hidden_layers)
- ]
- )
- self.gradient_checkpointing = False
- def forward(
- self,
- hidden_states: torch.Tensor,
- output_attentions: bool = False,
- output_hidden_states: bool = False,
- interpolate_pos_encoding: bool = False,
- resolution: tuple[int, int] | None = None,
- return_dict: bool = True,
- ) -> tuple | BaseModelOutput:
- all_hidden_states = () if output_hidden_states else None
- all_self_attentions = () if output_attentions else None
- for i, layer_module in enumerate(self.layer):
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- if self.has_relative_position_bias:
- height, width = resolution
- window_size = (height // self.config.patch_size, width // self.config.patch_size)
- relative_position_bias = self.relative_position_bias(
- window_size, interpolate_pos_encoding=interpolate_pos_encoding, dim_size=hidden_states.shape[1]
- )
- else:
- relative_position_bias = None
- layer_outputs = layer_module(
- hidden_states,
- output_attentions=output_attentions,
- relative_position_bias=relative_position_bias,
- interpolate_pos_encoding=interpolate_pos_encoding,
- resolution=resolution,
- )
- hidden_states = layer_outputs[0]
- if output_attentions:
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- if not return_dict:
- return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
- return BaseModelOutput(
- last_hidden_state=hidden_states,
- hidden_states=all_hidden_states,
- attentions=all_self_attentions,
- )
- @auto_docstring
- class BeitPreTrainedModel(PreTrainedModel):
- config: BeitConfig
- base_model_prefix = "beit"
- input_modalities = ("image",)
- main_input_name = "pixel_values"
- supports_gradient_checkpointing = True
- _no_split_modules = ["BeitLayer"]
- _keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"]
- _supports_sdpa = True
- @torch.no_grad()
- def _init_weights(self, module):
- """Initialize the weights"""
- super()._init_weights(module)
- if isinstance(module, BeitEmbeddings):
- init.zeros_(module.cls_token)
- if module.mask_token is not None:
- init.zeros_(module.mask_token)
- if module.position_embeddings is not None:
- init.zeros_(module.position_embeddings)
- elif isinstance(module, BeitRelativePositionBias):
- init.zeros_(module.relative_position_bias_table)
- elif isinstance(module, BeitLayer):
- if module.lambda_1 is not None:
- init.constant_(module.lambda_1, self.config.layer_scale_init_value)
- init.constant_(module.lambda_2, self.config.layer_scale_init_value)
- @auto_docstring
- class BeitModel(BeitPreTrainedModel):
- def __init__(self, config: BeitConfig, add_pooling_layer: bool = True) -> None:
- r"""
- add_pooling_layer (bool, *optional*, defaults to `True`):
- Whether to add a pooling layer
- """
- super().__init__(config)
- self.config = config
- self.embeddings = BeitEmbeddings(config)
- self.encoder = BeitEncoder(config, window_size=self.embeddings.patch_embeddings.patch_shape)
- self.layernorm = (
- nn.Identity() if config.use_mean_pooling else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- )
- self.pooler = BeitPooler(config) if add_pooling_layer else None
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.embeddings.patch_embeddings
- @auto_docstring
- def forward(
- self,
- pixel_values: torch.Tensor,
- bool_masked_pos: torch.BoolTensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- interpolate_pos_encoding: bool = False,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | BeitModelOutputWithPooling:
- r"""
- bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
- Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
- """
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- embedding_output, _ = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
- resolution = pixel_values.shape[2:]
- encoder_outputs = self.encoder(
- embedding_output,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- resolution=resolution,
- return_dict=return_dict,
- interpolate_pos_encoding=interpolate_pos_encoding,
- )
- sequence_output = encoder_outputs[0]
- sequence_output = self.layernorm(sequence_output)
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
- if not return_dict:
- head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
- return head_outputs + encoder_outputs[1:]
- return BeitModelOutputWithPooling(
- last_hidden_state=sequence_output,
- pooler_output=pooled_output,
- hidden_states=encoder_outputs.hidden_states,
- attentions=encoder_outputs.attentions,
- )
- class BeitPooler(nn.Module):
- def __init__(self, config: BeitConfig) -> None:
- super().__init__()
- self.layernorm = (
- nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_mean_pooling else None
- )
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- if self.layernorm is not None:
- # Mean pool the final hidden states of the patch tokens
- patch_tokens = hidden_states[:, 1:, :]
- pooled_output = self.layernorm(patch_tokens.mean(1))
- else:
- # Pool by simply taking the final hidden state of the [CLS] token
- pooled_output = hidden_states[:, 0]
- return pooled_output
- @auto_docstring(
- custom_intro="""
- Beit Model transformer with a 'language' modeling head on top. BEiT does masked image modeling by predicting
- visual tokens of a Vector-Quantize Variational Autoencoder (VQ-VAE), whereas other vision models like ViT and DeiT
- predict RGB pixel values. As a result, this class is incompatible with [`AutoModelForMaskedImageModeling`], so you
- will need to use [`BeitForMaskedImageModeling`] directly if you wish to do masked image modeling with BEiT.
- """
- )
- class BeitForMaskedImageModeling(BeitPreTrainedModel):
- def __init__(self, config: BeitConfig) -> None:
- super().__init__(config)
- self.num_labels = config.num_labels
- self.beit = BeitModel(config, add_pooling_layer=False)
- # Classifier head
- self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
- # Initialize weights and apply final processing
- self.post_init()
- def get_output_embeddings(self):
- return None
- @auto_docstring
- def forward(
- self,
- pixel_values: torch.Tensor | None = None,
- bool_masked_pos: torch.BoolTensor | None = None,
- labels: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- interpolate_pos_encoding: bool = False,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | MaskedLMOutput:
- r"""
- bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
- Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- Examples:
- ```python
- >>> from transformers import AutoImageProcessor, BeitForMaskedImageModeling
- >>> import torch
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224-pt22k")
- >>> model = BeitForMaskedImageModeling.from_pretrained("microsoft/beit-base-patch16-224-pt22k")
- >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
- >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
- >>> # create random boolean mask of shape (batch_size, num_patches)
- >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
- >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
- >>> loss, logits = outputs.loss, outputs.logits
- >>> list(logits.shape)
- [1, 196, 8192]
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.beit(
- pixel_values,
- bool_masked_pos=bool_masked_pos,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- interpolate_pos_encoding=interpolate_pos_encoding,
- return_dict=return_dict,
- )
- sequence_output = outputs[0]
- sequence_output = self.layernorm(sequence_output)
- prediction_scores = self.lm_head(sequence_output[:, 1:])
- masked_lm_loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss() # -100 index = padding token
- masked_lm_loss = loss_fct(prediction_scores[bool_masked_pos], labels)
- if not return_dict:
- output = (prediction_scores,) + outputs[1:]
- return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
- return MaskedLMOutput(
- loss=masked_lm_loss,
- logits=prediction_scores,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring(
- custom_intro="""
- Beit Model transformer with an image classification head on top (a linear layer on top of the average of the final
- hidden states of the patch tokens) e.g. for ImageNet.
- """
- )
- class BeitForImageClassification(BeitPreTrainedModel):
- def __init__(self, config: BeitConfig) -> None:
- super().__init__(config)
- self.num_labels = config.num_labels
- self.beit = BeitModel(config, add_pooling_layer=True)
- # Classifier head
- self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- pixel_values: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- interpolate_pos_encoding: bool = False,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | ImageClassifierOutput:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.beit(
- pixel_values,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- interpolate_pos_encoding=interpolate_pos_encoding,
- return_dict=return_dict,
- )
- pooled_output = outputs.pooler_output if return_dict else outputs[1]
- logits = self.classifier(pooled_output)
- loss = None
- if labels is not None:
- loss = self.loss_function(labels, logits, self.config)
- if not return_dict:
- output = (logits,) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return ImageClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- class BeitConvModule(nn.Module):
- """
- A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
- layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
- Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
- """
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int | tuple[int, int],
- padding: int | tuple[int, int] | str = 0,
- bias: bool = False,
- dilation: int | tuple[int, int] = 1,
- ) -> None:
- super().__init__()
- self.conv = nn.Conv2d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=kernel_size,
- padding=padding,
- bias=bias,
- dilation=dilation,
- )
- self.bn = nn.BatchNorm2d(out_channels)
- self.activation = nn.ReLU()
- def forward(self, input: torch.Tensor) -> torch.Tensor:
- output = self.conv(input)
- output = self.bn(output)
- output = self.activation(output)
- return output
- class BeitPyramidPoolingBlock(nn.Module):
- def __init__(self, pool_scale: int, in_channels: int, channels: int) -> None:
- super().__init__()
- self.layers = [
- nn.AdaptiveAvgPool2d(pool_scale),
- BeitConvModule(in_channels, channels, kernel_size=1),
- ]
- for i, layer in enumerate(self.layers):
- self.add_module(str(i), layer)
- def forward(self, input: torch.Tensor) -> torch.Tensor:
- hidden_state = input
- for layer in self.layers:
- hidden_state = layer(hidden_state)
- return hidden_state
- class BeitPyramidPoolingModule(nn.Module):
- """
- Pyramid Pooling Module (PPM) used in PSPNet.
- Args:
- pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
- Module.
- in_channels (int): Input channels.
- channels (int): Channels after modules, before conv_seg.
- align_corners (bool): align_corners argument of F.interpolate.
- Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
- """
- def __init__(self, pool_scales: tuple[int, ...], in_channels: int, channels: int, align_corners: bool) -> None:
- super().__init__()
- self.pool_scales = pool_scales
- self.align_corners = align_corners
- self.in_channels = in_channels
- self.channels = channels
- self.blocks = []
- for i, pool_scale in enumerate(pool_scales):
- block = BeitPyramidPoolingBlock(pool_scale=pool_scale, in_channels=in_channels, channels=channels)
- self.blocks.append(block)
- self.add_module(str(i), block)
- def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
- ppm_outs = []
- for ppm in self.blocks:
- ppm_out = ppm(x)
- upsampled_ppm_out = nn.functional.interpolate(
- ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
- )
- ppm_outs.append(upsampled_ppm_out)
- return ppm_outs
- class BeitUperHead(nn.Module):
- """
- Unified Perceptual Parsing for Scene Understanding. This head is the implementation of
- [UPerNet](https://huggingface.co/papers/1807.10221).
- Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
- """
- def __init__(self, config: BeitConfig) -> None:
- super().__init__()
- self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6)
- self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768]
- self.channels = config.hidden_size
- self.align_corners = False
- self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
- # PSP Module
- self.psp_modules = BeitPyramidPoolingModule(
- self.pool_scales,
- self.in_channels[-1],
- self.channels,
- align_corners=self.align_corners,
- )
- self.bottleneck = BeitConvModule(
- self.in_channels[-1] + len(self.pool_scales) * self.channels,
- self.channels,
- kernel_size=3,
- padding=1,
- )
- # FPN Module
- self.lateral_convs = nn.ModuleList()
- self.fpn_convs = nn.ModuleList()
- for in_channels in self.in_channels[:-1]: # skip the top layer
- l_conv = BeitConvModule(in_channels, self.channels, kernel_size=1)
- fpn_conv = BeitConvModule(self.channels, self.channels, kernel_size=3, padding=1)
- self.lateral_convs.append(l_conv)
- self.fpn_convs.append(fpn_conv)
- self.fpn_bottleneck = BeitConvModule(
- len(self.in_channels) * self.channels,
- self.channels,
- kernel_size=3,
- padding=1,
- )
- def psp_forward(self, inputs):
- x = inputs[-1]
- psp_outs = [x]
- psp_outs.extend(self.psp_modules(x))
- psp_outs = torch.cat(psp_outs, dim=1)
- output = self.bottleneck(psp_outs)
- return output
- def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
- # build laterals
- laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
- laterals.append(self.psp_forward(encoder_hidden_states))
- # build top-down path
- used_backbone_levels = len(laterals)
- for i in range(used_backbone_levels - 1, 0, -1):
- prev_shape = laterals[i - 1].shape[2:]
- laterals[i - 1] = laterals[i - 1] + nn.functional.interpolate(
- laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners
- )
- # build outputs
- fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
- # append psp feature
- fpn_outs.append(laterals[-1])
- for i in range(used_backbone_levels - 1, 0, -1):
- fpn_outs[i] = nn.functional.interpolate(
- fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
- )
- fpn_outs = torch.cat(fpn_outs, dim=1)
- output = self.fpn_bottleneck(fpn_outs)
- output = self.classifier(output)
- return output
- class BeitFCNHead(nn.Module):
- """
- Fully Convolution Networks for Semantic Segmentation. This head is implemented of
- [FCNNet](https://huggingface.co/papers/1411.4038>).
- Args:
- config (BeitConfig): Configuration.
- in_channels
- kernel_size (int): The kernel size for convs in the head. Default: 3.
- dilation (int): The dilation rate for convs in the head. Default: 1.
- Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
- """
- def __init__(
- self, config: BeitConfig, in_index: int = 2, kernel_size: int = 3, dilation: int | tuple[int, int] = 1
- ) -> None:
- super().__init__()
- self.in_channels = config.hidden_size
- self.channels = config.auxiliary_channels
- self.num_convs = config.auxiliary_num_convs
- self.concat_input = config.auxiliary_concat_input
- self.in_index = in_index
- conv_padding = (kernel_size // 2) * dilation
- convs = []
- convs.append(
- BeitConvModule(
- self.in_channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
- )
- )
- for i in range(self.num_convs - 1):
- convs.append(
- BeitConvModule(
- self.channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
- )
- )
- if self.num_convs == 0:
- self.convs = nn.Identity()
- else:
- self.convs = nn.Sequential(*convs)
- if self.concat_input:
- self.conv_cat = BeitConvModule(
- self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2
- )
- self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
- def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
- # just take the relevant feature maps
- hidden_states = encoder_hidden_states[self.in_index]
- output = self.convs(hidden_states)
- if self.concat_input:
- output = self.conv_cat(torch.cat([hidden_states, output], dim=1))
- output = self.classifier(output)
- return output
- @auto_docstring
- class BeitForSemanticSegmentation(BeitPreTrainedModel):
- def __init__(self, config: BeitConfig) -> None:
- super().__init__(config)
- self.num_labels = config.num_labels
- self.beit = BeitModel(config, add_pooling_layer=False)
- # FPNs
- if len(self.config.out_indices) != 4:
- raise ValueError(
- "BeitForSemanticSegmentation requires config.out_indices to be a list of 4 integers, "
- "specifying which features to use from the backbone. One can use [3, 5, 7, 11] in case of "
- "a base-sized architecture."
- )
- self.fpn1 = nn.Sequential(
- nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
- nn.BatchNorm2d(config.hidden_size),
- nn.GELU(),
- nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
- )
- self.fpn2 = nn.Sequential(
- nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
- )
- self.fpn3 = nn.Identity()
- self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
- # Semantic segmentation head(s)
- self.decode_head = BeitUperHead(config)
- self.auxiliary_head = BeitFCNHead(config) if config.use_auxiliary_head else None
- # Initialize weights and apply final processing
- self.post_init()
- def compute_loss(self, logits, auxiliary_logits, labels):
- # upsample logits to the images' original size
- upsampled_logits = nn.functional.interpolate(
- logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
- )
- if auxiliary_logits is not None:
- upsampled_auxiliary_logits = nn.functional.interpolate(
- auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
- )
- # compute weighted loss
- loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
- main_loss = loss_fct(upsampled_logits, labels)
- loss = main_loss
- if auxiliary_logits is not None:
- auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels)
- loss += self.config.auxiliary_loss_weight * auxiliary_loss
- return loss
- @auto_docstring
- def forward(
- self,
- pixel_values: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- interpolate_pos_encoding: bool = False,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | SemanticSegmenterOutput:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
- Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
- Examples:
- ```python
- >>> from transformers import AutoImageProcessor, BeitForSemanticSegmentation
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-finetuned-ade-640-640")
- >>> model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640")
- >>> inputs = image_processor(images=image, return_tensors="pt")
- >>> outputs = model(**inputs)
- >>> # logits are of shape (batch_size, num_labels, height, width)
- >>> logits = outputs.logits
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- if labels is not None and self.config.num_labels == 1:
- raise ValueError("The number of labels should be greater than one")
- outputs = self.beit(
- pixel_values,
- output_attentions=output_attentions,
- output_hidden_states=True, # we need the intermediate hidden states
- interpolate_pos_encoding=interpolate_pos_encoding,
- return_dict=return_dict,
- )
- encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
- # only keep certain features, and reshape
- # note that we do +1 as the encoder_hidden_states also includes the initial embeddings
- features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices]
- batch_size = pixel_values.shape[0]
- patch_resolution = self.config.image_size // self.config.patch_size
- features = [
- x[:, 1:, :].permute(0, 2, 1).reshape(batch_size, -1, patch_resolution, patch_resolution) for x in features
- ]
- # apply FPNs
- ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
- for i in range(len(features)):
- features[i] = ops[i](features[i])
- logits = self.decode_head(features)
- auxiliary_logits = None
- if self.auxiliary_head is not None:
- auxiliary_logits = self.auxiliary_head(features)
- loss = None
- if labels is not None:
- loss = self.compute_loss(logits, auxiliary_logits, labels)
- if not return_dict:
- if output_hidden_states:
- output = (logits,) + outputs[1:]
- else:
- output = (logits,) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return SemanticSegmenterOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states if output_hidden_states else None,
- attentions=outputs.attentions,
- )
- @auto_docstring(
- custom_intro="""
- BEiT backbone, to be used with frameworks like DETR and MaskFormer.
- """
- )
- class BeitBackbone(BackboneMixin, BeitPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
- self.embeddings = BeitEmbeddings(config)
- self.encoder = BeitEncoder(config, window_size=self.embeddings.patch_embeddings.patch_shape)
- if config.add_fpn:
- if len(self.config.out_indices) != 4:
- raise ValueError(
- "BeitBackbone requires config.out_indices to be a list of 4 integers, "
- "specifying which features to use from the backbone. One can use [3, 5, 7, 11] in case of "
- "a base-sized architecture."
- )
- hidden_size = config.hidden_size
- self.fpn1 = nn.Sequential(
- nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2),
- nn.BatchNorm2d(hidden_size, eps=config.batch_norm_eps),
- nn.GELU(),
- nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2),
- )
- self.fpn2 = nn.Sequential(nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2))
- self.fpn3 = nn.Identity()
- self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
- # initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.embeddings.patch_embeddings
- @can_return_tuple
- @filter_output_hidden_states
- @auto_docstring
- def forward(
- self,
- pixel_values: Tensor,
- output_hidden_states: bool | None = None,
- output_attentions: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> BackboneOutput:
- r"""
- Examples:
- ```python
- >>> from transformers import AutoImageProcessor, AutoBackbone
- >>> import torch
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224")
- >>> model = AutoBackbone.from_pretrained(
- ... "microsoft/beit-base-patch16-224", out_features=["stage1", "stage2", "stage3", "stage4"]
- ... )
- >>> inputs = processor(image, return_tensors="pt")
- >>> outputs = model(**inputs)
- >>> feature_maps = outputs.feature_maps
- >>> list(feature_maps[-1].shape)
- [1, 768, 14, 14]
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- batch_size = pixel_values.shape[0]
- embedding_output, (patch_height, patch_width) = self.embeddings(pixel_values)
- resolution = pixel_values.shape[2:]
- outputs = self.encoder(
- embedding_output,
- output_hidden_states=True,
- output_attentions=output_attentions,
- resolution=resolution,
- return_dict=return_dict,
- )
- hidden_states = outputs.hidden_states if return_dict else outputs[1]
- feature_maps = ()
- for stage, hidden_state in zip(self.stage_names, hidden_states):
- if stage in self.out_features:
- if self.config.reshape_hidden_states:
- hidden_state = hidden_state[:, 1:, :]
- hidden_state = hidden_state.permute(0, 2, 1)
- hidden_state = hidden_state.reshape(batch_size, -1, patch_height, patch_width)
- feature_maps += (hidden_state,)
- if self.config.add_fpn:
- feature_maps = [
- self.fpn1(feature_maps[0]),
- self.fpn2(feature_maps[1]),
- self.fpn3(feature_maps[2]),
- self.fpn4(feature_maps[3]),
- ]
- feature_maps = tuple(feature_maps)
- if not return_dict:
- if output_hidden_states:
- output = (feature_maps,) + outputs[1:]
- else:
- output = (feature_maps,) + outputs[2:]
- return output
- return BackboneOutput(
- feature_maps=feature_maps,
- hidden_states=outputs.hidden_states if output_hidden_states else None,
- attentions=outputs.attentions,
- )
- __all__ = [
- "BeitForImageClassification",
- "BeitForMaskedImageModeling",
- "BeitForSemanticSegmentation",
- "BeitModel",
- "BeitPreTrainedModel",
- "BeitBackbone",
- ]
|