modeling_detr.py 71 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639
  1. # Copyright 2021 Facebook AI Research The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch DETR model."""
  15. import math
  16. from collections.abc import Callable
  17. from dataclasses import dataclass
  18. import torch
  19. import torch.nn as nn
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...backbone_utils import load_backbone
  23. from ...masking_utils import create_bidirectional_mask
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import (
  26. BaseModelOutput,
  27. BaseModelOutputWithCrossAttentions,
  28. Seq2SeqModelOutput,
  29. )
  30. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  31. from ...processing_utils import Unpack
  32. from ...pytorch_utils import compile_compatible_method_lru_cache
  33. from ...utils import (
  34. ModelOutput,
  35. TransformersKwargs,
  36. auto_docstring,
  37. logging,
  38. )
  39. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  40. from ...utils.output_capturing import capture_outputs
  41. from .configuration_detr import DetrConfig
  42. logger = logging.get_logger(__name__)
  43. @dataclass
  44. @auto_docstring(
  45. custom_intro="""
  46. Base class for outputs of the DETR decoder. This class adds one attribute to BaseModelOutputWithCrossAttentions,
  47. namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
  48. gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
  49. """
  50. )
  51. class DetrDecoderOutput(BaseModelOutputWithCrossAttentions):
  52. r"""
  53. cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
  54. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  55. sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
  56. used to compute the weighted average in the cross-attention heads.
  57. intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
  58. Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
  59. layernorm.
  60. """
  61. intermediate_hidden_states: torch.FloatTensor | None = None
  62. @dataclass
  63. @auto_docstring(
  64. custom_intro="""
  65. Base class for outputs of the DETR encoder-decoder model. This class adds one attribute to Seq2SeqModelOutput,
  66. namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
  67. gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
  68. """
  69. )
  70. class DetrModelOutput(Seq2SeqModelOutput):
  71. r"""
  72. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  73. Sequence of hidden-states at the output of the last layer of the decoder of the model.
  74. intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, sequence_length, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
  75. Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
  76. layernorm.
  77. """
  78. intermediate_hidden_states: torch.FloatTensor | None = None
  79. @dataclass
  80. @auto_docstring(
  81. custom_intro="""
  82. Output type of [`DetrForObjectDetection`].
  83. """
  84. )
  85. class DetrObjectDetectionOutput(ModelOutput):
  86. r"""
  87. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
  88. Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
  89. bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
  90. scale-invariant IoU loss.
  91. loss_dict (`Dict`, *optional*):
  92. A dictionary containing the individual losses. Useful for logging.
  93. logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
  94. Classification logits (including no-object) for all queries.
  95. pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
  96. Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
  97. values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
  98. possible padding). You can use [`~DetrImageProcessor.post_process_object_detection`] to retrieve the
  99. unnormalized bounding boxes.
  100. auxiliary_outputs (`list[Dict]`, *optional*):
  101. Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
  102. and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
  103. `pred_boxes`) for each decoder layer.
  104. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  105. Sequence of hidden-states at the output of the last layer of the decoder of the model.
  106. """
  107. loss: torch.FloatTensor | None = None
  108. loss_dict: dict | None = None
  109. logits: torch.FloatTensor | None = None
  110. pred_boxes: torch.FloatTensor | None = None
  111. auxiliary_outputs: list[dict] | None = None
  112. last_hidden_state: torch.FloatTensor | None = None
  113. decoder_hidden_states: tuple[torch.FloatTensor] | None = None
  114. decoder_attentions: tuple[torch.FloatTensor] | None = None
  115. cross_attentions: tuple[torch.FloatTensor] | None = None
  116. encoder_last_hidden_state: torch.FloatTensor | None = None
  117. encoder_hidden_states: tuple[torch.FloatTensor] | None = None
  118. encoder_attentions: tuple[torch.FloatTensor] | None = None
  119. @dataclass
  120. @auto_docstring(
  121. custom_intro="""
  122. Output type of [`DetrForSegmentation`].
  123. """
  124. )
  125. class DetrSegmentationOutput(ModelOutput):
  126. r"""
  127. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
  128. Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
  129. bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
  130. scale-invariant IoU loss.
  131. loss_dict (`Dict`, *optional*):
  132. A dictionary containing the individual losses. Useful for logging.
  133. logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
  134. Classification logits (including no-object) for all queries.
  135. pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
  136. Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
  137. values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
  138. possible padding). You can use [`~DetrImageProcessor.post_process_object_detection`] to retrieve the
  139. unnormalized bounding boxes.
  140. pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height/4, width/4)`):
  141. Segmentation masks logits for all queries. See also
  142. [`~DetrImageProcessor.post_process_semantic_segmentation`] or
  143. [`~DetrImageProcessor.post_process_instance_segmentation`]
  144. [`~DetrImageProcessor.post_process_panoptic_segmentation`] to evaluate semantic, instance and panoptic
  145. segmentation masks respectively.
  146. auxiliary_outputs (`list[Dict]`, *optional*):
  147. Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
  148. and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
  149. `pred_boxes`) for each decoder layer.
  150. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  151. Sequence of hidden-states at the output of the last layer of the decoder of the model.
  152. """
  153. loss: torch.FloatTensor | None = None
  154. loss_dict: dict | None = None
  155. logits: torch.FloatTensor | None = None
  156. pred_boxes: torch.FloatTensor | None = None
  157. pred_masks: torch.FloatTensor | None = None
  158. auxiliary_outputs: list[dict] | None = None
  159. last_hidden_state: torch.FloatTensor | None = None
  160. decoder_hidden_states: tuple[torch.FloatTensor] | None = None
  161. decoder_attentions: tuple[torch.FloatTensor] | None = None
  162. cross_attentions: tuple[torch.FloatTensor] | None = None
  163. encoder_last_hidden_state: torch.FloatTensor | None = None
  164. encoder_hidden_states: tuple[torch.FloatTensor] | None = None
  165. encoder_attentions: tuple[torch.FloatTensor] | None = None
  166. class DetrFrozenBatchNorm2d(nn.Module):
  167. """
  168. BatchNorm2d where the batch statistics and the affine parameters are fixed.
  169. Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
  170. torchvision.models.resnet[18,34,50,101] produce nans.
  171. """
  172. def __init__(self, n):
  173. super().__init__()
  174. self.register_buffer("weight", torch.ones(n))
  175. self.register_buffer("bias", torch.zeros(n))
  176. self.register_buffer("running_mean", torch.zeros(n))
  177. self.register_buffer("running_var", torch.ones(n))
  178. def _load_from_state_dict(
  179. self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
  180. ):
  181. num_batches_tracked_key = prefix + "num_batches_tracked"
  182. if num_batches_tracked_key in state_dict:
  183. del state_dict[num_batches_tracked_key]
  184. super()._load_from_state_dict(
  185. state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
  186. )
  187. def forward(self, x):
  188. # move reshapes to the beginning
  189. # to make it user-friendly
  190. weight = self.weight.reshape(1, -1, 1, 1)
  191. bias = self.bias.reshape(1, -1, 1, 1)
  192. running_var = self.running_var.reshape(1, -1, 1, 1)
  193. running_mean = self.running_mean.reshape(1, -1, 1, 1)
  194. epsilon = 1e-5
  195. scale = weight * (running_var + epsilon).rsqrt()
  196. bias = bias - running_mean * scale
  197. return x * scale + bias
  198. def replace_batch_norm(model):
  199. r"""
  200. Recursively replace all `torch.nn.BatchNorm2d` with `DetrFrozenBatchNorm2d`.
  201. Args:
  202. model (torch.nn.Module):
  203. input model
  204. """
  205. for name, module in model.named_children():
  206. if isinstance(module, nn.BatchNorm2d):
  207. new_module = DetrFrozenBatchNorm2d(module.num_features)
  208. if module.weight.device != torch.device("meta"):
  209. new_module.weight.copy_(module.weight)
  210. new_module.bias.copy_(module.bias)
  211. new_module.running_mean.copy_(module.running_mean)
  212. new_module.running_var.copy_(module.running_var)
  213. model._modules[name] = new_module
  214. if len(list(module.children())) > 0:
  215. replace_batch_norm(module)
  216. class DetrConvEncoder(nn.Module):
  217. """
  218. Convolutional backbone, using either the AutoBackbone API or one from the timm library.
  219. nn.BatchNorm2d layers are replaced by DetrFrozenBatchNorm2d as defined above.
  220. """
  221. def __init__(self, config):
  222. super().__init__()
  223. self.config = config
  224. backbone = load_backbone(config)
  225. self.intermediate_channel_sizes = backbone.channels
  226. # replace batch norm by frozen batch norm
  227. with torch.no_grad():
  228. replace_batch_norm(backbone)
  229. # We used to load with timm library directly instead of the AutoBackbone API
  230. # so we need to unwrap the `backbone._backbone` module to load weights without mismatch
  231. is_timm_model = False
  232. if hasattr(backbone, "_backbone"):
  233. backbone = backbone._backbone
  234. is_timm_model = True
  235. self.model = backbone
  236. backbone_model_type = config.backbone_config.model_type
  237. if "resnet" in backbone_model_type:
  238. for name, parameter in self.model.named_parameters():
  239. if is_timm_model:
  240. if "layer2" not in name and "layer3" not in name and "layer4" not in name:
  241. parameter.requires_grad_(False)
  242. else:
  243. if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name:
  244. parameter.requires_grad_(False)
  245. def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
  246. # send pixel_values through the model to get list of feature maps
  247. features = self.model(pixel_values)
  248. if isinstance(features, dict):
  249. features = features.feature_maps
  250. out = []
  251. for feature_map in features:
  252. # downsample pixel_mask to match shape of corresponding feature_map
  253. mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
  254. out.append((feature_map, mask))
  255. return out
  256. class DetrSinePositionEmbedding(nn.Module):
  257. """
  258. This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
  259. need paper, generalized to work on images.
  260. """
  261. def __init__(
  262. self,
  263. num_position_features: int = 64,
  264. temperature: int = 10000,
  265. normalize: bool = False,
  266. scale: float | None = None,
  267. ):
  268. super().__init__()
  269. if scale is not None and normalize is False:
  270. raise ValueError("normalize should be True if scale is passed")
  271. self.num_position_features = num_position_features
  272. self.temperature = temperature
  273. self.normalize = normalize
  274. self.scale = 2 * math.pi if scale is None else scale
  275. @compile_compatible_method_lru_cache(maxsize=1)
  276. def forward(
  277. self,
  278. shape: torch.Size,
  279. device: torch.device | str,
  280. dtype: torch.dtype,
  281. mask: torch.Tensor | None = None,
  282. ) -> torch.Tensor:
  283. if mask is None:
  284. mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
  285. y_embed = mask.cumsum(1, dtype=dtype)
  286. x_embed = mask.cumsum(2, dtype=dtype)
  287. if self.normalize:
  288. eps = 1e-6
  289. y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
  290. x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
  291. dim_t = torch.arange(self.num_position_features, dtype=torch.int64, device=device).to(dtype)
  292. dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_position_features)
  293. pos_x = x_embed[:, :, :, None] / dim_t
  294. pos_y = y_embed[:, :, :, None] / dim_t
  295. pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
  296. pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
  297. pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
  298. # Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
  299. # expected by the encoder
  300. pos = pos.flatten(2).permute(0, 2, 1)
  301. return pos
  302. class DetrLearnedPositionEmbedding(nn.Module):
  303. """
  304. This module learns positional embeddings up to a fixed maximum size.
  305. """
  306. def __init__(self, embedding_dim=256):
  307. super().__init__()
  308. self.row_embeddings = nn.Embedding(50, embedding_dim)
  309. self.column_embeddings = nn.Embedding(50, embedding_dim)
  310. @compile_compatible_method_lru_cache(maxsize=1)
  311. def forward(
  312. self,
  313. shape: torch.Size,
  314. device: torch.device | str,
  315. dtype: torch.dtype,
  316. mask: torch.Tensor | None = None,
  317. ):
  318. height, width = shape[-2:]
  319. width_values = torch.arange(width, device=device)
  320. height_values = torch.arange(height, device=device)
  321. x_emb = self.column_embeddings(width_values)
  322. y_emb = self.row_embeddings(height_values)
  323. pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
  324. pos = pos.permute(2, 0, 1)
  325. pos = pos.unsqueeze(0)
  326. pos = pos.repeat(shape[0], 1, 1, 1)
  327. # Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
  328. # expected by the encoder
  329. pos = pos.flatten(2).permute(0, 2, 1)
  330. return pos
  331. # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
  332. def eager_attention_forward(
  333. module: nn.Module,
  334. query: torch.Tensor,
  335. key: torch.Tensor,
  336. value: torch.Tensor,
  337. attention_mask: torch.Tensor | None,
  338. scaling: float | None = None,
  339. dropout: float = 0.0,
  340. **kwargs: Unpack[TransformersKwargs],
  341. ):
  342. if scaling is None:
  343. scaling = query.size(-1) ** -0.5
  344. # Take the dot product between "query" and "key" to get the raw attention scores.
  345. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  346. if attention_mask is not None:
  347. attn_weights = attn_weights + attention_mask
  348. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  349. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  350. attn_output = torch.matmul(attn_weights, value)
  351. attn_output = attn_output.transpose(1, 2).contiguous()
  352. return attn_output, attn_weights
  353. class DetrSelfAttention(nn.Module):
  354. """
  355. Multi-headed self-attention from 'Attention Is All You Need' paper.
  356. In DETR, position embeddings are added to both queries and keys (but not values) in self-attention.
  357. """
  358. def __init__(
  359. self,
  360. config: DetrConfig,
  361. hidden_size: int,
  362. num_attention_heads: int,
  363. dropout: float = 0.0,
  364. bias: bool = True,
  365. ):
  366. super().__init__()
  367. self.config = config
  368. self.head_dim = hidden_size // num_attention_heads
  369. self.scaling = self.head_dim**-0.5
  370. self.attention_dropout = dropout
  371. self.is_causal = False
  372. self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
  373. self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
  374. self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
  375. self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
  376. def forward(
  377. self,
  378. hidden_states: torch.Tensor,
  379. attention_mask: torch.Tensor | None = None,
  380. position_embeddings: torch.Tensor | None = None,
  381. **kwargs: Unpack[TransformersKwargs],
  382. ) -> tuple[torch.Tensor, torch.Tensor]:
  383. """
  384. Position embeddings are added to both queries and keys (but not values).
  385. """
  386. input_shape = hidden_states.shape[:-1]
  387. hidden_shape = (*input_shape, -1, self.head_dim)
  388. query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
  389. query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2)
  390. key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2)
  391. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  392. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  393. self.config._attn_implementation, eager_attention_forward
  394. )
  395. attn_output, attn_weights = attention_interface(
  396. self,
  397. query_states,
  398. key_states,
  399. value_states,
  400. attention_mask,
  401. dropout=0.0 if not self.training else self.attention_dropout,
  402. scaling=self.scaling,
  403. **kwargs,
  404. )
  405. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  406. attn_output = self.o_proj(attn_output)
  407. return attn_output, attn_weights
  408. class DetrCrossAttention(nn.Module):
  409. """
  410. Multi-headed cross-attention from 'Attention Is All You Need' paper.
  411. In DETR, queries get their own position embeddings, while keys get encoder position embeddings.
  412. Values don't get any position embeddings.
  413. """
  414. def __init__(
  415. self,
  416. config: DetrConfig,
  417. hidden_size: int,
  418. num_attention_heads: int,
  419. dropout: float = 0.0,
  420. bias: bool = True,
  421. ):
  422. super().__init__()
  423. self.config = config
  424. self.head_dim = hidden_size // num_attention_heads
  425. self.scaling = self.head_dim**-0.5
  426. self.attention_dropout = dropout
  427. self.is_causal = False
  428. self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
  429. self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
  430. self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
  431. self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
  432. def forward(
  433. self,
  434. hidden_states: torch.Tensor,
  435. key_value_states: torch.Tensor,
  436. attention_mask: torch.Tensor | None = None,
  437. position_embeddings: torch.Tensor | None = None,
  438. encoder_position_embeddings: torch.Tensor | None = None,
  439. **kwargs: Unpack[TransformersKwargs],
  440. ) -> tuple[torch.Tensor, torch.Tensor]:
  441. """
  442. Position embeddings logic:
  443. - Queries get position_embeddings
  444. - Keys get encoder_position_embeddings
  445. - Values don't get any position embeddings
  446. """
  447. query_input_shape = hidden_states.shape[:-1]
  448. query_hidden_shape = (*query_input_shape, -1, self.head_dim)
  449. kv_input_shape = key_value_states.shape[:-1]
  450. kv_hidden_shape = (*kv_input_shape, -1, self.head_dim)
  451. query_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
  452. key_input = (
  453. key_value_states + encoder_position_embeddings
  454. if encoder_position_embeddings is not None
  455. else key_value_states
  456. )
  457. query_states = self.q_proj(query_input).view(query_hidden_shape).transpose(1, 2)
  458. key_states = self.k_proj(key_input).view(kv_hidden_shape).transpose(1, 2)
  459. value_states = self.v_proj(key_value_states).view(kv_hidden_shape).transpose(1, 2)
  460. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  461. self.config._attn_implementation, eager_attention_forward
  462. )
  463. attn_output, attn_weights = attention_interface(
  464. self,
  465. query_states,
  466. key_states,
  467. value_states,
  468. attention_mask,
  469. dropout=0.0 if not self.training else self.attention_dropout,
  470. scaling=self.scaling,
  471. **kwargs,
  472. )
  473. attn_output = attn_output.reshape(*query_input_shape, -1).contiguous()
  474. attn_output = self.o_proj(attn_output)
  475. return attn_output, attn_weights
  476. class DetrMLP(nn.Module):
  477. def __init__(self, config: DetrConfig, hidden_size: int, intermediate_size: int):
  478. super().__init__()
  479. self.fc1 = nn.Linear(hidden_size, intermediate_size)
  480. self.fc2 = nn.Linear(intermediate_size, hidden_size)
  481. self.activation_fn = ACT2FN[config.activation_function]
  482. self.activation_dropout = config.activation_dropout
  483. self.dropout = config.dropout
  484. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  485. hidden_states = self.activation_fn(self.fc1(hidden_states))
  486. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  487. hidden_states = self.fc2(hidden_states)
  488. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  489. return hidden_states
  490. class DetrEncoderLayer(GradientCheckpointingLayer):
  491. def __init__(self, config: DetrConfig):
  492. super().__init__()
  493. self.hidden_size = config.d_model
  494. self.self_attn = DetrSelfAttention(
  495. config=config,
  496. hidden_size=self.hidden_size,
  497. num_attention_heads=config.encoder_attention_heads,
  498. dropout=config.attention_dropout,
  499. )
  500. self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
  501. self.dropout = config.dropout
  502. self.mlp = DetrMLP(config, self.hidden_size, config.encoder_ffn_dim)
  503. self.final_layer_norm = nn.LayerNorm(self.hidden_size)
  504. def forward(
  505. self,
  506. hidden_states: torch.Tensor,
  507. attention_mask: torch.Tensor,
  508. spatial_position_embeddings: torch.Tensor | None = None,
  509. **kwargs: Unpack[TransformersKwargs],
  510. ) -> torch.Tensor:
  511. """
  512. Args:
  513. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
  514. attention_mask (`torch.FloatTensor`): attention mask of size
  515. `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
  516. values.
  517. spatial_position_embeddings (`torch.FloatTensor`, *optional*):
  518. Spatial position embeddings (2D positional encodings of image locations), to be added to both
  519. the queries and keys in self-attention (but not to values).
  520. """
  521. residual = hidden_states
  522. hidden_states, _ = self.self_attn(
  523. hidden_states=hidden_states,
  524. attention_mask=attention_mask,
  525. position_embeddings=spatial_position_embeddings,
  526. **kwargs,
  527. )
  528. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  529. hidden_states = residual + hidden_states
  530. hidden_states = self.self_attn_layer_norm(hidden_states)
  531. residual = hidden_states
  532. hidden_states = self.mlp(hidden_states)
  533. hidden_states = residual + hidden_states
  534. hidden_states = self.final_layer_norm(hidden_states)
  535. if self.training:
  536. if not torch.isfinite(hidden_states).all():
  537. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  538. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  539. return hidden_states
  540. class DetrDecoderLayer(GradientCheckpointingLayer):
  541. def __init__(self, config: DetrConfig):
  542. super().__init__()
  543. self.hidden_size = config.d_model
  544. self.self_attn = DetrSelfAttention(
  545. config=config,
  546. hidden_size=self.hidden_size,
  547. num_attention_heads=config.decoder_attention_heads,
  548. dropout=config.attention_dropout,
  549. )
  550. self.dropout = config.dropout
  551. self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
  552. self.encoder_attn = DetrCrossAttention(
  553. config=config,
  554. hidden_size=self.hidden_size,
  555. num_attention_heads=config.decoder_attention_heads,
  556. dropout=config.attention_dropout,
  557. )
  558. self.encoder_attn_layer_norm = nn.LayerNorm(self.hidden_size)
  559. self.mlp = DetrMLP(config, self.hidden_size, config.decoder_ffn_dim)
  560. self.final_layer_norm = nn.LayerNorm(self.hidden_size)
  561. def forward(
  562. self,
  563. hidden_states: torch.Tensor,
  564. attention_mask: torch.Tensor | None = None,
  565. spatial_position_embeddings: torch.Tensor | None = None,
  566. object_queries_position_embeddings: torch.Tensor | None = None,
  567. encoder_hidden_states: torch.Tensor | None = None,
  568. encoder_attention_mask: torch.Tensor | None = None,
  569. **kwargs: Unpack[TransformersKwargs],
  570. ) -> torch.Tensor:
  571. """
  572. Args:
  573. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
  574. attention_mask (`torch.FloatTensor`): attention mask of size
  575. `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
  576. values.
  577. spatial_position_embeddings (`torch.FloatTensor`, *optional*):
  578. Spatial position embeddings (2D positional encodings from encoder) that are added to the keys only
  579. in the cross-attention layer (not to values).
  580. object_queries_position_embeddings (`torch.FloatTensor`, *optional*):
  581. Position embeddings for the object query slots. In self-attention, these are added to both queries
  582. and keys (not values). In cross-attention, these are added to queries only (not to keys or values).
  583. encoder_hidden_states (`torch.FloatTensor`):
  584. cross attention input to the layer of shape `(batch, seq_len, hidden_size)`
  585. encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
  586. `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
  587. values.
  588. """
  589. residual = hidden_states
  590. # Self Attention
  591. hidden_states, _ = self.self_attn(
  592. hidden_states=hidden_states,
  593. position_embeddings=object_queries_position_embeddings,
  594. attention_mask=attention_mask,
  595. **kwargs,
  596. )
  597. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  598. hidden_states = residual + hidden_states
  599. hidden_states = self.self_attn_layer_norm(hidden_states)
  600. # Cross-Attention Block
  601. if encoder_hidden_states is not None:
  602. residual = hidden_states
  603. hidden_states, _ = self.encoder_attn(
  604. hidden_states=hidden_states,
  605. key_value_states=encoder_hidden_states,
  606. attention_mask=encoder_attention_mask,
  607. position_embeddings=object_queries_position_embeddings,
  608. encoder_position_embeddings=spatial_position_embeddings,
  609. **kwargs,
  610. )
  611. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  612. hidden_states = residual + hidden_states
  613. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  614. # Fully Connected
  615. residual = hidden_states
  616. hidden_states = self.mlp(hidden_states)
  617. hidden_states = residual + hidden_states
  618. hidden_states = self.final_layer_norm(hidden_states)
  619. return hidden_states
  620. class DetrConvBlock(nn.Module):
  621. """Basic conv block: Conv3x3 -> GroupNorm -> Activation."""
  622. def __init__(self, in_channels: int, out_channels: int, activation: str = "relu"):
  623. super().__init__()
  624. self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
  625. self.norm = nn.GroupNorm(min(8, out_channels), out_channels)
  626. self.activation = ACT2FN[activation]
  627. def forward(self, x: torch.Tensor) -> torch.Tensor:
  628. return self.activation(self.norm(self.conv(x)))
  629. class DetrFPNFusionStage(nn.Module):
  630. """Single FPN fusion stage combining low-resolution features with high-resolution FPN features."""
  631. def __init__(self, fpn_channels: int, current_channels: int, output_channels: int, activation: str = "relu"):
  632. super().__init__()
  633. self.fpn_adapter = nn.Conv2d(fpn_channels, current_channels, kernel_size=1)
  634. self.refine = DetrConvBlock(current_channels, output_channels, activation)
  635. def forward(self, features: torch.Tensor, fpn_features: torch.Tensor) -> torch.Tensor:
  636. """
  637. Args:
  638. features: Current features to upsample, shape (B*Q, current_channels, H_in, W_in)
  639. fpn_features: FPN features at target resolution, shape (B*Q, fpn_channels, H_out, W_out)
  640. Returns:
  641. Fused and refined features, shape (B*Q, output_channels, H_out, W_out)
  642. """
  643. fpn_features = self.fpn_adapter(fpn_features)
  644. features = nn.functional.interpolate(features, size=fpn_features.shape[-2:], mode="nearest")
  645. return self.refine(fpn_features + features)
  646. class DetrMaskHeadSmallConv(nn.Module):
  647. """
  648. Segmentation mask head that generates per-query masks using FPN-based progressive upsampling.
  649. Combines attention maps (spatial localization) with encoder features (semantics) and progressively
  650. upsamples through multiple scales, fusing with FPN features for high-resolution detail.
  651. """
  652. def __init__(
  653. self,
  654. input_channels: int,
  655. fpn_channels: list[int],
  656. hidden_size: int,
  657. activation_function: str = "relu",
  658. ):
  659. super().__init__()
  660. if input_channels % 8 != 0:
  661. raise ValueError(f"input_channels must be divisible by 8, got {input_channels}")
  662. self.conv1 = DetrConvBlock(input_channels, input_channels, activation_function)
  663. self.conv2 = DetrConvBlock(input_channels, hidden_size // 2, activation_function)
  664. # Progressive channel reduction: /2 -> /4 -> /8 -> /16
  665. self.fpn_stages = nn.ModuleList(
  666. [
  667. DetrFPNFusionStage(fpn_channels[0], hidden_size // 2, hidden_size // 4, activation_function),
  668. DetrFPNFusionStage(fpn_channels[1], hidden_size // 4, hidden_size // 8, activation_function),
  669. DetrFPNFusionStage(fpn_channels[2], hidden_size // 8, hidden_size // 16, activation_function),
  670. ]
  671. )
  672. self.output_conv = nn.Conv2d(hidden_size // 16, 1, kernel_size=3, padding=1)
  673. def forward(
  674. self,
  675. features: torch.Tensor,
  676. attention_masks: torch.Tensor,
  677. fpn_features: list[torch.Tensor],
  678. ) -> torch.Tensor:
  679. """
  680. Args:
  681. features: Encoder output features, shape (batch_size, hidden_size, H, W)
  682. attention_masks: Cross-attention maps from decoder, shape (batch_size, num_queries, num_heads, H, W)
  683. fpn_features: List of 3 FPN features from low to high resolution, each (batch_size, C, H, W)
  684. Returns:
  685. Predicted masks, shape (batch_size * num_queries, 1, output_H, output_W)
  686. """
  687. num_queries = attention_masks.shape[1]
  688. # Expand to (batch_size * num_queries) dimension
  689. features = features.unsqueeze(1).expand(-1, num_queries, -1, -1, -1).flatten(0, 1)
  690. attention_masks = attention_masks.flatten(0, 1)
  691. fpn_features = [
  692. fpn_feat.unsqueeze(1).expand(-1, num_queries, -1, -1, -1).flatten(0, 1) for fpn_feat in fpn_features
  693. ]
  694. hidden_states = torch.cat([features, attention_masks], dim=1)
  695. hidden_states = self.conv1(hidden_states)
  696. hidden_states = self.conv2(hidden_states)
  697. for fpn_stage, fpn_feat in zip(self.fpn_stages, fpn_features):
  698. hidden_states = fpn_stage(hidden_states, fpn_feat)
  699. return self.output_conv(hidden_states)
  700. class DetrMHAttentionMap(nn.Module):
  701. """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
  702. def __init__(
  703. self,
  704. hidden_size: int,
  705. num_attention_heads: int,
  706. dropout: float = 0.0,
  707. bias: bool = True,
  708. ):
  709. super().__init__()
  710. self.head_dim = hidden_size // num_attention_heads
  711. self.scaling = self.head_dim**-0.5
  712. self.attention_dropout = dropout
  713. self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
  714. self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
  715. def forward(
  716. self, query_states: torch.Tensor, key_states: torch.Tensor, attention_mask: torch.Tensor | None = None
  717. ):
  718. query_hidden_shape = (*query_states.shape[:-1], -1, self.head_dim)
  719. key_hidden_shape = (key_states.shape[0], -1, self.head_dim, *key_states.shape[-2:])
  720. query_states = self.q_proj(query_states).view(query_hidden_shape)
  721. key_states = nn.functional.conv2d(
  722. key_states, self.k_proj.weight.unsqueeze(-1).unsqueeze(-1), self.k_proj.bias
  723. ).view(key_hidden_shape)
  724. batch_size, num_queries, num_heads, head_dim = query_states.shape
  725. _, _, _, height, width = key_states.shape
  726. query_shape = (batch_size * num_heads, num_queries, head_dim)
  727. key_shape = (batch_size * num_heads, height * width, head_dim)
  728. attn_weights_shape = (batch_size, num_heads, num_queries, height, width)
  729. query = query_states.transpose(1, 2).contiguous().view(query_shape)
  730. key = key_states.permute(0, 1, 3, 4, 2).contiguous().view(key_shape)
  731. attn_weights = (
  732. (torch.matmul(query * self.scaling, key.transpose(1, 2))).view(attn_weights_shape).transpose(1, 2)
  733. )
  734. if attention_mask is not None:
  735. attn_weights = attn_weights + attention_mask
  736. attn_weights = nn.functional.softmax(attn_weights.flatten(2), dim=-1).view(attn_weights.size())
  737. attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
  738. return attn_weights
  739. @auto_docstring
  740. class DetrPreTrainedModel(PreTrainedModel):
  741. config: DetrConfig
  742. base_model_prefix = "model"
  743. main_input_name = "pixel_values"
  744. input_modalities = ("image",)
  745. _no_split_modules = [r"DetrConvEncoder", r"DetrEncoderLayer", r"DetrDecoderLayer"]
  746. supports_gradient_checkpointing = True
  747. _supports_sdpa = True
  748. _supports_flash_attn = True
  749. _supports_attention_backend = True
  750. _supports_flex_attn = True # Uses create_bidirectional_masks for attention masking
  751. _keys_to_ignore_on_load_unexpected = [
  752. r"detr\.model\.backbone\.model\.layer\d+\.0\.downsample\.1\.num_batches_tracked"
  753. ]
  754. @torch.no_grad()
  755. def _init_weights(self, module):
  756. std = self.config.init_std
  757. xavier_std = self.config.init_xavier_std
  758. if isinstance(module, DetrMaskHeadSmallConv):
  759. # DetrMaskHeadSmallConv uses kaiming initialization for all its Conv2d layers
  760. for m in module.modules():
  761. if isinstance(m, nn.Conv2d):
  762. init.kaiming_uniform_(m.weight, a=1)
  763. if m.bias is not None:
  764. init.constant_(m.bias, 0)
  765. elif isinstance(module, DetrMHAttentionMap):
  766. init.zeros_(module.k_proj.bias)
  767. init.zeros_(module.q_proj.bias)
  768. init.xavier_uniform_(module.k_proj.weight, gain=xavier_std)
  769. init.xavier_uniform_(module.q_proj.weight, gain=xavier_std)
  770. elif isinstance(module, DetrLearnedPositionEmbedding):
  771. init.uniform_(module.row_embeddings.weight)
  772. init.uniform_(module.column_embeddings.weight)
  773. elif isinstance(module, (nn.Linear, nn.Conv2d)):
  774. init.normal_(module.weight, mean=0.0, std=std)
  775. if module.bias is not None:
  776. init.zeros_(module.bias)
  777. elif isinstance(module, nn.Embedding):
  778. init.normal_(module.weight, mean=0.0, std=std)
  779. # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
  780. if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
  781. init.zeros_(module.weight[module.padding_idx])
  782. elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
  783. init.ones_(module.weight)
  784. init.zeros_(module.bias)
  785. class DetrEncoder(DetrPreTrainedModel):
  786. """
  787. Transformer encoder that processes a flattened feature map from a vision backbone, composed of a stack of
  788. [`DetrEncoderLayer`] modules.
  789. Args:
  790. config (`DetrConfig`): Model configuration object.
  791. """
  792. _can_record_outputs = {"hidden_states": DetrEncoderLayer, "attentions": DetrSelfAttention}
  793. def __init__(self, config: DetrConfig):
  794. super().__init__(config)
  795. self.dropout = config.dropout
  796. self.layers = nn.ModuleList([DetrEncoderLayer(config) for _ in range(config.encoder_layers)])
  797. # Initialize weights and apply final processing
  798. self.post_init()
  799. @merge_with_config_defaults
  800. @capture_outputs
  801. def forward(
  802. self,
  803. inputs_embeds=None,
  804. attention_mask=None,
  805. spatial_position_embeddings=None,
  806. **kwargs: Unpack[TransformersKwargs],
  807. ) -> BaseModelOutput:
  808. r"""
  809. Args:
  810. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  811. Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
  812. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  813. Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
  814. - 1 for pixel features that are real (i.e. **not masked**),
  815. - 0 for pixel features that are padding (i.e. **masked**).
  816. [What are attention masks?](../glossary#attention-mask)
  817. spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  818. Spatial position embeddings (2D positional encodings) that are added to the queries and keys in each self-attention layer.
  819. """
  820. hidden_states = inputs_embeds
  821. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  822. attention_mask = create_bidirectional_mask(
  823. config=self.config,
  824. inputs_embeds=inputs_embeds,
  825. attention_mask=attention_mask,
  826. )
  827. for encoder_layer in self.layers:
  828. # we add spatial_position_embeddings as extra input to the encoder_layer
  829. hidden_states = encoder_layer(
  830. hidden_states, attention_mask, spatial_position_embeddings=spatial_position_embeddings, **kwargs
  831. )
  832. return BaseModelOutput(last_hidden_state=hidden_states)
  833. class DetrDecoder(DetrPreTrainedModel):
  834. """
  835. Transformer decoder that refines a set of object queries. It is composed of a stack of [`DetrDecoderLayer`] modules,
  836. which apply self-attention to the queries and cross-attention to the encoder's outputs.
  837. Args:
  838. config (`DetrConfig`): Model configuration object.
  839. """
  840. _can_record_outputs = {
  841. "hidden_states": DetrDecoderLayer,
  842. "attentions": DetrSelfAttention,
  843. "cross_attentions": DetrCrossAttention,
  844. }
  845. def __init__(self, config: DetrConfig):
  846. super().__init__(config)
  847. self.dropout = config.dropout
  848. self.layers = nn.ModuleList([DetrDecoderLayer(config) for _ in range(config.decoder_layers)])
  849. # in DETR, the decoder uses layernorm after the last decoder layer output
  850. self.layernorm = nn.LayerNorm(config.d_model)
  851. # Initialize weights and apply final processing
  852. self.post_init()
  853. @merge_with_config_defaults
  854. @capture_outputs
  855. def forward(
  856. self,
  857. inputs_embeds=None,
  858. attention_mask=None,
  859. encoder_hidden_states=None,
  860. encoder_attention_mask=None,
  861. spatial_position_embeddings=None,
  862. object_queries_position_embeddings=None,
  863. **kwargs: Unpack[TransformersKwargs],
  864. ) -> DetrDecoderOutput:
  865. r"""
  866. Args:
  867. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  868. The query embeddings that are passed into the decoder.
  869. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  870. Mask to avoid performing attention on certain queries. Mask values selected in `[0, 1]`:
  871. - 1 for queries that are **not masked**,
  872. - 0 for queries that are **masked**.
  873. [What are attention masks?](../glossary#attention-mask)
  874. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  875. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  876. of the decoder.
  877. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
  878. Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected
  879. in `[0, 1]`:
  880. - 1 for pixels that are real (i.e. **not masked**),
  881. - 0 for pixels that are padding (i.e. **masked**).
  882. spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  883. Spatial position embeddings (2D positional encodings from encoder) that are added to the keys in each cross-attention layer.
  884. object_queries_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
  885. Position embeddings for the object query slots that are added to the queries and keys in each self-attention layer.
  886. """
  887. if inputs_embeds is not None:
  888. hidden_states = inputs_embeds
  889. if attention_mask is not None:
  890. attention_mask = create_bidirectional_mask(
  891. config=self.config,
  892. inputs_embeds=hidden_states,
  893. attention_mask=attention_mask,
  894. )
  895. # expand encoder attention mask (for cross-attention on encoder outputs)
  896. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  897. encoder_attention_mask = create_bidirectional_mask(
  898. config=self.config,
  899. inputs_embeds=hidden_states,
  900. attention_mask=encoder_attention_mask,
  901. encoder_hidden_states=encoder_hidden_states,
  902. )
  903. # optional intermediate hidden states
  904. intermediate = () if self.config.auxiliary_loss else None
  905. # decoder layers
  906. for idx, decoder_layer in enumerate(self.layers):
  907. hidden_states = decoder_layer(
  908. hidden_states,
  909. attention_mask,
  910. spatial_position_embeddings,
  911. object_queries_position_embeddings,
  912. encoder_hidden_states, # as a positional argument for gradient checkpointing
  913. encoder_attention_mask=encoder_attention_mask,
  914. **kwargs,
  915. )
  916. if self.config.auxiliary_loss:
  917. hidden_states = self.layernorm(hidden_states)
  918. intermediate += (hidden_states,)
  919. # finally, apply layernorm
  920. hidden_states = self.layernorm(hidden_states)
  921. # stack intermediate decoder activations
  922. if self.config.auxiliary_loss:
  923. intermediate = torch.stack(intermediate)
  924. return DetrDecoderOutput(last_hidden_state=hidden_states, intermediate_hidden_states=intermediate)
  925. @auto_docstring(
  926. custom_intro="""
  927. The bare DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw hidden-states without
  928. any specific head on top.
  929. """
  930. )
  931. class DetrModel(DetrPreTrainedModel):
  932. def __init__(self, config: DetrConfig):
  933. super().__init__(config)
  934. self.backbone = DetrConvEncoder(config)
  935. if config.position_embedding_type == "sine":
  936. self.position_embedding = DetrSinePositionEmbedding(config.d_model // 2, normalize=True)
  937. elif config.position_embedding_type == "learned":
  938. self.position_embedding = DetrLearnedPositionEmbedding(config.d_model // 2)
  939. else:
  940. raise ValueError(f"Not supported {config.position_embedding_type}")
  941. self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model)
  942. self.input_projection = nn.Conv2d(self.backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
  943. self.encoder = DetrEncoder(config)
  944. self.decoder = DetrDecoder(config)
  945. # Initialize weights and apply final processing
  946. self.post_init()
  947. def freeze_backbone(self):
  948. for _, param in self.backbone.model.named_parameters():
  949. param.requires_grad_(False)
  950. def unfreeze_backbone(self):
  951. for _, param in self.backbone.model.named_parameters():
  952. param.requires_grad_(True)
  953. @auto_docstring
  954. @can_return_tuple
  955. def forward(
  956. self,
  957. pixel_values: torch.FloatTensor | None = None,
  958. pixel_mask: torch.LongTensor | None = None,
  959. decoder_attention_mask: torch.FloatTensor | None = None,
  960. encoder_outputs: torch.FloatTensor | None = None,
  961. inputs_embeds: torch.FloatTensor | None = None,
  962. decoder_inputs_embeds: torch.FloatTensor | None = None,
  963. **kwargs: Unpack[TransformersKwargs],
  964. ) -> tuple[torch.FloatTensor] | DetrModelOutput:
  965. r"""
  966. decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
  967. Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
  968. - 1 for queries that are **not masked**,
  969. - 0 for queries that are **masked**.
  970. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  971. Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
  972. can choose to directly pass a flattened representation of an image. Useful for bypassing the vision backbone.
  973. decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
  974. Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
  975. embedded representation. Useful for tasks that require custom query initialization.
  976. Examples:
  977. ```python
  978. >>> from transformers import AutoImageProcessor, DetrModel
  979. >>> from PIL import Image
  980. >>> import httpx
  981. >>> from io import BytesIO
  982. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  983. >>> with httpx.stream("GET", url) as response:
  984. ... image = Image.open(BytesIO(response.read()))
  985. >>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
  986. >>> model = DetrModel.from_pretrained("facebook/detr-resnet-50")
  987. >>> # prepare image for the model
  988. >>> inputs = image_processor(images=image, return_tensors="pt")
  989. >>> # forward pass
  990. >>> outputs = model(**inputs)
  991. >>> # the last hidden states are the final query embeddings of the Transformer decoder
  992. >>> # these are of shape (batch_size, num_queries, hidden_size)
  993. >>> last_hidden_states = outputs.last_hidden_state
  994. >>> list(last_hidden_states.shape)
  995. [1, 100, 256]
  996. ```"""
  997. if pixel_values is None and inputs_embeds is None:
  998. raise ValueError("You have to specify either pixel_values or inputs_embeds")
  999. if inputs_embeds is None:
  1000. batch_size, num_channels, height, width = pixel_values.shape
  1001. device = pixel_values.device
  1002. if pixel_mask is None:
  1003. pixel_mask = torch.ones(((batch_size, height, width)), device=device)
  1004. vision_features = self.backbone(pixel_values, pixel_mask)
  1005. feature_map, mask = vision_features[-1]
  1006. # Apply 1x1 conv to map (batch_size, C, H, W) -> (batch_size, hidden_size, H, W), then flatten to (batch_size, HW, hidden_size)
  1007. # Position embeddings are already flattened to (batch_size, sequence_length, hidden_size) format
  1008. projected_feature_map = self.input_projection(feature_map)
  1009. flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
  1010. spatial_position_embeddings = self.position_embedding(
  1011. shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask
  1012. )
  1013. flattened_mask = mask.flatten(1)
  1014. else:
  1015. batch_size = inputs_embeds.shape[0]
  1016. device = inputs_embeds.device
  1017. flattened_features = inputs_embeds
  1018. # When using inputs_embeds, we need to infer spatial dimensions for position embeddings
  1019. # Assume square feature map
  1020. seq_len = inputs_embeds.shape[1]
  1021. feat_dim = int(seq_len**0.5)
  1022. # Create position embeddings for the inferred spatial size
  1023. spatial_position_embeddings = self.position_embedding(
  1024. shape=torch.Size([batch_size, self.config.d_model, feat_dim, feat_dim]),
  1025. device=device,
  1026. dtype=inputs_embeds.dtype,
  1027. )
  1028. # If a pixel_mask is provided with inputs_embeds, interpolate it to feat_dim, then flatten.
  1029. if pixel_mask is not None:
  1030. mask = nn.functional.interpolate(pixel_mask[None].float(), size=(feat_dim, feat_dim)).to(torch.bool)[0]
  1031. flattened_mask = mask.flatten(1)
  1032. else:
  1033. # If no mask provided, assume all positions are valid
  1034. flattened_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.long)
  1035. if encoder_outputs is None:
  1036. encoder_outputs = self.encoder(
  1037. inputs_embeds=flattened_features,
  1038. attention_mask=flattened_mask,
  1039. spatial_position_embeddings=spatial_position_embeddings,
  1040. **kwargs,
  1041. )
  1042. object_queries_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(
  1043. batch_size, 1, 1
  1044. )
  1045. # Use decoder_inputs_embeds as queries if provided, otherwise initialize with zeros
  1046. if decoder_inputs_embeds is not None:
  1047. queries = decoder_inputs_embeds
  1048. else:
  1049. queries = torch.zeros_like(object_queries_position_embeddings)
  1050. # decoder outputs consists of (dec_features, dec_hidden, dec_attn)
  1051. decoder_outputs = self.decoder(
  1052. inputs_embeds=queries,
  1053. attention_mask=decoder_attention_mask,
  1054. spatial_position_embeddings=spatial_position_embeddings,
  1055. object_queries_position_embeddings=object_queries_position_embeddings,
  1056. encoder_hidden_states=encoder_outputs.last_hidden_state,
  1057. encoder_attention_mask=flattened_mask,
  1058. **kwargs,
  1059. )
  1060. return DetrModelOutput(
  1061. last_hidden_state=decoder_outputs.last_hidden_state,
  1062. decoder_hidden_states=decoder_outputs.hidden_states,
  1063. decoder_attentions=decoder_outputs.attentions,
  1064. cross_attentions=decoder_outputs.cross_attentions,
  1065. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1066. encoder_hidden_states=encoder_outputs.hidden_states,
  1067. encoder_attentions=encoder_outputs.attentions,
  1068. intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
  1069. )
  1070. class DetrMLPPredictionHead(nn.Module):
  1071. """
  1072. Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
  1073. height and width of a bounding box w.r.t. an image.
  1074. """
  1075. def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
  1076. super().__init__()
  1077. self.num_layers = num_layers
  1078. h = [hidden_dim] * (num_layers - 1)
  1079. self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
  1080. def forward(self, x):
  1081. for i, layer in enumerate(self.layers):
  1082. x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
  1083. return x
  1084. @auto_docstring(
  1085. custom_intro="""
  1086. DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on top, for tasks
  1087. such as COCO detection.
  1088. """
  1089. )
  1090. class DetrForObjectDetection(DetrPreTrainedModel):
  1091. def __init__(self, config: DetrConfig):
  1092. super().__init__(config)
  1093. # DETR encoder-decoder model
  1094. self.model = DetrModel(config)
  1095. # Object detection heads
  1096. self.class_labels_classifier = nn.Linear(
  1097. config.d_model, config.num_labels + 1
  1098. ) # We add one for the "no object" class
  1099. self.bbox_predictor = DetrMLPPredictionHead(
  1100. input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
  1101. )
  1102. # Initialize weights and apply final processing
  1103. self.post_init()
  1104. @auto_docstring
  1105. @can_return_tuple
  1106. def forward(
  1107. self,
  1108. pixel_values: torch.FloatTensor,
  1109. pixel_mask: torch.LongTensor | None = None,
  1110. decoder_attention_mask: torch.FloatTensor | None = None,
  1111. encoder_outputs: torch.FloatTensor | None = None,
  1112. inputs_embeds: torch.FloatTensor | None = None,
  1113. decoder_inputs_embeds: torch.FloatTensor | None = None,
  1114. labels: list[dict] | None = None,
  1115. **kwargs: Unpack[TransformersKwargs],
  1116. ) -> tuple[torch.FloatTensor] | DetrObjectDetectionOutput:
  1117. r"""
  1118. decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
  1119. Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
  1120. - 1 for queries that are **not masked**,
  1121. - 0 for queries that are **masked**.
  1122. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  1123. Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
  1124. can choose to directly pass a flattened representation of an image. Useful for bypassing the vision backbone.
  1125. decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
  1126. Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
  1127. embedded representation. Useful for tasks that require custom query initialization.
  1128. labels (`list[Dict]` of len `(batch_size,)`, *optional*):
  1129. Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
  1130. following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
  1131. respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
  1132. in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
  1133. Examples:
  1134. ```python
  1135. >>> from transformers import AutoImageProcessor, DetrForObjectDetection
  1136. >>> import torch
  1137. >>> from PIL import Image
  1138. >>> import httpx
  1139. >>> from io import BytesIO
  1140. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1141. >>> with httpx.stream("GET", url) as response:
  1142. ... image = Image.open(BytesIO(response.read()))
  1143. >>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
  1144. >>> model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
  1145. >>> inputs = image_processor(images=image, return_tensors="pt")
  1146. >>> outputs = model(**inputs)
  1147. >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
  1148. >>> target_sizes = torch.tensor([image.size[::-1]])
  1149. >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[
  1150. ... 0
  1151. ... ]
  1152. >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
  1153. ... box = [round(i, 2) for i in box.tolist()]
  1154. ... print(
  1155. ... f"Detected {model.config.id2label[label.item()]} with confidence "
  1156. ... f"{round(score.item(), 3)} at location {box}"
  1157. ... )
  1158. Detected remote with confidence 0.998 at location [40.16, 70.81, 175.55, 117.98]
  1159. Detected remote with confidence 0.996 at location [333.24, 72.55, 368.33, 187.66]
  1160. Detected couch with confidence 0.995 at location [-0.02, 1.15, 639.73, 473.76]
  1161. Detected cat with confidence 0.999 at location [13.24, 52.05, 314.02, 470.93]
  1162. Detected cat with confidence 0.999 at location [345.4, 23.85, 640.37, 368.72]
  1163. ```"""
  1164. # First, sent images through DETR base model to obtain encoder + decoder outputs
  1165. outputs = self.model(
  1166. pixel_values,
  1167. pixel_mask=pixel_mask,
  1168. decoder_attention_mask=decoder_attention_mask,
  1169. encoder_outputs=encoder_outputs,
  1170. inputs_embeds=inputs_embeds,
  1171. decoder_inputs_embeds=decoder_inputs_embeds,
  1172. **kwargs,
  1173. )
  1174. sequence_output = outputs[0]
  1175. # class logits + predicted bounding boxes
  1176. logits = self.class_labels_classifier(sequence_output)
  1177. pred_boxes = self.bbox_predictor(sequence_output).sigmoid()
  1178. loss, loss_dict, auxiliary_outputs = None, None, None
  1179. if labels is not None:
  1180. outputs_class, outputs_coord = None, None
  1181. if self.config.auxiliary_loss:
  1182. intermediate = outputs.intermediate_hidden_states
  1183. outputs_class = self.class_labels_classifier(intermediate)
  1184. outputs_coord = self.bbox_predictor(intermediate).sigmoid()
  1185. loss, loss_dict, auxiliary_outputs = self.loss_function(
  1186. logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
  1187. )
  1188. return DetrObjectDetectionOutput(
  1189. loss=loss,
  1190. loss_dict=loss_dict,
  1191. logits=logits,
  1192. pred_boxes=pred_boxes,
  1193. auxiliary_outputs=auxiliary_outputs,
  1194. last_hidden_state=outputs.last_hidden_state,
  1195. decoder_hidden_states=outputs.decoder_hidden_states,
  1196. decoder_attentions=outputs.decoder_attentions,
  1197. cross_attentions=outputs.cross_attentions,
  1198. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1199. encoder_hidden_states=outputs.encoder_hidden_states,
  1200. encoder_attentions=outputs.encoder_attentions,
  1201. )
  1202. @auto_docstring(
  1203. custom_intro="""
  1204. DETR Model (consisting of a backbone and encoder-decoder Transformer) with a segmentation head on top, for tasks
  1205. such as COCO panoptic.
  1206. """
  1207. )
  1208. class DetrForSegmentation(DetrPreTrainedModel):
  1209. def __init__(self, config: DetrConfig):
  1210. super().__init__(config)
  1211. # object detection model
  1212. self.detr = DetrForObjectDetection(config)
  1213. # segmentation head
  1214. hidden_size, number_of_heads = config.d_model, config.encoder_attention_heads
  1215. intermediate_channel_sizes = self.detr.model.backbone.intermediate_channel_sizes
  1216. self.mask_head = DetrMaskHeadSmallConv(
  1217. input_channels=hidden_size + number_of_heads,
  1218. fpn_channels=intermediate_channel_sizes[::-1][-3:],
  1219. hidden_size=hidden_size,
  1220. activation_function=config.activation_function,
  1221. )
  1222. self.bbox_attention = DetrMHAttentionMap(hidden_size, number_of_heads, dropout=0.0)
  1223. # Initialize weights and apply final processing
  1224. self.post_init()
  1225. @auto_docstring
  1226. @can_return_tuple
  1227. def forward(
  1228. self,
  1229. pixel_values: torch.FloatTensor,
  1230. pixel_mask: torch.LongTensor | None = None,
  1231. decoder_attention_mask: torch.FloatTensor | None = None,
  1232. encoder_outputs: torch.FloatTensor | None = None,
  1233. inputs_embeds: torch.FloatTensor | None = None,
  1234. decoder_inputs_embeds: torch.FloatTensor | None = None,
  1235. labels: list[dict] | None = None,
  1236. **kwargs: Unpack[TransformersKwargs],
  1237. ) -> tuple[torch.FloatTensor] | DetrSegmentationOutput:
  1238. r"""
  1239. decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
  1240. Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
  1241. - 1 for queries that are **not masked**,
  1242. - 0 for queries that are **masked**.
  1243. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  1244. Kept for backward compatibility, but cannot be used for segmentation, as segmentation requires
  1245. multi-scale features from the backbone that are not available when bypassing it with inputs_embeds.
  1246. decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
  1247. Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
  1248. embedded representation. Useful for tasks that require custom query initialization.
  1249. labels (`list[Dict]` of len `(batch_size,)`, *optional*):
  1250. Labels for computing the bipartite matching loss, DICE/F-1 loss and Focal loss. List of dicts, each
  1251. dictionary containing at least the following 3 keys: 'class_labels', 'boxes' and 'masks' (the class labels,
  1252. bounding boxes and segmentation masks of an image in the batch respectively). The class labels themselves
  1253. should be a `torch.LongTensor` of len `(number of bounding boxes in the image,)`, the boxes a
  1254. `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)` and the masks a
  1255. `torch.FloatTensor` of shape `(number of bounding boxes in the image, height, width)`.
  1256. Examples:
  1257. ```python
  1258. >>> import io
  1259. >>> import httpx
  1260. >>> from io import BytesIO
  1261. >>> from PIL import Image
  1262. >>> import torch
  1263. >>> import numpy
  1264. >>> from transformers import AutoImageProcessor, DetrForSegmentation
  1265. >>> from transformers.image_transforms import rgb_to_id
  1266. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1267. >>> with httpx.stream("GET", url) as response:
  1268. ... image = Image.open(BytesIO(response.read()))
  1269. >>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50-panoptic")
  1270. >>> model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50-panoptic")
  1271. >>> # prepare image for the model
  1272. >>> inputs = image_processor(images=image, return_tensors="pt")
  1273. >>> # forward pass
  1274. >>> outputs = model(**inputs)
  1275. >>> # Use the `post_process_panoptic_segmentation` method of the `image_processor` to retrieve post-processed panoptic segmentation maps
  1276. >>> # Segmentation results are returned as a list of dictionaries
  1277. >>> result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[(300, 500)])
  1278. >>> # A tensor of shape (height, width) where each value denotes a segment id, filled with -1 if no segment is found
  1279. >>> panoptic_seg = result[0]["segmentation"]
  1280. >>> panoptic_seg.shape
  1281. torch.Size([300, 500])
  1282. >>> # Get prediction score and segment_id to class_id mapping of each segment
  1283. >>> panoptic_segments_info = result[0]["segments_info"]
  1284. >>> len(panoptic_segments_info)
  1285. 5
  1286. ```"""
  1287. batch_size, num_channels, height, width = pixel_values.shape
  1288. device = pixel_values.device
  1289. if pixel_mask is None:
  1290. pixel_mask = torch.ones((batch_size, height, width), device=device)
  1291. vision_features = self.detr.model.backbone(pixel_values, pixel_mask)
  1292. feature_map, mask = vision_features[-1]
  1293. # Apply 1x1 conv to map (batch_size, C, H, W) -> (batch_size, hidden_size, H, W), then flatten to (batch_size, HW, hidden_size)
  1294. projected_feature_map = self.detr.model.input_projection(feature_map)
  1295. flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
  1296. spatial_position_embeddings = self.detr.model.position_embedding(
  1297. shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask
  1298. )
  1299. flattened_mask = mask.flatten(1)
  1300. if encoder_outputs is None:
  1301. encoder_outputs = self.detr.model.encoder(
  1302. inputs_embeds=flattened_features,
  1303. attention_mask=flattened_mask,
  1304. spatial_position_embeddings=spatial_position_embeddings,
  1305. **kwargs,
  1306. )
  1307. object_queries_position_embeddings = self.detr.model.query_position_embeddings.weight.unsqueeze(0).repeat(
  1308. batch_size, 1, 1
  1309. )
  1310. # Use decoder_inputs_embeds as queries if provided, otherwise initialize with zeros
  1311. if decoder_inputs_embeds is not None:
  1312. queries = decoder_inputs_embeds
  1313. else:
  1314. queries = torch.zeros_like(object_queries_position_embeddings)
  1315. decoder_outputs = self.detr.model.decoder(
  1316. inputs_embeds=queries,
  1317. attention_mask=decoder_attention_mask,
  1318. spatial_position_embeddings=spatial_position_embeddings,
  1319. object_queries_position_embeddings=object_queries_position_embeddings,
  1320. encoder_hidden_states=encoder_outputs.last_hidden_state,
  1321. encoder_attention_mask=flattened_mask,
  1322. **kwargs,
  1323. )
  1324. sequence_output = decoder_outputs[0]
  1325. logits = self.detr.class_labels_classifier(sequence_output)
  1326. pred_boxes = self.detr.bbox_predictor(sequence_output).sigmoid()
  1327. height, width = feature_map.shape[-2:]
  1328. memory = encoder_outputs.last_hidden_state.permute(0, 2, 1).view(
  1329. batch_size, self.config.d_model, height, width
  1330. )
  1331. attention_mask = flattened_mask.view(batch_size, height, width)
  1332. if attention_mask is not None:
  1333. min_dtype = torch.finfo(memory.dtype).min
  1334. attention_mask = torch.where(
  1335. attention_mask.unsqueeze(1).unsqueeze(1),
  1336. torch.tensor(0.0, device=memory.device, dtype=memory.dtype),
  1337. min_dtype,
  1338. )
  1339. bbox_mask = self.bbox_attention(sequence_output, memory, attention_mask=attention_mask)
  1340. seg_masks = self.mask_head(
  1341. features=projected_feature_map,
  1342. attention_masks=bbox_mask,
  1343. fpn_features=[vision_features[2][0], vision_features[1][0], vision_features[0][0]],
  1344. )
  1345. pred_masks = seg_masks.view(batch_size, self.detr.config.num_queries, seg_masks.shape[-2], seg_masks.shape[-1])
  1346. loss, loss_dict, auxiliary_outputs = None, None, None
  1347. if labels is not None:
  1348. outputs_class, outputs_coord = None, None
  1349. if self.config.auxiliary_loss:
  1350. intermediate = decoder_outputs.intermediate_hidden_states
  1351. outputs_class = self.detr.class_labels_classifier(intermediate)
  1352. outputs_coord = self.detr.bbox_predictor(intermediate).sigmoid()
  1353. loss, loss_dict, auxiliary_outputs = self.loss_function(
  1354. logits, labels, device, pred_boxes, pred_masks, self.config, outputs_class, outputs_coord
  1355. )
  1356. return DetrSegmentationOutput(
  1357. loss=loss,
  1358. loss_dict=loss_dict,
  1359. logits=logits,
  1360. pred_boxes=pred_boxes,
  1361. pred_masks=pred_masks,
  1362. auxiliary_outputs=auxiliary_outputs,
  1363. last_hidden_state=decoder_outputs.last_hidden_state,
  1364. decoder_hidden_states=decoder_outputs.hidden_states,
  1365. decoder_attentions=decoder_outputs.attentions,
  1366. cross_attentions=decoder_outputs.cross_attentions,
  1367. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1368. encoder_hidden_states=encoder_outputs.hidden_states,
  1369. encoder_attentions=encoder_outputs.attentions,
  1370. )
  1371. __all__ = [
  1372. "DetrForObjectDetection",
  1373. "DetrForSegmentation",
  1374. "DetrModel",
  1375. "DetrPreTrainedModel",
  1376. ]