modeling_owlv2.py 67 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554
  1. # Copyright 2023 Google AI and The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch OWLv2 model."""
  15. from collections.abc import Callable
  16. from dataclasses import dataclass
  17. from typing import Any
  18. import torch
  19. from torch import Tensor, nn
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...masking_utils import create_causal_mask
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  25. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  26. from ...processing_utils import Unpack
  27. from ...utils import (
  28. ModelOutput,
  29. TransformersKwargs,
  30. auto_docstring,
  31. is_vision_available,
  32. logging,
  33. torch_int,
  34. )
  35. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  36. from ...utils.output_capturing import capture_outputs
  37. from .configuration_owlv2 import Owlv2Config, Owlv2TextConfig, Owlv2VisionConfig
  38. if is_vision_available():
  39. from transformers.image_transforms import center_to_corners_format
  40. logger = logging.get_logger(__name__)
  41. # See all Owlv2 models at https://huggingface.co/models?filter=owlv2
  42. # Copied from transformers.models.clip.modeling_clip.contrastive_loss with clip->owlv2
  43. def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
  44. return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
  45. # Copied from transformers.models.clip.modeling_clip.clip_loss with clip->owlv2
  46. def owlv2_loss(similarity: torch.Tensor) -> torch.Tensor:
  47. caption_loss = contrastive_loss(similarity)
  48. image_loss = contrastive_loss(similarity.t())
  49. return (caption_loss + image_loss) / 2.0
  50. @dataclass
  51. @auto_docstring
  52. class Owlv2Output(ModelOutput):
  53. r"""
  54. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  55. Contrastive loss for image-text similarity.
  56. logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
  57. The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
  58. similarity scores.
  59. logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
  60. The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
  61. similarity scores.
  62. text_embeds (`torch.FloatTensor` of shape `(batch_size * num_max_text_queries, output_dim`):
  63. The text embeddings obtained by applying the projection layer to the pooled output of [`Owlv2TextModel`].
  64. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  65. The image embeddings obtained by applying the projection layer to the pooled output of
  66. [`Owlv2VisionModel`].
  67. text_model_output (tuple[`BaseModelOutputWithPooling`]):
  68. The output of the [`Owlv2TextModel`].
  69. vision_model_output (`BaseModelOutputWithPooling`):
  70. The output of the [`Owlv2VisionModel`].
  71. """
  72. loss: torch.FloatTensor | None = None
  73. logits_per_image: torch.FloatTensor | None = None
  74. logits_per_text: torch.FloatTensor | None = None
  75. text_embeds: torch.FloatTensor | None = None
  76. image_embeds: torch.FloatTensor | None = None
  77. text_model_output: BaseModelOutputWithPooling = None
  78. vision_model_output: BaseModelOutputWithPooling = None
  79. def to_tuple(self) -> tuple[Any]:
  80. return tuple(
  81. self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
  82. for k in self.keys()
  83. )
  84. # Copied from transformers.loss.loss_for_object_detection._upcast
  85. def _upcast(t: Tensor) -> Tensor:
  86. # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
  87. if t.is_floating_point():
  88. return t if t.dtype in (torch.float32, torch.float64) else t.float()
  89. else:
  90. return t if t.dtype in (torch.int32, torch.int64) else t.int()
  91. # Copied from transformers.loss.loss_for_object_detection.box_area
  92. def box_area(boxes: Tensor) -> Tensor:
  93. """
  94. Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
  95. Args:
  96. boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
  97. Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
  98. < x2` and `0 <= y1 < y2`.
  99. Returns:
  100. `torch.FloatTensor`: a tensor containing the area for each box.
  101. """
  102. boxes = _upcast(boxes)
  103. return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
  104. # Copied from transformers.loss.loss_for_object_detection.box_iou
  105. def box_iou(boxes1, boxes2):
  106. area1 = box_area(boxes1)
  107. area2 = box_area(boxes2)
  108. left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
  109. right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
  110. width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
  111. inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
  112. union = area1[:, None] + area2 - inter
  113. iou = inter / union
  114. return iou, union
  115. # Copied from transformers.loss.loss_for_object_detection.generalized_box_iou
  116. def generalized_box_iou(boxes1, boxes2):
  117. """
  118. Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.
  119. Returns:
  120. `torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
  121. """
  122. # degenerate boxes gives inf / nan results
  123. # so do an early check
  124. if not (boxes1[:, 2:] >= boxes1[:, :2]).all():
  125. raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}")
  126. if not (boxes2[:, 2:] >= boxes2[:, :2]).all():
  127. raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}")
  128. iou, union = box_iou(boxes1, boxes2)
  129. top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2])
  130. bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
  131. width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2]
  132. area = width_height[:, :, 0] * width_height[:, :, 1]
  133. return iou - (area - union) / area
  134. @dataclass
  135. @auto_docstring(
  136. custom_intro="""
  137. Output type of [`Owlv2ForObjectDetection`].
  138. """
  139. )
  140. class Owlv2ObjectDetectionOutput(ModelOutput):
  141. r"""
  142. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
  143. Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
  144. bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
  145. scale-invariant IoU loss.
  146. loss_dict (`Dict`, *optional*):
  147. A dictionary containing the individual losses. Useful for logging.
  148. logits (`torch.FloatTensor` of shape `(batch_size, num_patches, num_queries)`):
  149. Classification logits (including no-object) for all queries.
  150. objectness_logits (`torch.FloatTensor` of shape `(batch_size, num_patches, 1)`):
  151. The objectness logits of all image patches. OWL-ViT represents images as a set of image patches where the
  152. total number of patches is (image_size / patch_size)**2.
  153. pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
  154. Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
  155. values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
  156. possible padding). You can use [`~Owlv2ImageProcessor.post_process_object_detection`] to retrieve the
  157. unnormalized bounding boxes.
  158. text_embeds (`torch.FloatTensor` of shape `(batch_size, num_max_text_queries, output_dim`):
  159. The text embeddings obtained by applying the projection layer to the pooled output of [`Owlv2TextModel`].
  160. image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
  161. Pooled output of [`Owlv2VisionModel`]. OWLv2 represents images as a set of image patches and computes image
  162. embeddings for each patch.
  163. class_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`):
  164. Class embeddings of all image patches. OWLv2 represents images as a set of image patches where the total
  165. number of patches is (image_size / patch_size)**2.
  166. text_model_output (tuple[`BaseModelOutputWithPooling`]):
  167. The output of the [`Owlv2TextModel`].
  168. vision_model_output (`BaseModelOutputWithPooling`):
  169. The output of the [`Owlv2VisionModel`].
  170. """
  171. loss: torch.FloatTensor | None = None
  172. loss_dict: dict | None = None
  173. logits: torch.FloatTensor | None = None
  174. objectness_logits: torch.FloatTensor | None = None
  175. pred_boxes: torch.FloatTensor | None = None
  176. text_embeds: torch.FloatTensor | None = None
  177. image_embeds: torch.FloatTensor | None = None
  178. class_embeds: torch.FloatTensor | None = None
  179. text_model_output: BaseModelOutputWithPooling = None
  180. vision_model_output: BaseModelOutputWithPooling = None
  181. def to_tuple(self) -> tuple[Any]:
  182. return tuple(
  183. self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
  184. for k in self.keys()
  185. )
  186. @dataclass
  187. @auto_docstring(
  188. custom_intro="""
  189. Output type of [`Owlv2ForObjectDetection.image_guided_detection`].
  190. """
  191. )
  192. # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTImageGuidedObjectDetectionOutput with OwlViT->Owlv2,OWL-ViT->OWLv2
  193. class Owlv2ImageGuidedObjectDetectionOutput(ModelOutput):
  194. r"""
  195. logits (`torch.FloatTensor` of shape `(batch_size, num_patches, num_queries)`):
  196. Classification logits (including no-object) for all queries.
  197. image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
  198. Pooled output of [`Owlv2VisionModel`]. OWLv2 represents images as a set of image patches and computes
  199. image embeddings for each patch.
  200. query_image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
  201. Pooled output of [`Owlv2VisionModel`]. OWLv2 represents images as a set of image patches and computes
  202. image embeddings for each patch.
  203. target_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
  204. Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
  205. values are normalized in [0, 1], relative to the size of each individual target image in the batch
  206. (disregarding possible padding). You can use [`~Owlv2ImageProcessor.post_process_object_detection`] to
  207. retrieve the unnormalized bounding boxes.
  208. query_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
  209. Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
  210. values are normalized in [0, 1], relative to the size of each individual query image in the batch
  211. (disregarding possible padding). You can use [`~Owlv2ImageProcessor.post_process_object_detection`] to
  212. retrieve the unnormalized bounding boxes.
  213. class_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`):
  214. Class embeddings of all image patches. OWLv2 represents images as a set of image patches where the total
  215. number of patches is (image_size / patch_size)**2.
  216. text_model_output (tuple[`BaseModelOutputWithPooling`]):
  217. The output of the [`Owlv2TextModel`].
  218. vision_model_output (`BaseModelOutputWithPooling`):
  219. The output of the [`Owlv2VisionModel`].
  220. """
  221. logits: torch.FloatTensor | None = None
  222. image_embeds: torch.FloatTensor | None = None
  223. query_image_embeds: torch.FloatTensor | None = None
  224. target_pred_boxes: torch.FloatTensor | None = None
  225. query_pred_boxes: torch.FloatTensor | None = None
  226. class_embeds: torch.FloatTensor | None = None
  227. text_model_output: BaseModelOutputWithPooling = None
  228. vision_model_output: BaseModelOutputWithPooling = None
  229. def to_tuple(self) -> tuple[Any]:
  230. return tuple(
  231. self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
  232. for k in self.keys()
  233. )
  234. # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTVisionEmbeddings with OwlViT->Owlv2
  235. class Owlv2VisionEmbeddings(nn.Module):
  236. def __init__(self, config: Owlv2VisionConfig):
  237. super().__init__()
  238. self.patch_size = config.patch_size
  239. self.config = config
  240. self.embed_dim = config.hidden_size
  241. self.class_embedding = nn.Parameter(torch.randn(config.hidden_size))
  242. self.patch_embedding = nn.Conv2d(
  243. in_channels=config.num_channels,
  244. out_channels=self.embed_dim,
  245. kernel_size=config.patch_size,
  246. stride=config.patch_size,
  247. bias=False,
  248. )
  249. self.num_patches = (config.image_size // config.patch_size) ** 2
  250. self.num_positions = self.num_patches + 1
  251. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  252. self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
  253. # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings.interpolate_pos_encoding
  254. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  255. """
  256. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  257. images. This method is also adapted to support torch.jit tracing.
  258. Adapted from:
  259. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  260. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  261. """
  262. num_patches = embeddings.shape[1] - 1
  263. position_embedding = self.position_embedding.weight.unsqueeze(0)
  264. num_positions = position_embedding.shape[1] - 1
  265. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  266. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  267. return self.position_embedding(self.position_ids)
  268. class_pos_embed = position_embedding[:, :1]
  269. patch_pos_embed = position_embedding[:, 1:]
  270. dim = embeddings.shape[-1]
  271. new_height = height // self.patch_size
  272. new_width = width // self.patch_size
  273. sqrt_num_positions = torch_int(num_positions**0.5)
  274. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  275. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  276. patch_pos_embed = nn.functional.interpolate(
  277. patch_pos_embed,
  278. size=(new_height, new_width),
  279. mode="bicubic",
  280. align_corners=False,
  281. )
  282. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  283. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  284. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  285. batch_size, _, height, width = pixel_values.shape
  286. patch_embeds = self.patch_embedding(pixel_values) # shape = [batch_size, num_channels, height, width]
  287. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  288. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  289. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  290. if interpolate_pos_encoding:
  291. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  292. else:
  293. embeddings = embeddings + self.position_embedding(self.position_ids)
  294. return embeddings
  295. # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTTextEmbeddings with OwlViT->Owlv2
  296. class Owlv2TextEmbeddings(nn.Module):
  297. def __init__(self, config: Owlv2TextConfig):
  298. super().__init__()
  299. self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
  300. self.position_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  301. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  302. self.register_buffer(
  303. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  304. )
  305. def forward(
  306. self,
  307. input_ids: torch.LongTensor | None = None,
  308. position_ids: torch.LongTensor | None = None,
  309. inputs_embeds: torch.FloatTensor | None = None,
  310. ) -> torch.Tensor:
  311. seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
  312. if position_ids is None:
  313. position_ids = self.position_ids[:, :seq_length]
  314. if inputs_embeds is None:
  315. inputs_embeds = self.token_embedding(input_ids)
  316. position_embeddings = self.position_embedding(position_ids)
  317. embeddings = inputs_embeds + position_embeddings
  318. return embeddings
  319. # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
  320. def eager_attention_forward(
  321. module: nn.Module,
  322. query: torch.Tensor,
  323. key: torch.Tensor,
  324. value: torch.Tensor,
  325. attention_mask: torch.Tensor | None,
  326. scaling: float | None = None,
  327. dropout: float = 0.0,
  328. **kwargs: Unpack[TransformersKwargs],
  329. ):
  330. if scaling is None:
  331. scaling = query.size(-1) ** -0.5
  332. # Take the dot product between "query" and "key" to get the raw attention scores.
  333. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  334. if attention_mask is not None:
  335. attn_weights = attn_weights + attention_mask
  336. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  337. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  338. attn_output = torch.matmul(attn_weights, value)
  339. attn_output = attn_output.transpose(1, 2).contiguous()
  340. return attn_output, attn_weights
  341. # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTAttention with OwlViT->Owlv2
  342. class Owlv2Attention(nn.Module):
  343. """Multi-headed attention from 'Attention Is All You Need' paper"""
  344. def __init__(self, config):
  345. super().__init__()
  346. self.config = config
  347. self.embed_dim = config.hidden_size
  348. self.num_heads = config.num_attention_heads
  349. self.head_dim = self.embed_dim // self.num_heads
  350. if self.head_dim * self.num_heads != self.embed_dim:
  351. raise ValueError(
  352. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  353. f" {self.num_heads})."
  354. )
  355. self.scale = self.head_dim**-0.5
  356. self.dropout = config.attention_dropout
  357. self.is_causal = False
  358. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  359. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  360. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  361. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  362. def forward(
  363. self,
  364. hidden_states: torch.Tensor,
  365. attention_mask: torch.Tensor | None = None,
  366. **kwargs: Unpack[TransformersKwargs],
  367. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  368. input_shape = hidden_states.shape[:-1]
  369. hidden_shape = (*input_shape, -1, self.head_dim)
  370. query_states = self.q_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
  371. key_states = self.k_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
  372. value_states = self.v_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
  373. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  374. self.config._attn_implementation, eager_attention_forward
  375. )
  376. attn_output, attn_weights = attention_interface(
  377. self,
  378. query_states,
  379. key_states,
  380. value_states,
  381. attention_mask,
  382. scaling=self.scale,
  383. dropout=0.0 if not self.training else self.dropout,
  384. **kwargs,
  385. )
  386. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  387. attn_output = self.out_proj(attn_output)
  388. return attn_output, attn_weights
  389. # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Owlv2
  390. class Owlv2MLP(nn.Module):
  391. def __init__(self, config):
  392. super().__init__()
  393. self.config = config
  394. self.activation_fn = ACT2FN[config.hidden_act]
  395. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  396. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  397. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  398. hidden_states = self.fc1(hidden_states)
  399. hidden_states = self.activation_fn(hidden_states)
  400. hidden_states = self.fc2(hidden_states)
  401. return hidden_states
  402. # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Owlv2
  403. class Owlv2EncoderLayer(GradientCheckpointingLayer):
  404. def __init__(self, config: Owlv2VisionConfig | Owlv2TextConfig):
  405. super().__init__()
  406. self.embed_dim = config.hidden_size
  407. self.self_attn = Owlv2Attention(config)
  408. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  409. self.mlp = Owlv2MLP(config)
  410. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  411. def forward(
  412. self,
  413. hidden_states: torch.Tensor,
  414. attention_mask: torch.Tensor,
  415. **kwargs: Unpack[TransformersKwargs],
  416. ) -> torch.FloatTensor:
  417. residual = hidden_states
  418. hidden_states = self.layer_norm1(hidden_states)
  419. hidden_states, _ = self.self_attn(
  420. hidden_states=hidden_states,
  421. attention_mask=attention_mask,
  422. **kwargs,
  423. )
  424. hidden_states = residual + hidden_states
  425. residual = hidden_states
  426. hidden_states = self.layer_norm2(hidden_states)
  427. hidden_states = self.mlp(hidden_states)
  428. hidden_states = residual + hidden_states
  429. return hidden_states
  430. @auto_docstring
  431. # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTPreTrainedModel with OwlViT->Owlv2,owlvit->owlv2
  432. class Owlv2PreTrainedModel(PreTrainedModel):
  433. config: Owlv2Config
  434. base_model_prefix = "owlv2"
  435. input_modalities = ("image", "text")
  436. supports_gradient_checkpointing = True
  437. _supports_sdpa = True
  438. _supports_flash_attn = True
  439. _supports_flex_attn = True
  440. _supports_attention_backend = True
  441. _no_split_modules = ["Owlv2EncoderLayer"]
  442. _can_record_outputs = {
  443. "hidden_states": Owlv2EncoderLayer,
  444. "attentions": Owlv2Attention,
  445. }
  446. _keys_to_ignore_on_load_unexpected = [
  447. r".*text_model\.embeddings\.position_ids",
  448. r".*vision_model\.embeddings\.position_ids",
  449. ]
  450. @torch.no_grad()
  451. def _init_weights(self, module: nn.Module):
  452. """Initialize the weights"""
  453. factor = self.config.initializer_factor
  454. if isinstance(module, Owlv2TextEmbeddings):
  455. init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02)
  456. init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02)
  457. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  458. elif isinstance(module, Owlv2VisionEmbeddings):
  459. init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
  460. init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
  461. init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
  462. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  463. elif isinstance(module, Owlv2Attention):
  464. in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  465. out_proj_std = (module.embed_dim**-0.5) * factor
  466. init.normal_(module.q_proj.weight, std=in_proj_std)
  467. init.normal_(module.k_proj.weight, std=in_proj_std)
  468. init.normal_(module.v_proj.weight, std=in_proj_std)
  469. init.normal_(module.out_proj.weight, std=out_proj_std)
  470. elif isinstance(module, Owlv2MLP):
  471. in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  472. fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
  473. init.normal_(module.fc1.weight, std=fc_std)
  474. init.normal_(module.fc2.weight, std=in_proj_std)
  475. elif isinstance(module, Owlv2Model):
  476. init.normal_(
  477. module.text_projection.weight,
  478. std=module.text_embed_dim**-0.5 * factor,
  479. )
  480. init.normal_(
  481. module.visual_projection.weight,
  482. std=module.vision_embed_dim**-0.5 * factor,
  483. )
  484. init.constant_(module.logit_scale, self.config.logit_scale_init_value)
  485. elif isinstance(module, Owlv2ForObjectDetection):
  486. init.copy_(module.box_bias, module.compute_box_bias(module.num_patches_height, module.num_patches_width))
  487. if isinstance(module, nn.LayerNorm):
  488. init.zeros_(module.bias)
  489. init.ones_(module.weight)
  490. if isinstance(module, nn.Linear):
  491. init.normal_(module.weight, mean=0.0, std=factor)
  492. if module.bias is not None:
  493. init.zeros_(module.bias)
  494. # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTEncoder with OwlViT->Owlv2
  495. class Owlv2Encoder(nn.Module):
  496. """
  497. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  498. [`Owlv2EncoderLayer`].
  499. Args:
  500. config: Owlv2Config
  501. """
  502. def __init__(self, config: Owlv2Config):
  503. super().__init__()
  504. self.config = config
  505. self.layers = nn.ModuleList([Owlv2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
  506. self.gradient_checkpointing = False
  507. def forward(
  508. self,
  509. inputs_embeds,
  510. attention_mask: torch.Tensor | None = None,
  511. **kwargs: Unpack[TransformersKwargs],
  512. ) -> BaseModelOutput:
  513. r"""
  514. Args:
  515. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  516. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  517. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  518. than the model's internal embedding lookup matrix.
  519. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  520. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  521. - 1 for tokens that are **not masked**,
  522. - 0 for tokens that are **masked**.
  523. [What are attention masks?](../glossary#attention-mask)
  524. """
  525. hidden_states = inputs_embeds
  526. for encoder_layer in self.layers:
  527. hidden_states = encoder_layer(
  528. hidden_states,
  529. attention_mask,
  530. **kwargs,
  531. )
  532. return BaseModelOutput(
  533. last_hidden_state=hidden_states,
  534. )
  535. # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTTextTransformer with OWLVIT->OWLV2,OwlViT->Owlv2
  536. class Owlv2TextTransformer(Owlv2PreTrainedModel):
  537. def __init__(self, config: Owlv2TextConfig):
  538. super().__init__(config)
  539. embed_dim = config.hidden_size
  540. self.embeddings = Owlv2TextEmbeddings(config)
  541. self.encoder = Owlv2Encoder(config)
  542. self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  543. # Initialize weights and apply final processing
  544. self.post_init()
  545. @merge_with_config_defaults
  546. @capture_outputs(tie_last_hidden_states=False)
  547. @auto_docstring
  548. def forward(
  549. self,
  550. input_ids: torch.Tensor | None = None,
  551. attention_mask: torch.Tensor | None = None,
  552. position_ids: torch.Tensor | None = None,
  553. **kwargs: Unpack[TransformersKwargs],
  554. ) -> tuple | BaseModelOutputWithPooling:
  555. r"""
  556. input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`):
  557. Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
  558. [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
  559. IDs?](../glossary#input-ids)
  560. """
  561. input_shape = input_ids.size()
  562. input_ids = input_ids.view(-1, input_shape[-1])
  563. hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
  564. attention_mask = create_causal_mask(
  565. config=self.config,
  566. inputs_embeds=hidden_states,
  567. attention_mask=attention_mask,
  568. past_key_values=None,
  569. )
  570. kwargs.pop("is_causal", None)
  571. encoder_outputs: BaseModelOutput = self.encoder(
  572. inputs_embeds=hidden_states,
  573. attention_mask=attention_mask,
  574. is_causal=True,
  575. **kwargs,
  576. )
  577. last_hidden_state = encoder_outputs.last_hidden_state
  578. last_hidden_state = self.final_layer_norm(last_hidden_state)
  579. # take features from the end of tokens embedding (end of token is the highest number in each sequence)
  580. # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
  581. pooled_output = last_hidden_state[
  582. torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
  583. input_ids.to(torch.int).argmax(dim=-1).to(last_hidden_state.device),
  584. ]
  585. return BaseModelOutputWithPooling(
  586. last_hidden_state=last_hidden_state,
  587. pooler_output=pooled_output,
  588. )
  589. # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTTextModel with google/owlvit-base-patch32->google/owlv2-base-patch16, OWLVIT->OWLV2,OwlViT->Owlv2
  590. class Owlv2TextModel(Owlv2PreTrainedModel):
  591. config: Owlv2TextConfig
  592. input_modalities = ("text",)
  593. def __init__(self, config: Owlv2TextConfig):
  594. super().__init__(config)
  595. self.text_model = Owlv2TextTransformer(config)
  596. # Initialize weights and apply final processing
  597. self.post_init()
  598. def get_input_embeddings(self) -> nn.Module:
  599. return self.text_model.embeddings.token_embedding
  600. def set_input_embeddings(self, value):
  601. self.text_model.embeddings.token_embedding = value
  602. @auto_docstring
  603. def forward(
  604. self,
  605. input_ids: torch.Tensor | None = None,
  606. attention_mask: torch.Tensor | None = None,
  607. **kwargs: Unpack[TransformersKwargs],
  608. ) -> tuple | BaseModelOutputWithPooling:
  609. r"""
  610. input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`):
  611. Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
  612. [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
  613. IDs?](../glossary#input-ids)
  614. Examples:
  615. ```python
  616. >>> from transformers import AutoProcessor, Owlv2TextModel
  617. >>> model = Owlv2TextModel.from_pretrained("google/owlv2-base-patch16")
  618. >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16")
  619. >>> inputs = processor(
  620. ... text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt"
  621. ... )
  622. >>> outputs = model(**inputs)
  623. >>> last_hidden_state = outputs.last_hidden_state
  624. >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
  625. ```"""
  626. return self.text_model(
  627. input_ids=input_ids,
  628. attention_mask=attention_mask,
  629. **kwargs,
  630. )
  631. # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTVisionTransformer with OWLVIT->OWLV2,OwlViT->Owlv2
  632. class Owlv2VisionTransformer(Owlv2PreTrainedModel):
  633. def __init__(self, config: Owlv2VisionConfig):
  634. super().__init__(config)
  635. self.embeddings = Owlv2VisionEmbeddings(config)
  636. self.pre_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  637. self.encoder = Owlv2Encoder(config)
  638. self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  639. # Initialize weights and apply final processing
  640. self.post_init()
  641. @merge_with_config_defaults
  642. @capture_outputs(tie_last_hidden_states=False)
  643. @auto_docstring
  644. def forward(
  645. self,
  646. pixel_values: torch.FloatTensor,
  647. interpolate_pos_encoding: bool | None = False,
  648. **kwargs: Unpack[TransformersKwargs],
  649. ) -> tuple | BaseModelOutputWithPooling:
  650. # Cast the input to the expected `dtype`
  651. expected_input_dtype = self.embeddings.patch_embedding.weight.dtype
  652. pixel_values = pixel_values.to(expected_input_dtype)
  653. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  654. hidden_states = self.pre_layernorm(hidden_states)
  655. encoder_outputs: BaseModelOutput = self.encoder(
  656. inputs_embeds=hidden_states,
  657. **kwargs,
  658. )
  659. last_hidden_state = encoder_outputs.last_hidden_state
  660. pooled_output = last_hidden_state[:, 0, :]
  661. pooled_output = self.post_layernorm(pooled_output)
  662. return BaseModelOutputWithPooling(
  663. last_hidden_state=last_hidden_state,
  664. pooler_output=pooled_output,
  665. )
  666. # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTVisionModel with OWLVIT->OWLV2,OwlViT->Owlv2,google/owlvit-base-patch32->google/owlv2-base-patch16
  667. class Owlv2VisionModel(Owlv2PreTrainedModel):
  668. config: Owlv2VisionConfig
  669. main_input_name = "pixel_values"
  670. input_modalities = ("image",)
  671. def __init__(self, config: Owlv2VisionConfig):
  672. super().__init__(config)
  673. self.vision_model = Owlv2VisionTransformer(config)
  674. # Initialize weights and apply final processing
  675. self.post_init()
  676. def get_input_embeddings(self) -> nn.Module:
  677. return self.vision_model.embeddings.patch_embedding
  678. @auto_docstring
  679. def forward(
  680. self,
  681. pixel_values: torch.FloatTensor | None = None,
  682. interpolate_pos_encoding: bool = False,
  683. **kwargs: Unpack[TransformersKwargs],
  684. ) -> BaseModelOutputWithPooling:
  685. r"""
  686. Examples:
  687. ```python
  688. >>> from PIL import Image
  689. >>> import httpx
  690. >>> from io import BytesIO
  691. >>> from transformers import AutoProcessor, Owlv2VisionModel
  692. >>> model = Owlv2VisionModel.from_pretrained("google/owlv2-base-patch16")
  693. >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16")
  694. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  695. >>> with httpx.stream("GET", url) as response:
  696. ... image = Image.open(BytesIO(response.read()))
  697. >>> inputs = processor(images=image, return_tensors="pt")
  698. >>> outputs = model(**inputs)
  699. >>> last_hidden_state = outputs.last_hidden_state
  700. >>> pooled_output = outputs.pooler_output # pooled CLS states
  701. ```"""
  702. return self.vision_model(
  703. pixel_values=pixel_values,
  704. interpolate_pos_encoding=interpolate_pos_encoding,
  705. **kwargs,
  706. )
  707. @auto_docstring
  708. # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTModel with google/owlvit-base-patch32->google/owlv2-base-patch16-ensemble, OWLVIT->OWLV2,OwlViT->Owlv2,owlvit->owlv2,OWL-ViT->OWLv2
  709. class Owlv2Model(Owlv2PreTrainedModel):
  710. config: Owlv2Config
  711. def __init__(self, config: Owlv2Config):
  712. super().__init__(config)
  713. text_config = config.text_config
  714. vision_config = config.vision_config
  715. self.projection_dim = config.projection_dim
  716. self.text_embed_dim = text_config.hidden_size
  717. self.vision_embed_dim = vision_config.hidden_size
  718. self.text_model = Owlv2TextTransformer(text_config)
  719. self.vision_model = Owlv2VisionTransformer(vision_config)
  720. self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
  721. self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
  722. self.logit_scale = nn.Parameter(torch.tensor(config.logit_scale_init_value))
  723. # Initialize weights and apply final processing
  724. self.post_init()
  725. @can_return_tuple
  726. @auto_docstring
  727. def get_text_features(
  728. self,
  729. input_ids: torch.Tensor,
  730. attention_mask: torch.Tensor | None = None,
  731. **kwargs: Unpack[TransformersKwargs],
  732. ) -> tuple | BaseModelOutputWithPooling:
  733. r"""
  734. input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`):
  735. Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
  736. [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
  737. IDs?](../glossary#input-ids)
  738. Examples:
  739. ```python
  740. >>> import torch
  741. >>> from transformers import AutoProcessor, Owlv2Model
  742. >>> model = Owlv2Model.from_pretrained("google/owlv2-base-patch16-ensemble")
  743. >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
  744. >>> inputs = processor(
  745. ... text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt"
  746. ... )
  747. >>> with torch.inference_mode():
  748. ... text_features = model.get_text_features(**inputs)
  749. ```"""
  750. text_outputs: BaseModelOutputWithPooling = self.text_model(
  751. input_ids=input_ids,
  752. attention_mask=attention_mask,
  753. **kwargs,
  754. )
  755. pooled_output = text_outputs.pooler_output
  756. text_outputs.pooler_output = self.text_projection(pooled_output)
  757. return text_outputs
  758. @can_return_tuple
  759. @auto_docstring
  760. def get_image_features(
  761. self,
  762. pixel_values: torch.Tensor,
  763. interpolate_pos_encoding: bool = False,
  764. **kwargs: Unpack[TransformersKwargs],
  765. ) -> tuple | BaseModelOutputWithPooling:
  766. r"""
  767. Examples:
  768. ```python
  769. >>> import torch
  770. >>> from transformers.image_utils import load_image
  771. >>> from transformers import AutoProcessor, Owlv2Model
  772. >>> model = Owlv2Model.from_pretrained("google/owlv2-base-patch16-ensemble")
  773. >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
  774. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  775. >>> image = load_image(url)
  776. >>> inputs = processor(images=image, return_tensors="pt")
  777. >>> with torch.inference_mode():
  778. ... image_features = model.get_image_features(**inputs)
  779. ```"""
  780. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  781. pixel_values=pixel_values,
  782. interpolate_pos_encoding=interpolate_pos_encoding,
  783. **kwargs,
  784. )
  785. vision_outputs.pooler_output = self.visual_projection(vision_outputs.pooler_output)
  786. return vision_outputs
  787. @can_return_tuple
  788. @auto_docstring
  789. def forward(
  790. self,
  791. input_ids: torch.LongTensor | None = None,
  792. pixel_values: torch.FloatTensor | None = None,
  793. attention_mask: torch.Tensor | None = None,
  794. return_loss: bool | None = None,
  795. interpolate_pos_encoding: bool = False,
  796. return_base_image_embeds: bool | None = None,
  797. **kwargs: Unpack[TransformersKwargs],
  798. ) -> tuple | Owlv2Output:
  799. r"""
  800. return_loss (`bool`, *optional*):
  801. Whether or not to return the contrastive loss.
  802. return_base_image_embeds (`bool`, *optional*):
  803. Whether or not to return the base image embeddings.
  804. Examples:
  805. ```python
  806. >>> from PIL import Image
  807. >>> import httpx
  808. >>> from io import BytesIO
  809. >>> from transformers import AutoProcessor, Owlv2Model
  810. >>> model = Owlv2Model.from_pretrained("google/owlv2-base-patch16-ensemble")
  811. >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
  812. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  813. >>> with httpx.stream("GET", url) as response:
  814. ... image = Image.open(BytesIO(response.read()))
  815. >>> inputs = processor(text=[["a photo of a cat", "a photo of a dog"]], images=image, return_tensors="pt")
  816. >>> outputs = model(**inputs)
  817. >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
  818. >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
  819. ```"""
  820. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  821. pixel_values=pixel_values,
  822. interpolate_pos_encoding=interpolate_pos_encoding,
  823. **kwargs,
  824. )
  825. # Get embeddings for all text queries in all batch samples
  826. text_outputs: BaseModelOutputWithPooling = self.text_model(
  827. input_ids=input_ids,
  828. attention_mask=attention_mask,
  829. **kwargs,
  830. )
  831. text_embeds = text_outputs.pooler_output
  832. text_embeds = self.text_projection(text_embeds)
  833. image_embeds = vision_outputs.pooler_output
  834. image_embeds = self.visual_projection(image_embeds)
  835. # normalized features
  836. image_embeds = image_embeds / torch.linalg.norm(image_embeds, ord=2, dim=-1, keepdim=True)
  837. text_embeds_norm = text_embeds / torch.linalg.norm(text_embeds, ord=2, dim=-1, keepdim=True)
  838. # cosine similarity as logits and set it on the correct device
  839. logit_scale = self.logit_scale.exp().to(image_embeds.device)
  840. logits_per_text = torch.matmul(text_embeds_norm, image_embeds.t()) * logit_scale
  841. logits_per_image = logits_per_text.t()
  842. loss = None
  843. if return_loss:
  844. loss = owlv2_loss(logits_per_text)
  845. text_embeds = text_embeds_norm
  846. return Owlv2Output(
  847. loss=loss,
  848. logits_per_image=logits_per_image,
  849. logits_per_text=logits_per_text,
  850. text_embeds=text_embeds,
  851. image_embeds=image_embeds,
  852. text_model_output=text_outputs,
  853. vision_model_output=vision_outputs,
  854. )
  855. # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTBoxPredictionHead with OwlViT->Owlv2
  856. class Owlv2BoxPredictionHead(nn.Module):
  857. def __init__(self, config: Owlv2Config, out_dim: int = 4):
  858. super().__init__()
  859. width = config.vision_config.hidden_size
  860. self.dense0 = nn.Linear(width, width)
  861. self.dense1 = nn.Linear(width, width)
  862. self.gelu = nn.GELU()
  863. self.dense2 = nn.Linear(width, out_dim)
  864. def forward(self, image_features: torch.Tensor) -> torch.FloatTensor:
  865. output = self.dense0(image_features)
  866. output = self.gelu(output)
  867. output = self.dense1(output)
  868. output = self.gelu(output)
  869. output = self.dense2(output)
  870. return output
  871. # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTClassPredictionHead with OwlViT->Owlv2
  872. class Owlv2ClassPredictionHead(nn.Module):
  873. def __init__(self, config: Owlv2Config):
  874. super().__init__()
  875. out_dim = config.text_config.hidden_size
  876. self.query_dim = config.vision_config.hidden_size
  877. self.dense0 = nn.Linear(self.query_dim, out_dim)
  878. self.logit_shift = nn.Linear(self.query_dim, 1)
  879. self.logit_scale = nn.Linear(self.query_dim, 1)
  880. self.elu = nn.ELU()
  881. def forward(
  882. self,
  883. image_embeds: torch.FloatTensor,
  884. query_embeds: torch.FloatTensor | None,
  885. query_mask: torch.Tensor | None,
  886. ) -> tuple[torch.FloatTensor]:
  887. image_class_embeds = self.dense0(image_embeds)
  888. if query_embeds is None:
  889. device = image_class_embeds.device
  890. batch_size, num_patches = image_class_embeds.shape[:2]
  891. pred_logits = torch.zeros((batch_size, num_patches, self.query_dim)).to(device)
  892. return (pred_logits, image_class_embeds)
  893. # Normalize image and text features
  894. image_class_embeds = image_class_embeds / (torch.linalg.norm(image_class_embeds, dim=-1, keepdim=True) + 1e-6)
  895. query_embeds = query_embeds / (torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6)
  896. # Get class predictions
  897. pred_logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds)
  898. # Apply a learnable shift and scale to logits
  899. logit_shift = self.logit_shift(image_embeds)
  900. logit_scale = self.logit_scale(image_embeds)
  901. logit_scale = self.elu(logit_scale) + 1
  902. pred_logits = (pred_logits + logit_shift) * logit_scale
  903. if query_mask is not None:
  904. if query_mask.ndim > 1:
  905. query_mask = torch.unsqueeze(query_mask, dim=-2)
  906. pred_logits = torch.where(query_mask == 0, torch.finfo(pred_logits.dtype).min, pred_logits)
  907. pred_logits = pred_logits.to(torch.float32)
  908. return (pred_logits, image_class_embeds)
  909. class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
  910. config: Owlv2Config
  911. def __init__(self, config: Owlv2Config):
  912. super().__init__(config)
  913. self.owlv2 = Owlv2Model(config)
  914. self.class_head = Owlv2ClassPredictionHead(config)
  915. self.box_head = Owlv2BoxPredictionHead(config)
  916. self.objectness_head = Owlv2BoxPredictionHead(config, out_dim=1)
  917. self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps)
  918. self.sigmoid = nn.Sigmoid()
  919. self.config = config
  920. self.num_patches_height = self.config.vision_config.image_size // self.config.vision_config.patch_size
  921. self.num_patches_width = self.config.vision_config.image_size // self.config.vision_config.patch_size
  922. self.register_buffer(
  923. "box_bias", self.compute_box_bias(self.num_patches_height, self.num_patches_width), persistent=False
  924. )
  925. # Initialize weights and apply final processing
  926. self.post_init()
  927. @staticmethod
  928. # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.normalize_grid_corner_coordinates
  929. def normalize_grid_corner_coordinates(num_patches_height: int, num_patches_width: int) -> torch.Tensor:
  930. # Create grid coordinates using torch
  931. x_coordinates = torch.arange(1, num_patches_width + 1, dtype=torch.float32)
  932. y_coordinates = torch.arange(1, num_patches_height + 1, dtype=torch.float32)
  933. xx, yy = torch.meshgrid(x_coordinates, y_coordinates, indexing="xy")
  934. # Stack the coordinates and divide by their respective patch counts
  935. box_coordinates = torch.stack((xx, yy), dim=-1)
  936. box_coordinates[..., 0] /= num_patches_width
  937. box_coordinates[..., 1] /= num_patches_height
  938. # Flatten (h, w, 2) -> (h*w, 2)
  939. box_coordinates = box_coordinates.view(-1, 2)
  940. return box_coordinates
  941. def objectness_predictor(self, image_features: torch.FloatTensor) -> torch.FloatTensor:
  942. """Predicts the probability that each image feature token is an object.
  943. Args:
  944. image_features (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_dim)`)):
  945. Features extracted from the image.
  946. Returns:
  947. Objectness scores.
  948. """
  949. image_features = image_features.detach()
  950. objectness_logits = self.objectness_head(image_features)
  951. objectness_logits = objectness_logits[..., 0]
  952. return objectness_logits
  953. # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.compute_box_bias
  954. def compute_box_bias(self, num_patches_height: int, num_patches_width: int) -> torch.Tensor:
  955. # The box center is biased to its position on the feature grid
  956. box_coordinates = self.normalize_grid_corner_coordinates(num_patches_height, num_patches_width)
  957. box_coordinates = torch.clip(box_coordinates, 0.0, 1.0)
  958. # Unnormalize xy
  959. box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4)
  960. # The box size is biased to the patch size
  961. box_size = torch.full_like(box_coord_bias, 1.0)
  962. box_size[..., 0] /= num_patches_width
  963. box_size[..., 1] /= num_patches_height
  964. box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4)
  965. # Compute box bias
  966. box_bias = torch.cat([box_coord_bias, box_size_bias], dim=-1)
  967. return box_bias
  968. # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.box_predictor
  969. def box_predictor(
  970. self,
  971. image_feats: torch.FloatTensor,
  972. feature_map: torch.FloatTensor,
  973. interpolate_pos_encoding: bool = False,
  974. ) -> torch.FloatTensor:
  975. """
  976. Args:
  977. image_feats:
  978. Features extracted from the image, returned by the `image_text_embedder` method.
  979. feature_map:
  980. A spatial re-arrangement of image_features, also returned by the `image_text_embedder` method.
  981. interpolate_pos_encoding:
  982. Whether to interpolate the pre-trained position encodings.
  983. Returns:
  984. pred_boxes:
  985. List of predicted boxes (cxcywh normalized to 0, 1) nested within a dictionary.
  986. """
  987. # Bounding box detection head [batch_size, num_boxes, 4].
  988. pred_boxes = self.box_head(image_feats)
  989. # Compute the location of each token on the grid and use it to compute a bias for the bbox prediction
  990. if interpolate_pos_encoding:
  991. _, num_patches_height, num_patches_width, _ = feature_map.shape
  992. box_bias = self.compute_box_bias(num_patches_height, num_patches_width)
  993. else:
  994. box_bias = self.box_bias
  995. box_bias = box_bias.to(feature_map.device)
  996. pred_boxes += box_bias
  997. pred_boxes = self.sigmoid(pred_boxes)
  998. return pred_boxes
  999. # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.class_predictor
  1000. def class_predictor(
  1001. self,
  1002. image_feats: torch.FloatTensor,
  1003. query_embeds: torch.FloatTensor | None = None,
  1004. query_mask: torch.Tensor | None = None,
  1005. ) -> tuple[torch.FloatTensor]:
  1006. """
  1007. Args:
  1008. image_feats:
  1009. Features extracted from the `image_text_embedder`.
  1010. query_embeds:
  1011. Text query embeddings.
  1012. query_mask:
  1013. Must be provided with query_embeddings. A mask indicating which query embeddings are valid.
  1014. """
  1015. (pred_logits, image_class_embeds) = self.class_head(image_feats, query_embeds, query_mask)
  1016. return (pred_logits, image_class_embeds)
  1017. # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.image_text_embedder with owlvit->owlv2
  1018. def image_text_embedder(
  1019. self,
  1020. input_ids: torch.Tensor,
  1021. pixel_values: torch.FloatTensor,
  1022. attention_mask: torch.Tensor,
  1023. interpolate_pos_encoding: bool = False,
  1024. **kwargs: Unpack[TransformersKwargs],
  1025. ) -> tuple[torch.FloatTensor]:
  1026. outputs = self.owlv2(
  1027. pixel_values=pixel_values,
  1028. input_ids=input_ids,
  1029. attention_mask=attention_mask,
  1030. interpolate_pos_encoding=interpolate_pos_encoding,
  1031. **kwargs,
  1032. )
  1033. if interpolate_pos_encoding:
  1034. _, _, height, width = pixel_values.shape
  1035. num_patches_height = height // self.config.vision_config.patch_size
  1036. num_patches_width = width // self.config.vision_config.patch_size
  1037. else:
  1038. num_patches_height = self.num_patches_height
  1039. num_patches_width = self.num_patches_width
  1040. # Get image embeddings
  1041. last_hidden_state = outputs.vision_model_output[0]
  1042. image_embeds = self.owlv2.vision_model.post_layernorm(last_hidden_state)
  1043. # Resize class token
  1044. class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape)
  1045. # Merge image embedding with class tokens
  1046. image_embeds = image_embeds[:, 1:, :] * class_token_out
  1047. image_embeds = self.layer_norm(image_embeds)
  1048. # Resize to [batch_size, num_patches_height, num_patches_width, hidden_size]
  1049. new_size = (
  1050. image_embeds.shape[0],
  1051. num_patches_height,
  1052. num_patches_width,
  1053. image_embeds.shape[-1],
  1054. )
  1055. image_embeds = image_embeds.reshape(new_size)
  1056. text_embeds = outputs[-4]
  1057. return (text_embeds, image_embeds, outputs)
  1058. # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.image_embedder with owlvit->owlv2, OwlViTModel->Owlv2Model
  1059. def image_embedder(
  1060. self,
  1061. pixel_values: torch.FloatTensor,
  1062. interpolate_pos_encoding: bool = False,
  1063. **kwargs: Unpack[TransformersKwargs],
  1064. ) -> tuple[torch.FloatTensor]:
  1065. # Get Owlv2Model vision embeddings (same as CLIP)
  1066. vision_outputs: BaseModelOutputWithPooling = self.owlv2.vision_model(
  1067. pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, **kwargs
  1068. )
  1069. if interpolate_pos_encoding:
  1070. _, _, height, width = pixel_values.shape
  1071. num_patches_height = height // self.config.vision_config.patch_size
  1072. num_patches_width = width // self.config.vision_config.patch_size
  1073. else:
  1074. num_patches_height = self.num_patches_height
  1075. num_patches_width = self.num_patches_width
  1076. # Apply post_layernorm to last_hidden_state, return non-projected output
  1077. last_hidden_state = vision_outputs[0]
  1078. image_embeds = self.owlv2.vision_model.post_layernorm(last_hidden_state)
  1079. # Resize class token
  1080. class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape)
  1081. # Merge image embedding with class tokens
  1082. image_embeds = image_embeds[:, 1:, :] * class_token_out
  1083. image_embeds = self.layer_norm(image_embeds)
  1084. # Resize to [batch_size, num_patches_height, num_patches_width, hidden_size]
  1085. new_size = (
  1086. image_embeds.shape[0],
  1087. num_patches_height,
  1088. num_patches_width,
  1089. image_embeds.shape[-1],
  1090. )
  1091. image_embeds = image_embeds.reshape(new_size)
  1092. return (image_embeds, vision_outputs)
  1093. # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.embed_image_query
  1094. def embed_image_query(
  1095. self,
  1096. query_image_features: torch.FloatTensor,
  1097. query_feature_map: torch.FloatTensor,
  1098. interpolate_pos_encoding: bool = False,
  1099. ) -> torch.FloatTensor:
  1100. _, class_embeds = self.class_predictor(query_image_features)
  1101. pred_boxes = self.box_predictor(query_image_features, query_feature_map, interpolate_pos_encoding)
  1102. pred_boxes_as_corners = center_to_corners_format(pred_boxes)
  1103. # Loop over query images
  1104. best_class_embeds = []
  1105. best_box_indices = []
  1106. pred_boxes_device = pred_boxes_as_corners.device
  1107. for i in range(query_image_features.shape[0]):
  1108. each_query_box = torch.tensor([[0, 0, 1, 1]], device=pred_boxes_device)
  1109. each_query_pred_boxes = pred_boxes_as_corners[i]
  1110. ious, _ = box_iou(each_query_box, each_query_pred_boxes)
  1111. # If there are no overlapping boxes, fall back to generalized IoU
  1112. if torch.all(ious[0] == 0.0):
  1113. ious = generalized_box_iou(each_query_box, each_query_pred_boxes)
  1114. # Use an adaptive threshold to include all boxes within 80% of the best IoU
  1115. iou_threshold = torch.max(ious) * 0.8
  1116. selected_inds = (ious[0] >= iou_threshold).nonzero()
  1117. if selected_inds.numel():
  1118. selected_embeddings = class_embeds[i][selected_inds.squeeze(1)]
  1119. mean_embeds = torch.mean(class_embeds[i], axis=0)
  1120. mean_sim = torch.einsum("d,id->i", mean_embeds, selected_embeddings)
  1121. best_box_ind = selected_inds[torch.argmin(mean_sim)]
  1122. best_class_embeds.append(class_embeds[i][best_box_ind])
  1123. best_box_indices.append(best_box_ind)
  1124. if best_class_embeds:
  1125. query_embeds = torch.stack(best_class_embeds)
  1126. box_indices = torch.stack(best_box_indices)
  1127. else:
  1128. query_embeds, box_indices = None, None
  1129. return query_embeds, box_indices, pred_boxes
  1130. @can_return_tuple
  1131. @auto_docstring
  1132. def image_guided_detection(
  1133. self,
  1134. pixel_values: torch.FloatTensor,
  1135. query_pixel_values: torch.FloatTensor | None = None,
  1136. interpolate_pos_encoding: bool = False,
  1137. **kwargs: Unpack[TransformersKwargs],
  1138. ) -> Owlv2ImageGuidedObjectDetectionOutput:
  1139. r"""
  1140. query_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  1141. Pixel values of query image(s) to be detected. Pass in one query image per target image.
  1142. Examples:
  1143. ```python
  1144. >>> import httpx
  1145. >>> from io import BytesIO
  1146. >>> from PIL import Image
  1147. >>> import torch
  1148. >>> from transformers import AutoProcessor, Owlv2ForObjectDetection
  1149. >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
  1150. >>> model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")
  1151. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1152. >>> with httpx.stream("GET", url) as response:
  1153. ... image = Image.open(BytesIO(response.read()))
  1154. >>> query_url = "http://images.cocodataset.org/val2017/000000001675.jpg"
  1155. >>> with httpx.stream("GET", query_url) as response:
  1156. ... query_image = Image.open(BytesIO(response.read()))
  1157. >>> inputs = processor(images=image, query_images=query_image, return_tensors="pt")
  1158. >>> # forward pass
  1159. >>> with torch.no_grad():
  1160. ... outputs = model.image_guided_detection(**inputs)
  1161. >>> target_sizes = torch.Tensor([image.size[::-1]])
  1162. >>> # Convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
  1163. >>> results = processor.post_process_image_guided_detection(
  1164. ... outputs=outputs, threshold=0.9, nms_threshold=0.3, target_sizes=target_sizes
  1165. ... )
  1166. >>> i = 0 # Retrieve predictions for the first image
  1167. >>> boxes, scores = results[i]["boxes"], results[i]["scores"]
  1168. >>> for box, score in zip(boxes, scores):
  1169. ... box = [round(i, 2) for i in box.tolist()]
  1170. ... print(f"Detected similar object with confidence {round(score.item(), 3)} at location {box}")
  1171. Detected similar object with confidence 0.938 at location [327.31, 54.94, 547.39, 268.06]
  1172. Detected similar object with confidence 0.959 at location [5.78, 360.65, 619.12, 366.39]
  1173. Detected similar object with confidence 0.902 at location [2.85, 360.01, 627.63, 380.8]
  1174. Detected similar object with confidence 0.985 at location [176.98, -29.45, 672.69, 182.83]
  1175. Detected similar object with confidence 1.0 at location [6.53, 14.35, 624.87, 470.82]
  1176. Detected similar object with confidence 0.998 at location [579.98, 29.14, 615.49, 489.05]
  1177. Detected similar object with confidence 0.985 at location [206.15, 10.53, 247.74, 466.01]
  1178. Detected similar object with confidence 0.947 at location [18.62, 429.72, 646.5, 457.72]
  1179. Detected similar object with confidence 0.996 at location [523.88, 20.69, 586.84, 483.18]
  1180. Detected similar object with confidence 0.998 at location [3.39, 360.59, 617.29, 499.21]
  1181. Detected similar object with confidence 0.969 at location [4.47, 449.05, 614.5, 474.76]
  1182. Detected similar object with confidence 0.966 at location [31.44, 463.65, 654.66, 471.07]
  1183. Detected similar object with confidence 0.924 at location [30.93, 468.07, 635.35, 475.39]
  1184. ```"""
  1185. # Compute feature maps for the input and query images
  1186. query_feature_map = self.image_embedder(
  1187. pixel_values=query_pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
  1188. )[0]
  1189. feature_map, vision_outputs = self.image_embedder(
  1190. pixel_values=pixel_values,
  1191. interpolate_pos_encoding=interpolate_pos_encoding,
  1192. **kwargs,
  1193. )
  1194. batch_size, num_patches_height, num_patches_width, hidden_dim = feature_map.shape
  1195. image_feats = torch.reshape(feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim))
  1196. batch_size, num_patches_height, num_patches_width, hidden_dim = query_feature_map.shape
  1197. query_image_feats = torch.reshape(
  1198. query_feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim)
  1199. )
  1200. # Get top class embedding and best box index for each query image in batch
  1201. query_embeds, best_box_indices, query_pred_boxes = self.embed_image_query(
  1202. query_image_feats, query_feature_map, interpolate_pos_encoding
  1203. )
  1204. # Predict object classes [batch_size, num_patches, num_queries+1]
  1205. (pred_logits, class_embeds) = self.class_predictor(image_feats=image_feats, query_embeds=query_embeds)
  1206. # Predict object boxes
  1207. target_pred_boxes = self.box_predictor(image_feats, feature_map, interpolate_pos_encoding)
  1208. return Owlv2ImageGuidedObjectDetectionOutput(
  1209. image_embeds=feature_map,
  1210. query_image_embeds=query_feature_map,
  1211. target_pred_boxes=target_pred_boxes,
  1212. query_pred_boxes=query_pred_boxes,
  1213. logits=pred_logits,
  1214. class_embeds=class_embeds,
  1215. text_model_output=None,
  1216. vision_model_output=vision_outputs,
  1217. )
  1218. @can_return_tuple
  1219. @auto_docstring
  1220. def forward(
  1221. self,
  1222. input_ids: torch.Tensor,
  1223. pixel_values: torch.FloatTensor,
  1224. attention_mask: torch.Tensor | None = None,
  1225. interpolate_pos_encoding: bool = False,
  1226. **kwargs: Unpack[TransformersKwargs],
  1227. ) -> Owlv2ObjectDetectionOutput:
  1228. r"""
  1229. input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`, *optional*):
  1230. Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
  1231. [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
  1232. IDs?](../glossary#input-ids).
  1233. Examples:
  1234. ```python
  1235. >>> import httpx
  1236. >>> from io import BytesIO
  1237. >>> from PIL import Image
  1238. >>> import torch
  1239. >>> from transformers import Owlv2Processor, Owlv2ForObjectDetection
  1240. >>> processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
  1241. >>> model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")
  1242. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1243. >>> with httpx.stream("GET", url) as response:
  1244. ... image = Image.open(BytesIO(response.read()))
  1245. >>> text_labels = [["a photo of a cat", "a photo of a dog"]]
  1246. >>> inputs = processor(text=text_labels, images=image, return_tensors="pt")
  1247. >>> outputs = model(**inputs)
  1248. >>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
  1249. >>> target_sizes = torch.tensor([(image.height, image.width)])
  1250. >>> # Convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
  1251. >>> results = processor.post_process_grounded_object_detection(
  1252. ... outputs=outputs, target_sizes=target_sizes, threshold=0.1, text_labels=text_labels
  1253. ... )
  1254. >>> # Retrieve predictions for the first image for the corresponding text queries
  1255. >>> result = results[0]
  1256. >>> boxes, scores, text_labels = result["boxes"], result["scores"], result["text_labels"]
  1257. >>> for box, score, text_label in zip(boxes, scores, text_labels):
  1258. ... box = [round(i, 2) for i in box.tolist()]
  1259. ... print(f"Detected {text_label} with confidence {round(score.item(), 3)} at location {box}")
  1260. Detected a photo of a cat with confidence 0.614 at location [341.67, 23.39, 642.32, 371.35]
  1261. Detected a photo of a cat with confidence 0.665 at location [6.75, 51.96, 326.62, 473.13]
  1262. ```"""
  1263. # Embed images and text queries
  1264. query_embeds, feature_map, outputs = self.image_text_embedder(
  1265. input_ids=input_ids,
  1266. pixel_values=pixel_values,
  1267. attention_mask=attention_mask,
  1268. interpolate_pos_encoding=interpolate_pos_encoding,
  1269. **kwargs,
  1270. )
  1271. # Text and vision model outputs
  1272. text_outputs = outputs.text_model_output
  1273. vision_outputs = outputs.vision_model_output
  1274. batch_size, num_patches_height, num_patches_width, hidden_dim = feature_map.shape
  1275. image_feats = torch.reshape(feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim))
  1276. # Reshape from [batch_size * max_text_queries, hidden_dim] -> [batch_size, max_text_queries, hidden_dim]
  1277. max_text_queries = input_ids.shape[0] // batch_size
  1278. query_embeds = query_embeds.reshape(batch_size, max_text_queries, query_embeds.shape[-1])
  1279. # If first token is 0, then this is a padded query [batch_size, num_queries].
  1280. input_ids = input_ids.reshape(batch_size, max_text_queries, input_ids.shape[-1])
  1281. query_mask = input_ids[..., 0] > 0
  1282. # Predict object classes [batch_size, num_patches, num_queries+1]
  1283. (pred_logits, class_embeds) = self.class_predictor(image_feats, query_embeds, query_mask)
  1284. # Predict objectness
  1285. objectness_logits = self.objectness_predictor(image_feats)
  1286. # Predict object boxes
  1287. pred_boxes = self.box_predictor(image_feats, feature_map, interpolate_pos_encoding)
  1288. return Owlv2ObjectDetectionOutput(
  1289. image_embeds=feature_map,
  1290. text_embeds=query_embeds,
  1291. pred_boxes=pred_boxes,
  1292. logits=pred_logits,
  1293. objectness_logits=objectness_logits,
  1294. class_embeds=class_embeds,
  1295. text_model_output=text_outputs,
  1296. vision_model_output=vision_outputs,
  1297. )
  1298. __all__ = ["Owlv2Model", "Owlv2PreTrainedModel", "Owlv2TextModel", "Owlv2VisionModel", "Owlv2ForObjectDetection"]