modeling_yolos.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655
  1. # Copyright 2022 School of EIC, Huazhong University of Science & Technology and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch YOLOS model."""
  15. import collections.abc
  16. from collections.abc import Callable
  17. from dataclasses import dataclass
  18. import torch
  19. from torch import nn
  20. from ...activations import ACT2FN
  21. from ...modeling_layers import GradientCheckpointingLayer
  22. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  23. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  24. from ...processing_utils import Unpack
  25. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging
  26. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  27. from ...utils.output_capturing import capture_outputs
  28. from .configuration_yolos import YolosConfig
  29. logger = logging.get_logger(__name__)
  30. @dataclass
  31. @auto_docstring(
  32. custom_intro="""
  33. Output type of [`YolosForObjectDetection`].
  34. """
  35. )
  36. class YolosObjectDetectionOutput(ModelOutput):
  37. r"""
  38. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
  39. Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
  40. bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
  41. scale-invariant IoU loss.
  42. loss_dict (`Dict`, *optional*):
  43. A dictionary containing the individual losses. Useful for logging.
  44. logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
  45. Classification logits (including no-object) for all queries.
  46. pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
  47. Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
  48. values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
  49. possible padding). You can use [`~YolosImageProcessor.post_process`] to retrieve the unnormalized bounding
  50. boxes.
  51. auxiliary_outputs (`list[Dict]`, *optional*):
  52. Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
  53. and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
  54. `pred_boxes`) for each decoder layer.
  55. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  56. Sequence of hidden-states at the output of the last layer of the decoder of the model.
  57. """
  58. loss: torch.FloatTensor | None = None
  59. loss_dict: dict | None = None
  60. logits: torch.FloatTensor | None = None
  61. pred_boxes: torch.FloatTensor | None = None
  62. auxiliary_outputs: list[dict] | None = None
  63. last_hidden_state: torch.FloatTensor | None = None
  64. hidden_states: tuple[torch.FloatTensor] | None = None
  65. attentions: tuple[torch.FloatTensor] | None = None
  66. class YolosEmbeddings(nn.Module):
  67. """
  68. Construct the CLS token, detection tokens, position and patch embeddings.
  69. """
  70. def __init__(self, config: YolosConfig) -> None:
  71. super().__init__()
  72. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  73. self.detection_tokens = nn.Parameter(torch.zeros(1, config.num_detection_tokens, config.hidden_size))
  74. self.patch_embeddings = YolosPatchEmbeddings(config)
  75. num_patches = self.patch_embeddings.num_patches
  76. self.position_embeddings = nn.Parameter(
  77. torch.zeros(1, num_patches + config.num_detection_tokens + 1, config.hidden_size)
  78. )
  79. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  80. self.interpolation = InterpolateInitialPositionEmbeddings(config)
  81. self.config = config
  82. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  83. batch_size, num_channels, height, width = pixel_values.shape
  84. embeddings = self.patch_embeddings(pixel_values)
  85. batch_size, seq_len, _ = embeddings.size()
  86. # add the [CLS] and detection tokens to the embedded patch tokens
  87. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  88. detection_tokens = self.detection_tokens.expand(batch_size, -1, -1)
  89. embeddings = torch.cat((cls_tokens, embeddings, detection_tokens), dim=1)
  90. # add positional encoding to each token
  91. # this might require interpolation of the existing position embeddings
  92. position_embeddings = self.interpolation(self.position_embeddings, (height, width))
  93. embeddings = embeddings + position_embeddings
  94. embeddings = self.dropout(embeddings)
  95. return embeddings
  96. class InterpolateInitialPositionEmbeddings(nn.Module):
  97. def __init__(self, config) -> None:
  98. super().__init__()
  99. self.config = config
  100. def forward(self, pos_embed, img_size=(800, 1344)) -> torch.Tensor:
  101. cls_pos_embed = pos_embed[:, 0, :]
  102. cls_pos_embed = cls_pos_embed[:, None]
  103. det_pos_embed = pos_embed[:, -self.config.num_detection_tokens :, :]
  104. patch_pos_embed = pos_embed[:, 1 : -self.config.num_detection_tokens, :]
  105. patch_pos_embed = patch_pos_embed.transpose(1, 2)
  106. batch_size, hidden_size, seq_len = patch_pos_embed.shape
  107. patch_height, patch_width = (
  108. self.config.image_size[0] // self.config.patch_size,
  109. self.config.image_size[1] // self.config.patch_size,
  110. )
  111. patch_pos_embed = patch_pos_embed.view(batch_size, hidden_size, patch_height, patch_width)
  112. height, width = img_size
  113. new_patch_height, new_patch_width = height // self.config.patch_size, width // self.config.patch_size
  114. patch_pos_embed = nn.functional.interpolate(
  115. patch_pos_embed, size=(new_patch_height, new_patch_width), mode="bicubic", align_corners=False
  116. )
  117. patch_pos_embed = patch_pos_embed.flatten(2).transpose(1, 2)
  118. scale_pos_embed = torch.cat((cls_pos_embed, patch_pos_embed, det_pos_embed), dim=1)
  119. return scale_pos_embed
  120. class InterpolateMidPositionEmbeddings(nn.Module):
  121. def __init__(self, config) -> None:
  122. super().__init__()
  123. self.config = config
  124. def forward(self, pos_embed, img_size=(800, 1344)) -> torch.Tensor:
  125. cls_pos_embed = pos_embed[:, :, 0, :]
  126. cls_pos_embed = cls_pos_embed[:, None]
  127. det_pos_embed = pos_embed[:, :, -self.config.num_detection_tokens :, :]
  128. patch_pos_embed = pos_embed[:, :, 1 : -self.config.num_detection_tokens, :]
  129. patch_pos_embed = patch_pos_embed.transpose(2, 3)
  130. depth, batch_size, hidden_size, seq_len = patch_pos_embed.shape
  131. patch_height, patch_width = (
  132. self.config.image_size[0] // self.config.patch_size,
  133. self.config.image_size[1] // self.config.patch_size,
  134. )
  135. patch_pos_embed = patch_pos_embed.view(depth * batch_size, hidden_size, patch_height, patch_width)
  136. height, width = img_size
  137. new_patch_height, new_patch_width = height // self.config.patch_size, width // self.config.patch_size
  138. patch_pos_embed = nn.functional.interpolate(
  139. patch_pos_embed, size=(new_patch_height, new_patch_width), mode="bicubic", align_corners=False
  140. )
  141. patch_pos_embed = (
  142. patch_pos_embed.flatten(2)
  143. .transpose(1, 2)
  144. .contiguous()
  145. .view(depth, batch_size, new_patch_height * new_patch_width, hidden_size)
  146. )
  147. scale_pos_embed = torch.cat((cls_pos_embed, patch_pos_embed, det_pos_embed), dim=2)
  148. return scale_pos_embed
  149. class YolosPatchEmbeddings(nn.Module):
  150. """
  151. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  152. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  153. Transformer.
  154. """
  155. def __init__(self, config):
  156. super().__init__()
  157. image_size, patch_size = config.image_size, config.patch_size
  158. num_channels, hidden_size = config.num_channels, config.hidden_size
  159. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  160. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  161. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  162. self.image_size = image_size
  163. self.patch_size = patch_size
  164. self.num_channels = num_channels
  165. self.num_patches = num_patches
  166. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  167. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  168. batch_size, num_channels, height, width = pixel_values.shape
  169. if num_channels != self.num_channels:
  170. raise ValueError(
  171. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  172. )
  173. embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
  174. return embeddings
  175. # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
  176. def eager_attention_forward(
  177. module: nn.Module,
  178. query: torch.Tensor,
  179. key: torch.Tensor,
  180. value: torch.Tensor,
  181. attention_mask: torch.Tensor | None,
  182. scaling: float | None = None,
  183. dropout: float = 0.0,
  184. **kwargs: Unpack[TransformersKwargs],
  185. ):
  186. if scaling is None:
  187. scaling = query.size(-1) ** -0.5
  188. # Take the dot product between "query" and "key" to get the raw attention scores.
  189. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  190. if attention_mask is not None:
  191. attn_weights = attn_weights + attention_mask
  192. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  193. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  194. attn_output = torch.matmul(attn_weights, value)
  195. attn_output = attn_output.transpose(1, 2).contiguous()
  196. return attn_output, attn_weights
  197. # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Yolos
  198. class YolosSelfAttention(nn.Module):
  199. def __init__(self, config: YolosConfig):
  200. super().__init__()
  201. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  202. raise ValueError(
  203. f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
  204. f"heads {config.num_attention_heads}."
  205. )
  206. self.config = config
  207. self.num_attention_heads = config.num_attention_heads
  208. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  209. self.all_head_size = self.num_attention_heads * self.attention_head_size
  210. self.dropout_prob = config.attention_probs_dropout_prob
  211. self.scaling = self.attention_head_size**-0.5
  212. self.is_causal = False
  213. self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  214. self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  215. self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  216. def forward(
  217. self,
  218. hidden_states: torch.Tensor,
  219. **kwargs: Unpack[TransformersKwargs],
  220. ) -> tuple[torch.Tensor, torch.Tensor]:
  221. batch_size = hidden_states.shape[0]
  222. new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
  223. key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
  224. value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
  225. query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)
  226. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  227. self.config._attn_implementation, eager_attention_forward
  228. )
  229. context_layer, attention_probs = attention_interface(
  230. self,
  231. query_layer,
  232. key_layer,
  233. value_layer,
  234. None,
  235. is_causal=self.is_causal,
  236. scaling=self.scaling,
  237. dropout=0.0 if not self.training else self.dropout_prob,
  238. **kwargs,
  239. )
  240. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  241. context_layer = context_layer.reshape(new_context_layer_shape)
  242. return context_layer, attention_probs
  243. # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Yolos
  244. class YolosSelfOutput(nn.Module):
  245. """
  246. The residual connection is defined in YolosLayer instead of here (as is the case with other models), due to the
  247. layernorm applied before each block.
  248. """
  249. def __init__(self, config: YolosConfig):
  250. super().__init__()
  251. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  252. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  253. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  254. hidden_states = self.dense(hidden_states)
  255. hidden_states = self.dropout(hidden_states)
  256. return hidden_states
  257. # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Yolos
  258. class YolosAttention(nn.Module):
  259. def __init__(self, config: YolosConfig):
  260. super().__init__()
  261. self.attention = YolosSelfAttention(config)
  262. self.output = YolosSelfOutput(config)
  263. def forward(
  264. self,
  265. hidden_states: torch.Tensor,
  266. **kwargs: Unpack[TransformersKwargs],
  267. ) -> torch.Tensor:
  268. self_attn_output, _ = self.attention(hidden_states, **kwargs)
  269. output = self.output(self_attn_output, hidden_states)
  270. return output
  271. # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->Yolos
  272. class YolosIntermediate(nn.Module):
  273. def __init__(self, config: YolosConfig):
  274. super().__init__()
  275. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  276. if isinstance(config.hidden_act, str):
  277. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  278. else:
  279. self.intermediate_act_fn = config.hidden_act
  280. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  281. hidden_states = self.dense(hidden_states)
  282. hidden_states = self.intermediate_act_fn(hidden_states)
  283. return hidden_states
  284. # Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->Yolos
  285. class YolosOutput(nn.Module):
  286. def __init__(self, config: YolosConfig):
  287. super().__init__()
  288. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  289. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  290. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  291. hidden_states = self.dense(hidden_states)
  292. hidden_states = self.dropout(hidden_states)
  293. hidden_states = hidden_states + input_tensor
  294. return hidden_states
  295. # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->Yolos,VIT->YOLOS
  296. class YolosLayer(GradientCheckpointingLayer):
  297. """This corresponds to the Block class in the timm implementation."""
  298. def __init__(self, config: YolosConfig):
  299. super().__init__()
  300. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  301. self.seq_len_dim = 1
  302. self.attention = YolosAttention(config)
  303. self.intermediate = YolosIntermediate(config)
  304. self.output = YolosOutput(config)
  305. self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  306. self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  307. def forward(
  308. self,
  309. hidden_states: torch.Tensor,
  310. **kwargs: Unpack[TransformersKwargs],
  311. ) -> torch.Tensor:
  312. hidden_states_norm = self.layernorm_before(hidden_states)
  313. attention_output = self.attention(hidden_states_norm, **kwargs)
  314. # first residual connection
  315. hidden_states = attention_output + hidden_states
  316. # in Yolos, layernorm is also applied after self-attention
  317. layer_output = self.layernorm_after(hidden_states)
  318. layer_output = self.intermediate(layer_output)
  319. # second residual connection is done here
  320. layer_output = self.output(layer_output, hidden_states)
  321. return layer_output
  322. class YolosEncoder(nn.Module):
  323. def __init__(self, config: YolosConfig) -> None:
  324. super().__init__()
  325. self.config = config
  326. self.layer = nn.ModuleList([YolosLayer(config) for _ in range(config.num_hidden_layers)])
  327. self.gradient_checkpointing = False
  328. seq_length = (
  329. 1 + (config.image_size[0] * config.image_size[1] // config.patch_size**2) + config.num_detection_tokens
  330. )
  331. self.mid_position_embeddings = (
  332. nn.Parameter(
  333. torch.zeros(
  334. config.num_hidden_layers - 1,
  335. 1,
  336. seq_length,
  337. config.hidden_size,
  338. )
  339. )
  340. if config.use_mid_position_embeddings
  341. else None
  342. )
  343. self.interpolation = InterpolateMidPositionEmbeddings(config) if config.use_mid_position_embeddings else None
  344. def forward(
  345. self,
  346. hidden_states: torch.Tensor,
  347. height: int,
  348. width: int,
  349. ) -> BaseModelOutput:
  350. if self.config.use_mid_position_embeddings:
  351. interpolated_mid_position_embeddings = self.interpolation(self.mid_position_embeddings, (height, width))
  352. for i, layer_module in enumerate(self.layer):
  353. hidden_states = layer_module(hidden_states)
  354. if self.config.use_mid_position_embeddings:
  355. if i < (self.config.num_hidden_layers - 1):
  356. hidden_states = hidden_states + interpolated_mid_position_embeddings[i]
  357. return BaseModelOutput(last_hidden_state=hidden_states)
  358. @auto_docstring
  359. class YolosPreTrainedModel(PreTrainedModel):
  360. config: YolosConfig
  361. base_model_prefix = "vit"
  362. main_input_name = "pixel_values"
  363. input_modalities = ("image",)
  364. supports_gradient_checkpointing = True
  365. _no_split_modules = []
  366. _supports_sdpa = True
  367. _supports_flash_attn = True
  368. _supports_flex_attn = True
  369. _supports_attention_backend = True
  370. _can_record_outputs = {
  371. "hidden_states": YolosLayer,
  372. "attentions": YolosSelfAttention,
  373. }
  374. @auto_docstring
  375. class YolosModel(YolosPreTrainedModel):
  376. def __init__(self, config: YolosConfig, add_pooling_layer: bool = True):
  377. r"""
  378. add_pooling_layer (bool, *optional*, defaults to `True`):
  379. Whether to add a pooling layer
  380. """
  381. super().__init__(config)
  382. self.config = config
  383. self.embeddings = YolosEmbeddings(config)
  384. self.encoder = YolosEncoder(config)
  385. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  386. self.pooler = YolosPooler(config) if add_pooling_layer else None
  387. # Initialize weights and apply final processing
  388. self.post_init()
  389. def get_input_embeddings(self) -> YolosPatchEmbeddings:
  390. return self.embeddings.patch_embeddings
  391. @merge_with_config_defaults
  392. @capture_outputs(tie_last_hidden_states=False)
  393. @auto_docstring
  394. def forward(
  395. self,
  396. pixel_values: torch.Tensor | None = None,
  397. **kwargs: Unpack[TransformersKwargs],
  398. ) -> BaseModelOutputWithPooling:
  399. if pixel_values is None:
  400. raise ValueError("You have to specify pixel_values")
  401. embedding_output = self.embeddings(pixel_values)
  402. height, width = pixel_values.shape[-2:]
  403. encoder_outputs: BaseModelOutput = self.encoder(embedding_output, height=height, width=width)
  404. sequence_output = encoder_outputs.last_hidden_state
  405. sequence_output = self.layernorm(sequence_output)
  406. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  407. return BaseModelOutputWithPooling(last_hidden_state=sequence_output, pooler_output=pooled_output)
  408. class YolosPooler(nn.Module):
  409. def __init__(self, config: YolosConfig):
  410. super().__init__()
  411. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  412. self.activation = nn.Tanh()
  413. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  414. # We "pool" the model by simply taking the hidden state corresponding
  415. # to the first token.
  416. first_token_tensor = hidden_states[:, 0]
  417. pooled_output = self.dense(first_token_tensor)
  418. pooled_output = self.activation(pooled_output)
  419. return pooled_output
  420. # Copied from transformers.models.detr.modeling_detr.DetrMLPPredictionHead with Detr->Yolos
  421. class YolosMLPPredictionHead(nn.Module):
  422. """
  423. Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
  424. height and width of a bounding box w.r.t. an image.
  425. """
  426. def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
  427. super().__init__()
  428. self.num_layers = num_layers
  429. h = [hidden_dim] * (num_layers - 1)
  430. self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
  431. def forward(self, x):
  432. for i, layer in enumerate(self.layers):
  433. x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
  434. return x
  435. @auto_docstring(
  436. custom_intro="""
  437. YOLOS Model (consisting of a ViT encoder) with object detection heads on top, for tasks such as COCO detection.
  438. """
  439. )
  440. class YolosForObjectDetection(YolosPreTrainedModel):
  441. def __init__(self, config: YolosConfig):
  442. super().__init__(config)
  443. # YOLOS (ViT) encoder model
  444. self.vit = YolosModel(config, add_pooling_layer=False)
  445. # Object detection heads
  446. # We add one for the "no object" class
  447. self.class_labels_classifier = YolosMLPPredictionHead(
  448. input_dim=config.hidden_size, hidden_dim=config.hidden_size, output_dim=config.num_labels + 1, num_layers=3
  449. )
  450. self.bbox_predictor = YolosMLPPredictionHead(
  451. input_dim=config.hidden_size, hidden_dim=config.hidden_size, output_dim=4, num_layers=3
  452. )
  453. # Initialize weights and apply final processing
  454. self.post_init()
  455. # taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
  456. def _set_aux_loss(self, outputs_class, outputs_coord):
  457. return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
  458. @can_return_tuple
  459. @auto_docstring
  460. def forward(
  461. self,
  462. pixel_values: torch.FloatTensor,
  463. labels: list[dict] | None = None,
  464. **kwargs: Unpack[TransformersKwargs],
  465. ) -> YolosObjectDetectionOutput:
  466. r"""
  467. labels (`list[Dict]` of len `(batch_size,)`, *optional*):
  468. Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
  469. following 2 keys: `'class_labels'` and `'boxes'` (the class labels and bounding boxes of an image in the
  470. batch respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding
  471. boxes in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image,
  472. 4)`.
  473. Examples:
  474. ```python
  475. >>> from transformers import AutoImageProcessor, AutoModelForObjectDetection
  476. >>> import torch
  477. >>> from PIL import Image
  478. >>> import httpx
  479. >>> from io import BytesIO
  480. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  481. >>> with httpx.stream("GET", url) as response:
  482. ... image = Image.open(BytesIO(response.read()))
  483. >>> image_processor = AutoImageProcessor.from_pretrained("hustvl/yolos-tiny")
  484. >>> model = AutoModelForObjectDetection.from_pretrained("hustvl/yolos-tiny")
  485. >>> inputs = image_processor(images=image, return_tensors="pt")
  486. >>> outputs = model(**inputs)
  487. >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
  488. >>> target_sizes = torch.tensor([image.size[::-1]])
  489. >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[
  490. ... 0
  491. ... ]
  492. >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
  493. ... box = [round(i, 2) for i in box.tolist()]
  494. ... print(
  495. ... f"Detected {model.config.id2label[label.item()]} with confidence "
  496. ... f"{round(score.item(), 3)} at location {box}"
  497. ... )
  498. Detected remote with confidence 0.991 at location [46.48, 72.78, 178.98, 119.3]
  499. Detected remote with confidence 0.908 at location [336.48, 79.27, 368.23, 192.36]
  500. Detected cat with confidence 0.934 at location [337.18, 18.06, 638.14, 373.09]
  501. Detected cat with confidence 0.979 at location [10.93, 53.74, 313.41, 470.67]
  502. Detected remote with confidence 0.974 at location [41.63, 72.23, 178.09, 119.99]
  503. ```"""
  504. # First, sent images through YOLOS base model to obtain hidden states
  505. outputs: BaseModelOutputWithPooling = self.vit(pixel_values, **kwargs)
  506. sequence_output = outputs.last_hidden_state
  507. # Take the final hidden states of the detection tokens
  508. sequence_output = sequence_output[:, -self.config.num_detection_tokens :, :]
  509. # Class logits + predicted bounding boxes
  510. logits = self.class_labels_classifier(sequence_output)
  511. pred_boxes = self.bbox_predictor(sequence_output).sigmoid()
  512. loss, loss_dict, auxiliary_outputs = None, None, None
  513. if labels is not None:
  514. outputs_class, outputs_coord = None, None
  515. if self.config.auxiliary_loss:
  516. intermediate = outputs.hidden_states
  517. outputs_class = self.class_labels_classifier(intermediate)
  518. outputs_coord = self.bbox_predictor(intermediate).sigmoid()
  519. loss, loss_dict, auxiliary_outputs = self.loss_function(
  520. logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
  521. )
  522. return YolosObjectDetectionOutput(
  523. loss=loss,
  524. loss_dict=loss_dict,
  525. logits=logits,
  526. pred_boxes=pred_boxes,
  527. auxiliary_outputs=auxiliary_outputs,
  528. last_hidden_state=outputs.last_hidden_state,
  529. hidden_states=outputs.hidden_states,
  530. attentions=outputs.attentions,
  531. )
  532. __all__ = ["YolosForObjectDetection", "YolosModel", "YolosPreTrainedModel"]