| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639 |
- # Copyright 2021 Facebook AI Research The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch DETR model."""
- import math
- from collections.abc import Callable
- from dataclasses import dataclass
- import torch
- import torch.nn as nn
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...backbone_utils import load_backbone
- from ...masking_utils import create_bidirectional_mask
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import (
- BaseModelOutput,
- BaseModelOutputWithCrossAttentions,
- Seq2SeqModelOutput,
- )
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
- from ...processing_utils import Unpack
- from ...pytorch_utils import compile_compatible_method_lru_cache
- from ...utils import (
- ModelOutput,
- TransformersKwargs,
- auto_docstring,
- logging,
- )
- from ...utils.generic import can_return_tuple, merge_with_config_defaults
- from ...utils.output_capturing import capture_outputs
- from .configuration_detr import DetrConfig
- logger = logging.get_logger(__name__)
- @dataclass
- @auto_docstring(
- custom_intro="""
- Base class for outputs of the DETR decoder. This class adds one attribute to BaseModelOutputWithCrossAttentions,
- namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
- gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
- """
- )
- class DetrDecoderOutput(BaseModelOutputWithCrossAttentions):
- r"""
- cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
- used to compute the weighted average in the cross-attention heads.
- intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
- Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
- layernorm.
- """
- intermediate_hidden_states: torch.FloatTensor | None = None
- @dataclass
- @auto_docstring(
- custom_intro="""
- Base class for outputs of the DETR encoder-decoder model. This class adds one attribute to Seq2SeqModelOutput,
- namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
- gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
- """
- )
- class DetrModelOutput(Seq2SeqModelOutput):
- r"""
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Sequence of hidden-states at the output of the last layer of the decoder of the model.
- intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, sequence_length, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
- Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
- layernorm.
- """
- intermediate_hidden_states: torch.FloatTensor | None = None
- @dataclass
- @auto_docstring(
- custom_intro="""
- Output type of [`DetrForObjectDetection`].
- """
- )
- class DetrObjectDetectionOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
- Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
- bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
- scale-invariant IoU loss.
- loss_dict (`Dict`, *optional*):
- A dictionary containing the individual losses. Useful for logging.
- logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
- Classification logits (including no-object) for all queries.
- pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
- Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
- values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
- possible padding). You can use [`~DetrImageProcessor.post_process_object_detection`] to retrieve the
- unnormalized bounding boxes.
- auxiliary_outputs (`list[Dict]`, *optional*):
- Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
- and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
- `pred_boxes`) for each decoder layer.
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Sequence of hidden-states at the output of the last layer of the decoder of the model.
- """
- loss: torch.FloatTensor | None = None
- loss_dict: dict | None = None
- logits: torch.FloatTensor | None = None
- pred_boxes: torch.FloatTensor | None = None
- auxiliary_outputs: list[dict] | None = None
- last_hidden_state: torch.FloatTensor | None = None
- decoder_hidden_states: tuple[torch.FloatTensor] | None = None
- decoder_attentions: tuple[torch.FloatTensor] | None = None
- cross_attentions: tuple[torch.FloatTensor] | None = None
- encoder_last_hidden_state: torch.FloatTensor | None = None
- encoder_hidden_states: tuple[torch.FloatTensor] | None = None
- encoder_attentions: tuple[torch.FloatTensor] | None = None
- @dataclass
- @auto_docstring(
- custom_intro="""
- Output type of [`DetrForSegmentation`].
- """
- )
- class DetrSegmentationOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
- Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
- bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
- scale-invariant IoU loss.
- loss_dict (`Dict`, *optional*):
- A dictionary containing the individual losses. Useful for logging.
- logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
- Classification logits (including no-object) for all queries.
- pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
- Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
- values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
- possible padding). You can use [`~DetrImageProcessor.post_process_object_detection`] to retrieve the
- unnormalized bounding boxes.
- pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height/4, width/4)`):
- Segmentation masks logits for all queries. See also
- [`~DetrImageProcessor.post_process_semantic_segmentation`] or
- [`~DetrImageProcessor.post_process_instance_segmentation`]
- [`~DetrImageProcessor.post_process_panoptic_segmentation`] to evaluate semantic, instance and panoptic
- segmentation masks respectively.
- auxiliary_outputs (`list[Dict]`, *optional*):
- Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
- and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
- `pred_boxes`) for each decoder layer.
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Sequence of hidden-states at the output of the last layer of the decoder of the model.
- """
- loss: torch.FloatTensor | None = None
- loss_dict: dict | None = None
- logits: torch.FloatTensor | None = None
- pred_boxes: torch.FloatTensor | None = None
- pred_masks: torch.FloatTensor | None = None
- auxiliary_outputs: list[dict] | None = None
- last_hidden_state: torch.FloatTensor | None = None
- decoder_hidden_states: tuple[torch.FloatTensor] | None = None
- decoder_attentions: tuple[torch.FloatTensor] | None = None
- cross_attentions: tuple[torch.FloatTensor] | None = None
- encoder_last_hidden_state: torch.FloatTensor | None = None
- encoder_hidden_states: tuple[torch.FloatTensor] | None = None
- encoder_attentions: tuple[torch.FloatTensor] | None = None
- class DetrFrozenBatchNorm2d(nn.Module):
- """
- BatchNorm2d where the batch statistics and the affine parameters are fixed.
- Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
- torchvision.models.resnet[18,34,50,101] produce nans.
- """
- def __init__(self, n):
- super().__init__()
- self.register_buffer("weight", torch.ones(n))
- self.register_buffer("bias", torch.zeros(n))
- self.register_buffer("running_mean", torch.zeros(n))
- self.register_buffer("running_var", torch.ones(n))
- def _load_from_state_dict(
- self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
- ):
- num_batches_tracked_key = prefix + "num_batches_tracked"
- if num_batches_tracked_key in state_dict:
- del state_dict[num_batches_tracked_key]
- super()._load_from_state_dict(
- state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
- )
- def forward(self, x):
- # move reshapes to the beginning
- # to make it user-friendly
- weight = self.weight.reshape(1, -1, 1, 1)
- bias = self.bias.reshape(1, -1, 1, 1)
- running_var = self.running_var.reshape(1, -1, 1, 1)
- running_mean = self.running_mean.reshape(1, -1, 1, 1)
- epsilon = 1e-5
- scale = weight * (running_var + epsilon).rsqrt()
- bias = bias - running_mean * scale
- return x * scale + bias
- def replace_batch_norm(model):
- r"""
- Recursively replace all `torch.nn.BatchNorm2d` with `DetrFrozenBatchNorm2d`.
- Args:
- model (torch.nn.Module):
- input model
- """
- for name, module in model.named_children():
- if isinstance(module, nn.BatchNorm2d):
- new_module = DetrFrozenBatchNorm2d(module.num_features)
- if module.weight.device != torch.device("meta"):
- new_module.weight.copy_(module.weight)
- new_module.bias.copy_(module.bias)
- new_module.running_mean.copy_(module.running_mean)
- new_module.running_var.copy_(module.running_var)
- model._modules[name] = new_module
- if len(list(module.children())) > 0:
- replace_batch_norm(module)
- class DetrConvEncoder(nn.Module):
- """
- Convolutional backbone, using either the AutoBackbone API or one from the timm library.
- nn.BatchNorm2d layers are replaced by DetrFrozenBatchNorm2d as defined above.
- """
- def __init__(self, config):
- super().__init__()
- self.config = config
- backbone = load_backbone(config)
- self.intermediate_channel_sizes = backbone.channels
- # replace batch norm by frozen batch norm
- with torch.no_grad():
- replace_batch_norm(backbone)
- # We used to load with timm library directly instead of the AutoBackbone API
- # so we need to unwrap the `backbone._backbone` module to load weights without mismatch
- is_timm_model = False
- if hasattr(backbone, "_backbone"):
- backbone = backbone._backbone
- is_timm_model = True
- self.model = backbone
- backbone_model_type = config.backbone_config.model_type
- if "resnet" in backbone_model_type:
- for name, parameter in self.model.named_parameters():
- if is_timm_model:
- if "layer2" not in name and "layer3" not in name and "layer4" not in name:
- parameter.requires_grad_(False)
- else:
- if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name:
- parameter.requires_grad_(False)
- def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
- # send pixel_values through the model to get list of feature maps
- features = self.model(pixel_values)
- if isinstance(features, dict):
- features = features.feature_maps
- out = []
- for feature_map in features:
- # downsample pixel_mask to match shape of corresponding feature_map
- mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
- out.append((feature_map, mask))
- return out
- class DetrSinePositionEmbedding(nn.Module):
- """
- This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
- need paper, generalized to work on images.
- """
- def __init__(
- self,
- num_position_features: int = 64,
- temperature: int = 10000,
- normalize: bool = False,
- scale: float | None = None,
- ):
- super().__init__()
- if scale is not None and normalize is False:
- raise ValueError("normalize should be True if scale is passed")
- self.num_position_features = num_position_features
- self.temperature = temperature
- self.normalize = normalize
- self.scale = 2 * math.pi if scale is None else scale
- @compile_compatible_method_lru_cache(maxsize=1)
- def forward(
- self,
- shape: torch.Size,
- device: torch.device | str,
- dtype: torch.dtype,
- mask: torch.Tensor | None = None,
- ) -> torch.Tensor:
- if mask is None:
- mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
- y_embed = mask.cumsum(1, dtype=dtype)
- x_embed = mask.cumsum(2, dtype=dtype)
- if self.normalize:
- eps = 1e-6
- y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
- x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
- dim_t = torch.arange(self.num_position_features, dtype=torch.int64, device=device).to(dtype)
- dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_position_features)
- pos_x = x_embed[:, :, :, None] / dim_t
- pos_y = y_embed[:, :, :, None] / dim_t
- pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
- pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
- pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
- # Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
- # expected by the encoder
- pos = pos.flatten(2).permute(0, 2, 1)
- return pos
- class DetrLearnedPositionEmbedding(nn.Module):
- """
- This module learns positional embeddings up to a fixed maximum size.
- """
- def __init__(self, embedding_dim=256):
- super().__init__()
- self.row_embeddings = nn.Embedding(50, embedding_dim)
- self.column_embeddings = nn.Embedding(50, embedding_dim)
- @compile_compatible_method_lru_cache(maxsize=1)
- def forward(
- self,
- shape: torch.Size,
- device: torch.device | str,
- dtype: torch.dtype,
- mask: torch.Tensor | None = None,
- ):
- height, width = shape[-2:]
- width_values = torch.arange(width, device=device)
- height_values = torch.arange(height, device=device)
- x_emb = self.column_embeddings(width_values)
- y_emb = self.row_embeddings(height_values)
- pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
- pos = pos.permute(2, 0, 1)
- pos = pos.unsqueeze(0)
- pos = pos.repeat(shape[0], 1, 1, 1)
- # Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
- # expected by the encoder
- pos = pos.flatten(2).permute(0, 2, 1)
- return pos
- # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
- def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: torch.Tensor | None,
- scaling: float | None = None,
- dropout: float = 0.0,
- **kwargs: Unpack[TransformersKwargs],
- ):
- if scaling is None:
- scaling = query.size(-1) ** -0.5
- # Take the dot product between "query" and "key" to get the raw attention scores.
- attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
- if attention_mask is not None:
- attn_weights = attn_weights + attention_mask
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
- class DetrSelfAttention(nn.Module):
- """
- Multi-headed self-attention from 'Attention Is All You Need' paper.
- In DETR, position embeddings are added to both queries and keys (but not values) in self-attention.
- """
- def __init__(
- self,
- config: DetrConfig,
- hidden_size: int,
- num_attention_heads: int,
- dropout: float = 0.0,
- bias: bool = True,
- ):
- super().__init__()
- self.config = config
- self.head_dim = hidden_size // num_attention_heads
- self.scaling = self.head_dim**-0.5
- self.attention_dropout = dropout
- self.is_causal = False
- self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
- self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
- self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
- self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- position_embeddings: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Position embeddings are added to both queries and keys (but not values).
- """
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
- query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
- query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
- class DetrCrossAttention(nn.Module):
- """
- Multi-headed cross-attention from 'Attention Is All You Need' paper.
- In DETR, queries get their own position embeddings, while keys get encoder position embeddings.
- Values don't get any position embeddings.
- """
- def __init__(
- self,
- config: DetrConfig,
- hidden_size: int,
- num_attention_heads: int,
- dropout: float = 0.0,
- bias: bool = True,
- ):
- super().__init__()
- self.config = config
- self.head_dim = hidden_size // num_attention_heads
- self.scaling = self.head_dim**-0.5
- self.attention_dropout = dropout
- self.is_causal = False
- self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
- self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
- self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
- self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- position_embeddings: torch.Tensor | None = None,
- encoder_position_embeddings: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Position embeddings logic:
- - Queries get position_embeddings
- - Keys get encoder_position_embeddings
- - Values don't get any position embeddings
- """
- query_input_shape = hidden_states.shape[:-1]
- query_hidden_shape = (*query_input_shape, -1, self.head_dim)
- kv_input_shape = key_value_states.shape[:-1]
- kv_hidden_shape = (*kv_input_shape, -1, self.head_dim)
- query_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
- key_input = (
- key_value_states + encoder_position_embeddings
- if encoder_position_embeddings is not None
- else key_value_states
- )
- query_states = self.q_proj(query_input).view(query_hidden_shape).transpose(1, 2)
- key_states = self.k_proj(key_input).view(kv_hidden_shape).transpose(1, 2)
- value_states = self.v_proj(key_value_states).view(kv_hidden_shape).transpose(1, 2)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- **kwargs,
- )
- attn_output = attn_output.reshape(*query_input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
- class DetrMLP(nn.Module):
- def __init__(self, config: DetrConfig, hidden_size: int, intermediate_size: int):
- super().__init__()
- self.fc1 = nn.Linear(hidden_size, intermediate_size)
- self.fc2 = nn.Linear(intermediate_size, hidden_size)
- self.activation_fn = ACT2FN[config.activation_function]
- self.activation_dropout = config.activation_dropout
- self.dropout = config.dropout
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.activation_fn(self.fc1(hidden_states))
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
- hidden_states = self.fc2(hidden_states)
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- return hidden_states
- class DetrEncoderLayer(GradientCheckpointingLayer):
- def __init__(self, config: DetrConfig):
- super().__init__()
- self.hidden_size = config.d_model
- self.self_attn = DetrSelfAttention(
- config=config,
- hidden_size=self.hidden_size,
- num_attention_heads=config.encoder_attention_heads,
- dropout=config.attention_dropout,
- )
- self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
- self.dropout = config.dropout
- self.mlp = DetrMLP(config, self.hidden_size, config.encoder_ffn_dim)
- self.final_layer_norm = nn.LayerNorm(self.hidden_size)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- spatial_position_embeddings: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.Tensor:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
- attention_mask (`torch.FloatTensor`): attention mask of size
- `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
- values.
- spatial_position_embeddings (`torch.FloatTensor`, *optional*):
- Spatial position embeddings (2D positional encodings of image locations), to be added to both
- the queries and keys in self-attention (but not to values).
- """
- residual = hidden_states
- hidden_states, _ = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_embeddings=spatial_position_embeddings,
- **kwargs,
- )
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- hidden_states = self.self_attn_layer_norm(hidden_states)
- residual = hidden_states
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
- hidden_states = self.final_layer_norm(hidden_states)
- if self.training:
- if not torch.isfinite(hidden_states).all():
- clamp_value = torch.finfo(hidden_states.dtype).max - 1000
- hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
- return hidden_states
- class DetrDecoderLayer(GradientCheckpointingLayer):
- def __init__(self, config: DetrConfig):
- super().__init__()
- self.hidden_size = config.d_model
- self.self_attn = DetrSelfAttention(
- config=config,
- hidden_size=self.hidden_size,
- num_attention_heads=config.decoder_attention_heads,
- dropout=config.attention_dropout,
- )
- self.dropout = config.dropout
- self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
- self.encoder_attn = DetrCrossAttention(
- config=config,
- hidden_size=self.hidden_size,
- num_attention_heads=config.decoder_attention_heads,
- dropout=config.attention_dropout,
- )
- self.encoder_attn_layer_norm = nn.LayerNorm(self.hidden_size)
- self.mlp = DetrMLP(config, self.hidden_size, config.decoder_ffn_dim)
- self.final_layer_norm = nn.LayerNorm(self.hidden_size)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- spatial_position_embeddings: torch.Tensor | None = None,
- object_queries_position_embeddings: torch.Tensor | None = None,
- encoder_hidden_states: torch.Tensor | None = None,
- encoder_attention_mask: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.Tensor:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
- attention_mask (`torch.FloatTensor`): attention mask of size
- `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
- values.
- spatial_position_embeddings (`torch.FloatTensor`, *optional*):
- Spatial position embeddings (2D positional encodings from encoder) that are added to the keys only
- in the cross-attention layer (not to values).
- object_queries_position_embeddings (`torch.FloatTensor`, *optional*):
- Position embeddings for the object query slots. In self-attention, these are added to both queries
- and keys (not values). In cross-attention, these are added to queries only (not to keys or values).
- encoder_hidden_states (`torch.FloatTensor`):
- cross attention input to the layer of shape `(batch, seq_len, hidden_size)`
- encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
- `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
- values.
- """
- residual = hidden_states
- # Self Attention
- hidden_states, _ = self.self_attn(
- hidden_states=hidden_states,
- position_embeddings=object_queries_position_embeddings,
- attention_mask=attention_mask,
- **kwargs,
- )
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- hidden_states = self.self_attn_layer_norm(hidden_states)
- # Cross-Attention Block
- if encoder_hidden_states is not None:
- residual = hidden_states
- hidden_states, _ = self.encoder_attn(
- hidden_states=hidden_states,
- key_value_states=encoder_hidden_states,
- attention_mask=encoder_attention_mask,
- position_embeddings=object_queries_position_embeddings,
- encoder_position_embeddings=spatial_position_embeddings,
- **kwargs,
- )
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- hidden_states = self.encoder_attn_layer_norm(hidden_states)
- # Fully Connected
- residual = hidden_states
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
- hidden_states = self.final_layer_norm(hidden_states)
- return hidden_states
- class DetrConvBlock(nn.Module):
- """Basic conv block: Conv3x3 -> GroupNorm -> Activation."""
- def __init__(self, in_channels: int, out_channels: int, activation: str = "relu"):
- super().__init__()
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
- self.norm = nn.GroupNorm(min(8, out_channels), out_channels)
- self.activation = ACT2FN[activation]
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return self.activation(self.norm(self.conv(x)))
- class DetrFPNFusionStage(nn.Module):
- """Single FPN fusion stage combining low-resolution features with high-resolution FPN features."""
- def __init__(self, fpn_channels: int, current_channels: int, output_channels: int, activation: str = "relu"):
- super().__init__()
- self.fpn_adapter = nn.Conv2d(fpn_channels, current_channels, kernel_size=1)
- self.refine = DetrConvBlock(current_channels, output_channels, activation)
- def forward(self, features: torch.Tensor, fpn_features: torch.Tensor) -> torch.Tensor:
- """
- Args:
- features: Current features to upsample, shape (B*Q, current_channels, H_in, W_in)
- fpn_features: FPN features at target resolution, shape (B*Q, fpn_channels, H_out, W_out)
- Returns:
- Fused and refined features, shape (B*Q, output_channels, H_out, W_out)
- """
- fpn_features = self.fpn_adapter(fpn_features)
- features = nn.functional.interpolate(features, size=fpn_features.shape[-2:], mode="nearest")
- return self.refine(fpn_features + features)
- class DetrMaskHeadSmallConv(nn.Module):
- """
- Segmentation mask head that generates per-query masks using FPN-based progressive upsampling.
- Combines attention maps (spatial localization) with encoder features (semantics) and progressively
- upsamples through multiple scales, fusing with FPN features for high-resolution detail.
- """
- def __init__(
- self,
- input_channels: int,
- fpn_channels: list[int],
- hidden_size: int,
- activation_function: str = "relu",
- ):
- super().__init__()
- if input_channels % 8 != 0:
- raise ValueError(f"input_channels must be divisible by 8, got {input_channels}")
- self.conv1 = DetrConvBlock(input_channels, input_channels, activation_function)
- self.conv2 = DetrConvBlock(input_channels, hidden_size // 2, activation_function)
- # Progressive channel reduction: /2 -> /4 -> /8 -> /16
- self.fpn_stages = nn.ModuleList(
- [
- DetrFPNFusionStage(fpn_channels[0], hidden_size // 2, hidden_size // 4, activation_function),
- DetrFPNFusionStage(fpn_channels[1], hidden_size // 4, hidden_size // 8, activation_function),
- DetrFPNFusionStage(fpn_channels[2], hidden_size // 8, hidden_size // 16, activation_function),
- ]
- )
- self.output_conv = nn.Conv2d(hidden_size // 16, 1, kernel_size=3, padding=1)
- def forward(
- self,
- features: torch.Tensor,
- attention_masks: torch.Tensor,
- fpn_features: list[torch.Tensor],
- ) -> torch.Tensor:
- """
- Args:
- features: Encoder output features, shape (batch_size, hidden_size, H, W)
- attention_masks: Cross-attention maps from decoder, shape (batch_size, num_queries, num_heads, H, W)
- fpn_features: List of 3 FPN features from low to high resolution, each (batch_size, C, H, W)
- Returns:
- Predicted masks, shape (batch_size * num_queries, 1, output_H, output_W)
- """
- num_queries = attention_masks.shape[1]
- # Expand to (batch_size * num_queries) dimension
- features = features.unsqueeze(1).expand(-1, num_queries, -1, -1, -1).flatten(0, 1)
- attention_masks = attention_masks.flatten(0, 1)
- fpn_features = [
- fpn_feat.unsqueeze(1).expand(-1, num_queries, -1, -1, -1).flatten(0, 1) for fpn_feat in fpn_features
- ]
- hidden_states = torch.cat([features, attention_masks], dim=1)
- hidden_states = self.conv1(hidden_states)
- hidden_states = self.conv2(hidden_states)
- for fpn_stage, fpn_feat in zip(self.fpn_stages, fpn_features):
- hidden_states = fpn_stage(hidden_states, fpn_feat)
- return self.output_conv(hidden_states)
- class DetrMHAttentionMap(nn.Module):
- """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
- def __init__(
- self,
- hidden_size: int,
- num_attention_heads: int,
- dropout: float = 0.0,
- bias: bool = True,
- ):
- super().__init__()
- self.head_dim = hidden_size // num_attention_heads
- self.scaling = self.head_dim**-0.5
- self.attention_dropout = dropout
- self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
- self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
- def forward(
- self, query_states: torch.Tensor, key_states: torch.Tensor, attention_mask: torch.Tensor | None = None
- ):
- query_hidden_shape = (*query_states.shape[:-1], -1, self.head_dim)
- key_hidden_shape = (key_states.shape[0], -1, self.head_dim, *key_states.shape[-2:])
- query_states = self.q_proj(query_states).view(query_hidden_shape)
- key_states = nn.functional.conv2d(
- key_states, self.k_proj.weight.unsqueeze(-1).unsqueeze(-1), self.k_proj.bias
- ).view(key_hidden_shape)
- batch_size, num_queries, num_heads, head_dim = query_states.shape
- _, _, _, height, width = key_states.shape
- query_shape = (batch_size * num_heads, num_queries, head_dim)
- key_shape = (batch_size * num_heads, height * width, head_dim)
- attn_weights_shape = (batch_size, num_heads, num_queries, height, width)
- query = query_states.transpose(1, 2).contiguous().view(query_shape)
- key = key_states.permute(0, 1, 3, 4, 2).contiguous().view(key_shape)
- attn_weights = (
- (torch.matmul(query * self.scaling, key.transpose(1, 2))).view(attn_weights_shape).transpose(1, 2)
- )
- if attention_mask is not None:
- attn_weights = attn_weights + attention_mask
- attn_weights = nn.functional.softmax(attn_weights.flatten(2), dim=-1).view(attn_weights.size())
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
- return attn_weights
- @auto_docstring
- class DetrPreTrainedModel(PreTrainedModel):
- config: DetrConfig
- base_model_prefix = "model"
- main_input_name = "pixel_values"
- input_modalities = ("image",)
- _no_split_modules = [r"DetrConvEncoder", r"DetrEncoderLayer", r"DetrDecoderLayer"]
- supports_gradient_checkpointing = True
- _supports_sdpa = True
- _supports_flash_attn = True
- _supports_attention_backend = True
- _supports_flex_attn = True # Uses create_bidirectional_masks for attention masking
- _keys_to_ignore_on_load_unexpected = [
- r"detr\.model\.backbone\.model\.layer\d+\.0\.downsample\.1\.num_batches_tracked"
- ]
- @torch.no_grad()
- def _init_weights(self, module):
- std = self.config.init_std
- xavier_std = self.config.init_xavier_std
- if isinstance(module, DetrMaskHeadSmallConv):
- # DetrMaskHeadSmallConv uses kaiming initialization for all its Conv2d layers
- for m in module.modules():
- if isinstance(m, nn.Conv2d):
- init.kaiming_uniform_(m.weight, a=1)
- if m.bias is not None:
- init.constant_(m.bias, 0)
- elif isinstance(module, DetrMHAttentionMap):
- init.zeros_(module.k_proj.bias)
- init.zeros_(module.q_proj.bias)
- init.xavier_uniform_(module.k_proj.weight, gain=xavier_std)
- init.xavier_uniform_(module.q_proj.weight, gain=xavier_std)
- elif isinstance(module, DetrLearnedPositionEmbedding):
- init.uniform_(module.row_embeddings.weight)
- init.uniform_(module.column_embeddings.weight)
- elif isinstance(module, (nn.Linear, nn.Conv2d)):
- init.normal_(module.weight, mean=0.0, std=std)
- if module.bias is not None:
- init.zeros_(module.bias)
- elif isinstance(module, nn.Embedding):
- init.normal_(module.weight, mean=0.0, std=std)
- # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
- if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
- init.zeros_(module.weight[module.padding_idx])
- elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
- init.ones_(module.weight)
- init.zeros_(module.bias)
- class DetrEncoder(DetrPreTrainedModel):
- """
- Transformer encoder that processes a flattened feature map from a vision backbone, composed of a stack of
- [`DetrEncoderLayer`] modules.
- Args:
- config (`DetrConfig`): Model configuration object.
- """
- _can_record_outputs = {"hidden_states": DetrEncoderLayer, "attentions": DetrSelfAttention}
- def __init__(self, config: DetrConfig):
- super().__init__(config)
- self.dropout = config.dropout
- self.layers = nn.ModuleList([DetrEncoderLayer(config) for _ in range(config.encoder_layers)])
- # Initialize weights and apply final processing
- self.post_init()
- @merge_with_config_defaults
- @capture_outputs
- def forward(
- self,
- inputs_embeds=None,
- attention_mask=None,
- spatial_position_embeddings=None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutput:
- r"""
- Args:
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
- - 1 for pixel features that are real (i.e. **not masked**),
- - 0 for pixel features that are padding (i.e. **masked**).
- [What are attention masks?](../glossary#attention-mask)
- spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Spatial position embeddings (2D positional encodings) that are added to the queries and keys in each self-attention layer.
- """
- hidden_states = inputs_embeds
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- attention_mask = create_bidirectional_mask(
- config=self.config,
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- )
- for encoder_layer in self.layers:
- # we add spatial_position_embeddings as extra input to the encoder_layer
- hidden_states = encoder_layer(
- hidden_states, attention_mask, spatial_position_embeddings=spatial_position_embeddings, **kwargs
- )
- return BaseModelOutput(last_hidden_state=hidden_states)
- class DetrDecoder(DetrPreTrainedModel):
- """
- Transformer decoder that refines a set of object queries. It is composed of a stack of [`DetrDecoderLayer`] modules,
- which apply self-attention to the queries and cross-attention to the encoder's outputs.
- Args:
- config (`DetrConfig`): Model configuration object.
- """
- _can_record_outputs = {
- "hidden_states": DetrDecoderLayer,
- "attentions": DetrSelfAttention,
- "cross_attentions": DetrCrossAttention,
- }
- def __init__(self, config: DetrConfig):
- super().__init__(config)
- self.dropout = config.dropout
- self.layers = nn.ModuleList([DetrDecoderLayer(config) for _ in range(config.decoder_layers)])
- # in DETR, the decoder uses layernorm after the last decoder layer output
- self.layernorm = nn.LayerNorm(config.d_model)
- # Initialize weights and apply final processing
- self.post_init()
- @merge_with_config_defaults
- @capture_outputs
- def forward(
- self,
- inputs_embeds=None,
- attention_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- spatial_position_embeddings=None,
- object_queries_position_embeddings=None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> DetrDecoderOutput:
- r"""
- Args:
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- The query embeddings that are passed into the decoder.
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on certain queries. Mask values selected in `[0, 1]`:
- - 1 for queries that are **not masked**,
- - 0 for queries that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
- Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
- of the decoder.
- encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
- Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected
- in `[0, 1]`:
- - 1 for pixels that are real (i.e. **not masked**),
- - 0 for pixels that are padding (i.e. **masked**).
- spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Spatial position embeddings (2D positional encodings from encoder) that are added to the keys in each cross-attention layer.
- object_queries_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
- Position embeddings for the object query slots that are added to the queries and keys in each self-attention layer.
- """
- if inputs_embeds is not None:
- hidden_states = inputs_embeds
- if attention_mask is not None:
- attention_mask = create_bidirectional_mask(
- config=self.config,
- inputs_embeds=hidden_states,
- attention_mask=attention_mask,
- )
- # expand encoder attention mask (for cross-attention on encoder outputs)
- if encoder_hidden_states is not None and encoder_attention_mask is not None:
- encoder_attention_mask = create_bidirectional_mask(
- config=self.config,
- inputs_embeds=hidden_states,
- attention_mask=encoder_attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- )
- # optional intermediate hidden states
- intermediate = () if self.config.auxiliary_loss else None
- # decoder layers
- for idx, decoder_layer in enumerate(self.layers):
- hidden_states = decoder_layer(
- hidden_states,
- attention_mask,
- spatial_position_embeddings,
- object_queries_position_embeddings,
- encoder_hidden_states, # as a positional argument for gradient checkpointing
- encoder_attention_mask=encoder_attention_mask,
- **kwargs,
- )
- if self.config.auxiliary_loss:
- hidden_states = self.layernorm(hidden_states)
- intermediate += (hidden_states,)
- # finally, apply layernorm
- hidden_states = self.layernorm(hidden_states)
- # stack intermediate decoder activations
- if self.config.auxiliary_loss:
- intermediate = torch.stack(intermediate)
- return DetrDecoderOutput(last_hidden_state=hidden_states, intermediate_hidden_states=intermediate)
- @auto_docstring(
- custom_intro="""
- The bare DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw hidden-states without
- any specific head on top.
- """
- )
- class DetrModel(DetrPreTrainedModel):
- def __init__(self, config: DetrConfig):
- super().__init__(config)
- self.backbone = DetrConvEncoder(config)
- if config.position_embedding_type == "sine":
- self.position_embedding = DetrSinePositionEmbedding(config.d_model // 2, normalize=True)
- elif config.position_embedding_type == "learned":
- self.position_embedding = DetrLearnedPositionEmbedding(config.d_model // 2)
- else:
- raise ValueError(f"Not supported {config.position_embedding_type}")
- self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model)
- self.input_projection = nn.Conv2d(self.backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
- self.encoder = DetrEncoder(config)
- self.decoder = DetrDecoder(config)
- # Initialize weights and apply final processing
- self.post_init()
- def freeze_backbone(self):
- for _, param in self.backbone.model.named_parameters():
- param.requires_grad_(False)
- def unfreeze_backbone(self):
- for _, param in self.backbone.model.named_parameters():
- param.requires_grad_(True)
- @auto_docstring
- @can_return_tuple
- def forward(
- self,
- pixel_values: torch.FloatTensor | None = None,
- pixel_mask: torch.LongTensor | None = None,
- decoder_attention_mask: torch.FloatTensor | None = None,
- encoder_outputs: torch.FloatTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- decoder_inputs_embeds: torch.FloatTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.FloatTensor] | DetrModelOutput:
- r"""
- decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
- Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
- - 1 for queries that are **not masked**,
- - 0 for queries that are **masked**.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
- can choose to directly pass a flattened representation of an image. Useful for bypassing the vision backbone.
- decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
- Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
- embedded representation. Useful for tasks that require custom query initialization.
- Examples:
- ```python
- >>> from transformers import AutoImageProcessor, DetrModel
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
- >>> model = DetrModel.from_pretrained("facebook/detr-resnet-50")
- >>> # prepare image for the model
- >>> inputs = image_processor(images=image, return_tensors="pt")
- >>> # forward pass
- >>> outputs = model(**inputs)
- >>> # the last hidden states are the final query embeddings of the Transformer decoder
- >>> # these are of shape (batch_size, num_queries, hidden_size)
- >>> last_hidden_states = outputs.last_hidden_state
- >>> list(last_hidden_states.shape)
- [1, 100, 256]
- ```"""
- if pixel_values is None and inputs_embeds is None:
- raise ValueError("You have to specify either pixel_values or inputs_embeds")
- if inputs_embeds is None:
- batch_size, num_channels, height, width = pixel_values.shape
- device = pixel_values.device
- if pixel_mask is None:
- pixel_mask = torch.ones(((batch_size, height, width)), device=device)
- vision_features = self.backbone(pixel_values, pixel_mask)
- feature_map, mask = vision_features[-1]
- # Apply 1x1 conv to map (batch_size, C, H, W) -> (batch_size, hidden_size, H, W), then flatten to (batch_size, HW, hidden_size)
- # Position embeddings are already flattened to (batch_size, sequence_length, hidden_size) format
- projected_feature_map = self.input_projection(feature_map)
- flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
- spatial_position_embeddings = self.position_embedding(
- shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask
- )
- flattened_mask = mask.flatten(1)
- else:
- batch_size = inputs_embeds.shape[0]
- device = inputs_embeds.device
- flattened_features = inputs_embeds
- # When using inputs_embeds, we need to infer spatial dimensions for position embeddings
- # Assume square feature map
- seq_len = inputs_embeds.shape[1]
- feat_dim = int(seq_len**0.5)
- # Create position embeddings for the inferred spatial size
- spatial_position_embeddings = self.position_embedding(
- shape=torch.Size([batch_size, self.config.d_model, feat_dim, feat_dim]),
- device=device,
- dtype=inputs_embeds.dtype,
- )
- # If a pixel_mask is provided with inputs_embeds, interpolate it to feat_dim, then flatten.
- if pixel_mask is not None:
- mask = nn.functional.interpolate(pixel_mask[None].float(), size=(feat_dim, feat_dim)).to(torch.bool)[0]
- flattened_mask = mask.flatten(1)
- else:
- # If no mask provided, assume all positions are valid
- flattened_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.long)
- if encoder_outputs is None:
- encoder_outputs = self.encoder(
- inputs_embeds=flattened_features,
- attention_mask=flattened_mask,
- spatial_position_embeddings=spatial_position_embeddings,
- **kwargs,
- )
- object_queries_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(
- batch_size, 1, 1
- )
- # Use decoder_inputs_embeds as queries if provided, otherwise initialize with zeros
- if decoder_inputs_embeds is not None:
- queries = decoder_inputs_embeds
- else:
- queries = torch.zeros_like(object_queries_position_embeddings)
- # decoder outputs consists of (dec_features, dec_hidden, dec_attn)
- decoder_outputs = self.decoder(
- inputs_embeds=queries,
- attention_mask=decoder_attention_mask,
- spatial_position_embeddings=spatial_position_embeddings,
- object_queries_position_embeddings=object_queries_position_embeddings,
- encoder_hidden_states=encoder_outputs.last_hidden_state,
- encoder_attention_mask=flattened_mask,
- **kwargs,
- )
- return DetrModelOutput(
- last_hidden_state=decoder_outputs.last_hidden_state,
- decoder_hidden_states=decoder_outputs.hidden_states,
- decoder_attentions=decoder_outputs.attentions,
- cross_attentions=decoder_outputs.cross_attentions,
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
- encoder_hidden_states=encoder_outputs.hidden_states,
- encoder_attentions=encoder_outputs.attentions,
- intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
- )
- class DetrMLPPredictionHead(nn.Module):
- """
- Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
- height and width of a bounding box w.r.t. an image.
- """
- def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
- super().__init__()
- self.num_layers = num_layers
- h = [hidden_dim] * (num_layers - 1)
- self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
- def forward(self, x):
- for i, layer in enumerate(self.layers):
- x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
- return x
- @auto_docstring(
- custom_intro="""
- DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on top, for tasks
- such as COCO detection.
- """
- )
- class DetrForObjectDetection(DetrPreTrainedModel):
- def __init__(self, config: DetrConfig):
- super().__init__(config)
- # DETR encoder-decoder model
- self.model = DetrModel(config)
- # Object detection heads
- self.class_labels_classifier = nn.Linear(
- config.d_model, config.num_labels + 1
- ) # We add one for the "no object" class
- self.bbox_predictor = DetrMLPPredictionHead(
- input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
- )
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- @can_return_tuple
- def forward(
- self,
- pixel_values: torch.FloatTensor,
- pixel_mask: torch.LongTensor | None = None,
- decoder_attention_mask: torch.FloatTensor | None = None,
- encoder_outputs: torch.FloatTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- decoder_inputs_embeds: torch.FloatTensor | None = None,
- labels: list[dict] | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.FloatTensor] | DetrObjectDetectionOutput:
- r"""
- decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
- Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
- - 1 for queries that are **not masked**,
- - 0 for queries that are **masked**.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
- can choose to directly pass a flattened representation of an image. Useful for bypassing the vision backbone.
- decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
- Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
- embedded representation. Useful for tasks that require custom query initialization.
- labels (`list[Dict]` of len `(batch_size,)`, *optional*):
- Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
- following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
- respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
- in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
- Examples:
- ```python
- >>> from transformers import AutoImageProcessor, DetrForObjectDetection
- >>> import torch
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
- >>> model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
- >>> inputs = image_processor(images=image, return_tensors="pt")
- >>> outputs = model(**inputs)
- >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
- >>> target_sizes = torch.tensor([image.size[::-1]])
- >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[
- ... 0
- ... ]
- >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
- ... box = [round(i, 2) for i in box.tolist()]
- ... print(
- ... f"Detected {model.config.id2label[label.item()]} with confidence "
- ... f"{round(score.item(), 3)} at location {box}"
- ... )
- Detected remote with confidence 0.998 at location [40.16, 70.81, 175.55, 117.98]
- Detected remote with confidence 0.996 at location [333.24, 72.55, 368.33, 187.66]
- Detected couch with confidence 0.995 at location [-0.02, 1.15, 639.73, 473.76]
- Detected cat with confidence 0.999 at location [13.24, 52.05, 314.02, 470.93]
- Detected cat with confidence 0.999 at location [345.4, 23.85, 640.37, 368.72]
- ```"""
- # First, sent images through DETR base model to obtain encoder + decoder outputs
- outputs = self.model(
- pixel_values,
- pixel_mask=pixel_mask,
- decoder_attention_mask=decoder_attention_mask,
- encoder_outputs=encoder_outputs,
- inputs_embeds=inputs_embeds,
- decoder_inputs_embeds=decoder_inputs_embeds,
- **kwargs,
- )
- sequence_output = outputs[0]
- # class logits + predicted bounding boxes
- logits = self.class_labels_classifier(sequence_output)
- pred_boxes = self.bbox_predictor(sequence_output).sigmoid()
- loss, loss_dict, auxiliary_outputs = None, None, None
- if labels is not None:
- outputs_class, outputs_coord = None, None
- if self.config.auxiliary_loss:
- intermediate = outputs.intermediate_hidden_states
- outputs_class = self.class_labels_classifier(intermediate)
- outputs_coord = self.bbox_predictor(intermediate).sigmoid()
- loss, loss_dict, auxiliary_outputs = self.loss_function(
- logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
- )
- return DetrObjectDetectionOutput(
- loss=loss,
- loss_dict=loss_dict,
- logits=logits,
- pred_boxes=pred_boxes,
- auxiliary_outputs=auxiliary_outputs,
- last_hidden_state=outputs.last_hidden_state,
- decoder_hidden_states=outputs.decoder_hidden_states,
- decoder_attentions=outputs.decoder_attentions,
- cross_attentions=outputs.cross_attentions,
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
- encoder_hidden_states=outputs.encoder_hidden_states,
- encoder_attentions=outputs.encoder_attentions,
- )
- @auto_docstring(
- custom_intro="""
- DETR Model (consisting of a backbone and encoder-decoder Transformer) with a segmentation head on top, for tasks
- such as COCO panoptic.
- """
- )
- class DetrForSegmentation(DetrPreTrainedModel):
- def __init__(self, config: DetrConfig):
- super().__init__(config)
- # object detection model
- self.detr = DetrForObjectDetection(config)
- # segmentation head
- hidden_size, number_of_heads = config.d_model, config.encoder_attention_heads
- intermediate_channel_sizes = self.detr.model.backbone.intermediate_channel_sizes
- self.mask_head = DetrMaskHeadSmallConv(
- input_channels=hidden_size + number_of_heads,
- fpn_channels=intermediate_channel_sizes[::-1][-3:],
- hidden_size=hidden_size,
- activation_function=config.activation_function,
- )
- self.bbox_attention = DetrMHAttentionMap(hidden_size, number_of_heads, dropout=0.0)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- @can_return_tuple
- def forward(
- self,
- pixel_values: torch.FloatTensor,
- pixel_mask: torch.LongTensor | None = None,
- decoder_attention_mask: torch.FloatTensor | None = None,
- encoder_outputs: torch.FloatTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- decoder_inputs_embeds: torch.FloatTensor | None = None,
- labels: list[dict] | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.FloatTensor] | DetrSegmentationOutput:
- r"""
- decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
- Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
- - 1 for queries that are **not masked**,
- - 0 for queries that are **masked**.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Kept for backward compatibility, but cannot be used for segmentation, as segmentation requires
- multi-scale features from the backbone that are not available when bypassing it with inputs_embeds.
- decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
- Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
- embedded representation. Useful for tasks that require custom query initialization.
- labels (`list[Dict]` of len `(batch_size,)`, *optional*):
- Labels for computing the bipartite matching loss, DICE/F-1 loss and Focal loss. List of dicts, each
- dictionary containing at least the following 3 keys: 'class_labels', 'boxes' and 'masks' (the class labels,
- bounding boxes and segmentation masks of an image in the batch respectively). The class labels themselves
- should be a `torch.LongTensor` of len `(number of bounding boxes in the image,)`, the boxes a
- `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)` and the masks a
- `torch.FloatTensor` of shape `(number of bounding boxes in the image, height, width)`.
- Examples:
- ```python
- >>> import io
- >>> import httpx
- >>> from io import BytesIO
- >>> from PIL import Image
- >>> import torch
- >>> import numpy
- >>> from transformers import AutoImageProcessor, DetrForSegmentation
- >>> from transformers.image_transforms import rgb_to_id
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50-panoptic")
- >>> model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50-panoptic")
- >>> # prepare image for the model
- >>> inputs = image_processor(images=image, return_tensors="pt")
- >>> # forward pass
- >>> outputs = model(**inputs)
- >>> # Use the `post_process_panoptic_segmentation` method of the `image_processor` to retrieve post-processed panoptic segmentation maps
- >>> # Segmentation results are returned as a list of dictionaries
- >>> result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[(300, 500)])
- >>> # A tensor of shape (height, width) where each value denotes a segment id, filled with -1 if no segment is found
- >>> panoptic_seg = result[0]["segmentation"]
- >>> panoptic_seg.shape
- torch.Size([300, 500])
- >>> # Get prediction score and segment_id to class_id mapping of each segment
- >>> panoptic_segments_info = result[0]["segments_info"]
- >>> len(panoptic_segments_info)
- 5
- ```"""
- batch_size, num_channels, height, width = pixel_values.shape
- device = pixel_values.device
- if pixel_mask is None:
- pixel_mask = torch.ones((batch_size, height, width), device=device)
- vision_features = self.detr.model.backbone(pixel_values, pixel_mask)
- feature_map, mask = vision_features[-1]
- # Apply 1x1 conv to map (batch_size, C, H, W) -> (batch_size, hidden_size, H, W), then flatten to (batch_size, HW, hidden_size)
- projected_feature_map = self.detr.model.input_projection(feature_map)
- flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
- spatial_position_embeddings = self.detr.model.position_embedding(
- shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask
- )
- flattened_mask = mask.flatten(1)
- if encoder_outputs is None:
- encoder_outputs = self.detr.model.encoder(
- inputs_embeds=flattened_features,
- attention_mask=flattened_mask,
- spatial_position_embeddings=spatial_position_embeddings,
- **kwargs,
- )
- object_queries_position_embeddings = self.detr.model.query_position_embeddings.weight.unsqueeze(0).repeat(
- batch_size, 1, 1
- )
- # Use decoder_inputs_embeds as queries if provided, otherwise initialize with zeros
- if decoder_inputs_embeds is not None:
- queries = decoder_inputs_embeds
- else:
- queries = torch.zeros_like(object_queries_position_embeddings)
- decoder_outputs = self.detr.model.decoder(
- inputs_embeds=queries,
- attention_mask=decoder_attention_mask,
- spatial_position_embeddings=spatial_position_embeddings,
- object_queries_position_embeddings=object_queries_position_embeddings,
- encoder_hidden_states=encoder_outputs.last_hidden_state,
- encoder_attention_mask=flattened_mask,
- **kwargs,
- )
- sequence_output = decoder_outputs[0]
- logits = self.detr.class_labels_classifier(sequence_output)
- pred_boxes = self.detr.bbox_predictor(sequence_output).sigmoid()
- height, width = feature_map.shape[-2:]
- memory = encoder_outputs.last_hidden_state.permute(0, 2, 1).view(
- batch_size, self.config.d_model, height, width
- )
- attention_mask = flattened_mask.view(batch_size, height, width)
- if attention_mask is not None:
- min_dtype = torch.finfo(memory.dtype).min
- attention_mask = torch.where(
- attention_mask.unsqueeze(1).unsqueeze(1),
- torch.tensor(0.0, device=memory.device, dtype=memory.dtype),
- min_dtype,
- )
- bbox_mask = self.bbox_attention(sequence_output, memory, attention_mask=attention_mask)
- seg_masks = self.mask_head(
- features=projected_feature_map,
- attention_masks=bbox_mask,
- fpn_features=[vision_features[2][0], vision_features[1][0], vision_features[0][0]],
- )
- pred_masks = seg_masks.view(batch_size, self.detr.config.num_queries, seg_masks.shape[-2], seg_masks.shape[-1])
- loss, loss_dict, auxiliary_outputs = None, None, None
- if labels is not None:
- outputs_class, outputs_coord = None, None
- if self.config.auxiliary_loss:
- intermediate = decoder_outputs.intermediate_hidden_states
- outputs_class = self.detr.class_labels_classifier(intermediate)
- outputs_coord = self.detr.bbox_predictor(intermediate).sigmoid()
- loss, loss_dict, auxiliary_outputs = self.loss_function(
- logits, labels, device, pred_boxes, pred_masks, self.config, outputs_class, outputs_coord
- )
- return DetrSegmentationOutput(
- loss=loss,
- loss_dict=loss_dict,
- logits=logits,
- pred_boxes=pred_boxes,
- pred_masks=pred_masks,
- auxiliary_outputs=auxiliary_outputs,
- last_hidden_state=decoder_outputs.last_hidden_state,
- decoder_hidden_states=decoder_outputs.hidden_states,
- decoder_attentions=decoder_outputs.attentions,
- cross_attentions=decoder_outputs.cross_attentions,
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
- encoder_hidden_states=encoder_outputs.hidden_states,
- encoder_attentions=encoder_outputs.attentions,
- )
- __all__ = [
- "DetrForObjectDetection",
- "DetrForSegmentation",
- "DetrModel",
- "DetrPreTrainedModel",
- ]
|