modeling_owlvit.py 63 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492
  1. # Copyright 2022 Google AI and The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch OWL-ViT model."""
  15. from collections.abc import Callable
  16. from dataclasses import dataclass
  17. from typing import Any
  18. import torch
  19. from torch import Tensor, nn
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...masking_utils import create_causal_mask
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  25. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  26. from ...processing_utils import Unpack
  27. from ...utils import (
  28. ModelOutput,
  29. TransformersKwargs,
  30. auto_docstring,
  31. is_vision_available,
  32. logging,
  33. torch_int,
  34. )
  35. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  36. from ...utils.output_capturing import capture_outputs
  37. from .configuration_owlvit import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig
  38. if is_vision_available():
  39. from transformers.image_transforms import center_to_corners_format
  40. logger = logging.get_logger(__name__)
  41. # See all OwlViT models at https://huggingface.co/models?filter=owlvit
  42. # Copied from transformers.models.clip.modeling_clip.contrastive_loss with clip->owlvit
  43. def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
  44. return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
  45. # Copied from transformers.models.clip.modeling_clip.clip_loss with clip->owlvit
  46. def owlvit_loss(similarity: torch.Tensor) -> torch.Tensor:
  47. caption_loss = contrastive_loss(similarity)
  48. image_loss = contrastive_loss(similarity.t())
  49. return (caption_loss + image_loss) / 2.0
  50. @dataclass
  51. @auto_docstring
  52. class OwlViTOutput(ModelOutput):
  53. r"""
  54. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  55. Contrastive loss for image-text similarity.
  56. logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
  57. The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
  58. similarity scores.
  59. logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
  60. The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
  61. similarity scores.
  62. text_embeds (`torch.FloatTensor` of shape `(batch_size * num_max_text_queries, output_dim`):
  63. The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`].
  64. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  65. The image embeddings obtained by applying the projection layer to the pooled output of
  66. [`OwlViTVisionModel`].
  67. text_model_output (tuple[`BaseModelOutputWithPooling`]):
  68. The output of the [`OwlViTTextModel`].
  69. vision_model_output (`BaseModelOutputWithPooling`):
  70. The output of the [`OwlViTVisionModel`].
  71. """
  72. loss: torch.FloatTensor | None = None
  73. logits_per_image: torch.FloatTensor | None = None
  74. logits_per_text: torch.FloatTensor | None = None
  75. text_embeds: torch.FloatTensor | None = None
  76. image_embeds: torch.FloatTensor | None = None
  77. text_model_output: BaseModelOutputWithPooling = None
  78. vision_model_output: BaseModelOutputWithPooling = None
  79. def to_tuple(self) -> tuple[Any]:
  80. return tuple(
  81. self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
  82. for k in self.keys()
  83. )
  84. # Copied from transformers.loss.loss_for_object_detection._upcast
  85. def _upcast(t: Tensor) -> Tensor:
  86. # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
  87. if t.is_floating_point():
  88. return t if t.dtype in (torch.float32, torch.float64) else t.float()
  89. else:
  90. return t if t.dtype in (torch.int32, torch.int64) else t.int()
  91. # Copied from transformers.loss.loss_for_object_detection.box_area
  92. def box_area(boxes: Tensor) -> Tensor:
  93. """
  94. Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
  95. Args:
  96. boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
  97. Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
  98. < x2` and `0 <= y1 < y2`.
  99. Returns:
  100. `torch.FloatTensor`: a tensor containing the area for each box.
  101. """
  102. boxes = _upcast(boxes)
  103. return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
  104. # Copied from transformers.loss.loss_for_object_detection.box_iou
  105. def box_iou(boxes1, boxes2):
  106. area1 = box_area(boxes1)
  107. area2 = box_area(boxes2)
  108. left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
  109. right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
  110. width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
  111. inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
  112. union = area1[:, None] + area2 - inter
  113. iou = inter / union
  114. return iou, union
  115. # Copied from transformers.loss.loss_for_object_detection.generalized_box_iou
  116. def generalized_box_iou(boxes1, boxes2):
  117. """
  118. Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.
  119. Returns:
  120. `torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
  121. """
  122. # degenerate boxes gives inf / nan results
  123. # so do an early check
  124. if not (boxes1[:, 2:] >= boxes1[:, :2]).all():
  125. raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}")
  126. if not (boxes2[:, 2:] >= boxes2[:, :2]).all():
  127. raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}")
  128. iou, union = box_iou(boxes1, boxes2)
  129. top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2])
  130. bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
  131. width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2]
  132. area = width_height[:, :, 0] * width_height[:, :, 1]
  133. return iou - (area - union) / area
  134. @dataclass
  135. @auto_docstring(
  136. custom_intro="""
  137. Output type of [`OwlViTForObjectDetection`].
  138. """
  139. )
  140. class OwlViTObjectDetectionOutput(ModelOutput):
  141. r"""
  142. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
  143. Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
  144. bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
  145. scale-invariant IoU loss.
  146. loss_dict (`Dict`, *optional*):
  147. A dictionary containing the individual losses. Useful for logging.
  148. logits (`torch.FloatTensor` of shape `(batch_size, num_patches, num_queries)`):
  149. Classification logits (including no-object) for all queries.
  150. pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
  151. Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
  152. values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
  153. possible padding). You can use [`~OwlViTImageProcessor.post_process_object_detection`] to retrieve the
  154. unnormalized bounding boxes.
  155. text_embeds (`torch.FloatTensor` of shape `(batch_size, num_max_text_queries, output_dim`):
  156. The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`].
  157. image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
  158. Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes
  159. image embeddings for each patch.
  160. class_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`):
  161. Class embeddings of all image patches. OWL-ViT represents images as a set of image patches where the total
  162. number of patches is (image_size / patch_size)**2.
  163. text_model_output (tuple[`BaseModelOutputWithPooling`]):
  164. The output of the [`OwlViTTextModel`].
  165. vision_model_output (`BaseModelOutputWithPooling`):
  166. The output of the [`OwlViTVisionModel`].
  167. """
  168. loss: torch.FloatTensor | None = None
  169. loss_dict: dict | None = None
  170. logits: torch.FloatTensor | None = None
  171. pred_boxes: torch.FloatTensor | None = None
  172. text_embeds: torch.FloatTensor | None = None
  173. image_embeds: torch.FloatTensor | None = None
  174. class_embeds: torch.FloatTensor | None = None
  175. text_model_output: BaseModelOutputWithPooling = None
  176. vision_model_output: BaseModelOutputWithPooling = None
  177. def to_tuple(self) -> tuple[Any]:
  178. return tuple(
  179. self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
  180. for k in self.keys()
  181. )
  182. @dataclass
  183. @auto_docstring(
  184. custom_intro="""
  185. Output type of [`OwlViTForObjectDetection.image_guided_detection`].
  186. """
  187. )
  188. class OwlViTImageGuidedObjectDetectionOutput(ModelOutput):
  189. r"""
  190. logits (`torch.FloatTensor` of shape `(batch_size, num_patches, num_queries)`):
  191. Classification logits (including no-object) for all queries.
  192. image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
  193. Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes
  194. image embeddings for each patch.
  195. query_image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
  196. Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes
  197. image embeddings for each patch.
  198. target_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
  199. Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
  200. values are normalized in [0, 1], relative to the size of each individual target image in the batch
  201. (disregarding possible padding). You can use [`~OwlViTImageProcessor.post_process_object_detection`] to
  202. retrieve the unnormalized bounding boxes.
  203. query_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
  204. Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
  205. values are normalized in [0, 1], relative to the size of each individual query image in the batch
  206. (disregarding possible padding). You can use [`~OwlViTImageProcessor.post_process_object_detection`] to
  207. retrieve the unnormalized bounding boxes.
  208. class_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`):
  209. Class embeddings of all image patches. OWL-ViT represents images as a set of image patches where the total
  210. number of patches is (image_size / patch_size)**2.
  211. text_model_output (tuple[`BaseModelOutputWithPooling`]):
  212. The output of the [`OwlViTTextModel`].
  213. vision_model_output (`BaseModelOutputWithPooling`):
  214. The output of the [`OwlViTVisionModel`].
  215. """
  216. logits: torch.FloatTensor | None = None
  217. image_embeds: torch.FloatTensor | None = None
  218. query_image_embeds: torch.FloatTensor | None = None
  219. target_pred_boxes: torch.FloatTensor | None = None
  220. query_pred_boxes: torch.FloatTensor | None = None
  221. class_embeds: torch.FloatTensor | None = None
  222. text_model_output: BaseModelOutputWithPooling = None
  223. vision_model_output: BaseModelOutputWithPooling = None
  224. def to_tuple(self) -> tuple[Any]:
  225. return tuple(
  226. self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
  227. for k in self.keys()
  228. )
  229. class OwlViTVisionEmbeddings(nn.Module):
  230. def __init__(self, config: OwlViTVisionConfig):
  231. super().__init__()
  232. self.patch_size = config.patch_size
  233. self.config = config
  234. self.embed_dim = config.hidden_size
  235. self.class_embedding = nn.Parameter(torch.randn(config.hidden_size))
  236. self.patch_embedding = nn.Conv2d(
  237. in_channels=config.num_channels,
  238. out_channels=self.embed_dim,
  239. kernel_size=config.patch_size,
  240. stride=config.patch_size,
  241. bias=False,
  242. )
  243. self.num_patches = (config.image_size // config.patch_size) ** 2
  244. self.num_positions = self.num_patches + 1
  245. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  246. self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
  247. # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings.interpolate_pos_encoding
  248. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  249. """
  250. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  251. images. This method is also adapted to support torch.jit tracing.
  252. Adapted from:
  253. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  254. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  255. """
  256. num_patches = embeddings.shape[1] - 1
  257. position_embedding = self.position_embedding.weight.unsqueeze(0)
  258. num_positions = position_embedding.shape[1] - 1
  259. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  260. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  261. return self.position_embedding(self.position_ids)
  262. class_pos_embed = position_embedding[:, :1]
  263. patch_pos_embed = position_embedding[:, 1:]
  264. dim = embeddings.shape[-1]
  265. new_height = height // self.patch_size
  266. new_width = width // self.patch_size
  267. sqrt_num_positions = torch_int(num_positions**0.5)
  268. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  269. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  270. patch_pos_embed = nn.functional.interpolate(
  271. patch_pos_embed,
  272. size=(new_height, new_width),
  273. mode="bicubic",
  274. align_corners=False,
  275. )
  276. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  277. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  278. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  279. batch_size, _, height, width = pixel_values.shape
  280. patch_embeds = self.patch_embedding(pixel_values) # shape = [batch_size, num_channels, height, width]
  281. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  282. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  283. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  284. if interpolate_pos_encoding:
  285. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  286. else:
  287. embeddings = embeddings + self.position_embedding(self.position_ids)
  288. return embeddings
  289. class OwlViTTextEmbeddings(nn.Module):
  290. def __init__(self, config: OwlViTTextConfig):
  291. super().__init__()
  292. self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
  293. self.position_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  294. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  295. self.register_buffer(
  296. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  297. )
  298. def forward(
  299. self,
  300. input_ids: torch.LongTensor | None = None,
  301. position_ids: torch.LongTensor | None = None,
  302. inputs_embeds: torch.FloatTensor | None = None,
  303. ) -> torch.Tensor:
  304. seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
  305. if position_ids is None:
  306. position_ids = self.position_ids[:, :seq_length]
  307. if inputs_embeds is None:
  308. inputs_embeds = self.token_embedding(input_ids)
  309. position_embeddings = self.position_embedding(position_ids)
  310. embeddings = inputs_embeds + position_embeddings
  311. return embeddings
  312. # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
  313. def eager_attention_forward(
  314. module: nn.Module,
  315. query: torch.Tensor,
  316. key: torch.Tensor,
  317. value: torch.Tensor,
  318. attention_mask: torch.Tensor | None,
  319. scaling: float | None = None,
  320. dropout: float = 0.0,
  321. **kwargs: Unpack[TransformersKwargs],
  322. ):
  323. if scaling is None:
  324. scaling = query.size(-1) ** -0.5
  325. # Take the dot product between "query" and "key" to get the raw attention scores.
  326. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  327. if attention_mask is not None:
  328. attn_weights = attn_weights + attention_mask
  329. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  330. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  331. attn_output = torch.matmul(attn_weights, value)
  332. attn_output = attn_output.transpose(1, 2).contiguous()
  333. return attn_output, attn_weights
  334. class OwlViTAttention(nn.Module):
  335. """Multi-headed attention from 'Attention Is All You Need' paper"""
  336. def __init__(self, config):
  337. super().__init__()
  338. self.config = config
  339. self.embed_dim = config.hidden_size
  340. self.num_heads = config.num_attention_heads
  341. self.head_dim = self.embed_dim // self.num_heads
  342. if self.head_dim * self.num_heads != self.embed_dim:
  343. raise ValueError(
  344. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  345. f" {self.num_heads})."
  346. )
  347. self.scale = self.head_dim**-0.5
  348. self.dropout = config.attention_dropout
  349. self.is_causal = False
  350. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  351. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  352. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  353. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  354. def forward(
  355. self,
  356. hidden_states: torch.Tensor,
  357. attention_mask: torch.Tensor | None = None,
  358. **kwargs: Unpack[TransformersKwargs],
  359. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  360. input_shape = hidden_states.shape[:-1]
  361. hidden_shape = (*input_shape, -1, self.head_dim)
  362. query_states = self.q_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
  363. key_states = self.k_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
  364. value_states = self.v_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
  365. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  366. self.config._attn_implementation, eager_attention_forward
  367. )
  368. attn_output, attn_weights = attention_interface(
  369. self,
  370. query_states,
  371. key_states,
  372. value_states,
  373. attention_mask,
  374. scaling=self.scale,
  375. dropout=0.0 if not self.training else self.dropout,
  376. **kwargs,
  377. )
  378. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  379. attn_output = self.out_proj(attn_output)
  380. return attn_output, attn_weights
  381. # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->OwlViT
  382. class OwlViTMLP(nn.Module):
  383. def __init__(self, config):
  384. super().__init__()
  385. self.config = config
  386. self.activation_fn = ACT2FN[config.hidden_act]
  387. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  388. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  389. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  390. hidden_states = self.fc1(hidden_states)
  391. hidden_states = self.activation_fn(hidden_states)
  392. hidden_states = self.fc2(hidden_states)
  393. return hidden_states
  394. # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->OwlViT
  395. class OwlViTEncoderLayer(GradientCheckpointingLayer):
  396. def __init__(self, config: OwlViTVisionConfig | OwlViTTextConfig):
  397. super().__init__()
  398. self.embed_dim = config.hidden_size
  399. self.self_attn = OwlViTAttention(config)
  400. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  401. self.mlp = OwlViTMLP(config)
  402. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  403. def forward(
  404. self,
  405. hidden_states: torch.Tensor,
  406. attention_mask: torch.Tensor,
  407. **kwargs: Unpack[TransformersKwargs],
  408. ) -> torch.FloatTensor:
  409. residual = hidden_states
  410. hidden_states = self.layer_norm1(hidden_states)
  411. hidden_states, _ = self.self_attn(
  412. hidden_states=hidden_states,
  413. attention_mask=attention_mask,
  414. **kwargs,
  415. )
  416. hidden_states = residual + hidden_states
  417. residual = hidden_states
  418. hidden_states = self.layer_norm2(hidden_states)
  419. hidden_states = self.mlp(hidden_states)
  420. hidden_states = residual + hidden_states
  421. return hidden_states
  422. @auto_docstring
  423. class OwlViTPreTrainedModel(PreTrainedModel):
  424. config: OwlViTConfig
  425. base_model_prefix = "owlvit"
  426. input_modalities = ("image", "text")
  427. supports_gradient_checkpointing = True
  428. _supports_sdpa = True
  429. _supports_flash_attn = True
  430. _supports_flex_attn = True
  431. _supports_attention_backend = True
  432. _no_split_modules = ["OwlViTEncoderLayer"]
  433. _can_record_outputs = {
  434. "hidden_states": OwlViTEncoderLayer,
  435. "attentions": OwlViTAttention,
  436. }
  437. _keys_to_ignore_on_load_unexpected = [
  438. r".*text_model\.embeddings\.position_ids",
  439. r".*vision_model\.embeddings\.position_ids",
  440. ]
  441. @torch.no_grad()
  442. def _init_weights(self, module: nn.Module):
  443. """Initialize the weights"""
  444. factor = self.config.initializer_factor
  445. if isinstance(module, OwlViTTextEmbeddings):
  446. init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02)
  447. init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02)
  448. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  449. elif isinstance(module, OwlViTVisionEmbeddings):
  450. init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
  451. init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
  452. init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
  453. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  454. elif isinstance(module, OwlViTAttention):
  455. in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  456. out_proj_std = (module.embed_dim**-0.5) * factor
  457. init.normal_(module.q_proj.weight, std=in_proj_std)
  458. init.normal_(module.k_proj.weight, std=in_proj_std)
  459. init.normal_(module.v_proj.weight, std=in_proj_std)
  460. init.normal_(module.out_proj.weight, std=out_proj_std)
  461. elif isinstance(module, OwlViTMLP):
  462. in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  463. fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
  464. init.normal_(module.fc1.weight, std=fc_std)
  465. init.normal_(module.fc2.weight, std=in_proj_std)
  466. elif isinstance(module, OwlViTModel):
  467. init.normal_(
  468. module.text_projection.weight,
  469. std=module.text_embed_dim**-0.5 * factor,
  470. )
  471. init.normal_(
  472. module.visual_projection.weight,
  473. std=module.vision_embed_dim**-0.5 * factor,
  474. )
  475. init.constant_(module.logit_scale, self.config.logit_scale_init_value)
  476. elif isinstance(module, OwlViTForObjectDetection):
  477. init.copy_(module.box_bias, module.compute_box_bias(module.num_patches_height, module.num_patches_width))
  478. if isinstance(module, nn.LayerNorm):
  479. init.zeros_(module.bias)
  480. init.ones_(module.weight)
  481. if isinstance(module, nn.Linear):
  482. init.normal_(module.weight, mean=0.0, std=factor)
  483. if module.bias is not None:
  484. init.zeros_(module.bias)
  485. # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->OwlViT
  486. class OwlViTEncoder(nn.Module):
  487. """
  488. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  489. [`OwlViTEncoderLayer`].
  490. Args:
  491. config: OwlViTConfig
  492. """
  493. def __init__(self, config: OwlViTConfig):
  494. super().__init__()
  495. self.config = config
  496. self.layers = nn.ModuleList([OwlViTEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  497. self.gradient_checkpointing = False
  498. def forward(
  499. self,
  500. inputs_embeds,
  501. attention_mask: torch.Tensor | None = None,
  502. **kwargs: Unpack[TransformersKwargs],
  503. ) -> BaseModelOutput:
  504. r"""
  505. Args:
  506. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  507. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  508. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  509. than the model's internal embedding lookup matrix.
  510. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  511. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  512. - 1 for tokens that are **not masked**,
  513. - 0 for tokens that are **masked**.
  514. [What are attention masks?](../glossary#attention-mask)
  515. """
  516. hidden_states = inputs_embeds
  517. for encoder_layer in self.layers:
  518. hidden_states = encoder_layer(
  519. hidden_states,
  520. attention_mask,
  521. **kwargs,
  522. )
  523. return BaseModelOutput(
  524. last_hidden_state=hidden_states,
  525. )
  526. class OwlViTTextTransformer(OwlViTPreTrainedModel):
  527. def __init__(self, config: OwlViTTextConfig):
  528. super().__init__(config)
  529. embed_dim = config.hidden_size
  530. self.embeddings = OwlViTTextEmbeddings(config)
  531. self.encoder = OwlViTEncoder(config)
  532. self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  533. # Initialize weights and apply final processing
  534. self.post_init()
  535. @merge_with_config_defaults
  536. @capture_outputs(tie_last_hidden_states=False)
  537. @auto_docstring
  538. def forward(
  539. self,
  540. input_ids: torch.Tensor | None = None,
  541. attention_mask: torch.Tensor | None = None,
  542. position_ids: torch.Tensor | None = None,
  543. **kwargs: Unpack[TransformersKwargs],
  544. ) -> tuple | BaseModelOutputWithPooling:
  545. r"""
  546. input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`):
  547. Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
  548. [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
  549. IDs?](../glossary#input-ids)
  550. """
  551. input_shape = input_ids.size()
  552. input_ids = input_ids.view(-1, input_shape[-1])
  553. hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
  554. attention_mask = create_causal_mask(
  555. config=self.config,
  556. inputs_embeds=hidden_states,
  557. attention_mask=attention_mask,
  558. past_key_values=None,
  559. )
  560. kwargs.pop("is_causal", None)
  561. encoder_outputs: BaseModelOutput = self.encoder(
  562. inputs_embeds=hidden_states,
  563. attention_mask=attention_mask,
  564. is_causal=True,
  565. **kwargs,
  566. )
  567. last_hidden_state = encoder_outputs.last_hidden_state
  568. last_hidden_state = self.final_layer_norm(last_hidden_state)
  569. # take features from the end of tokens embedding (end of token is the highest number in each sequence)
  570. # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
  571. pooled_output = last_hidden_state[
  572. torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
  573. input_ids.to(torch.int).argmax(dim=-1).to(last_hidden_state.device),
  574. ]
  575. return BaseModelOutputWithPooling(
  576. last_hidden_state=last_hidden_state,
  577. pooler_output=pooled_output,
  578. )
  579. class OwlViTTextModel(OwlViTPreTrainedModel):
  580. config: OwlViTTextConfig
  581. input_modalities = ("text",)
  582. def __init__(self, config: OwlViTTextConfig):
  583. super().__init__(config)
  584. self.text_model = OwlViTTextTransformer(config)
  585. # Initialize weights and apply final processing
  586. self.post_init()
  587. def get_input_embeddings(self) -> nn.Module:
  588. return self.text_model.embeddings.token_embedding
  589. def set_input_embeddings(self, value):
  590. self.text_model.embeddings.token_embedding = value
  591. @auto_docstring
  592. def forward(
  593. self,
  594. input_ids: torch.Tensor | None = None,
  595. attention_mask: torch.Tensor | None = None,
  596. **kwargs: Unpack[TransformersKwargs],
  597. ) -> tuple | BaseModelOutputWithPooling:
  598. r"""
  599. input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`):
  600. Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
  601. [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
  602. IDs?](../glossary#input-ids)
  603. Examples:
  604. ```python
  605. >>> from transformers import AutoProcessor, OwlViTTextModel
  606. >>> model = OwlViTTextModel.from_pretrained("google/owlvit-base-patch32")
  607. >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32")
  608. >>> inputs = processor(
  609. ... text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt"
  610. ... )
  611. >>> outputs = model(**inputs)
  612. >>> last_hidden_state = outputs.last_hidden_state
  613. >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
  614. ```"""
  615. return self.text_model(
  616. input_ids=input_ids,
  617. attention_mask=attention_mask,
  618. **kwargs,
  619. )
  620. class OwlViTVisionTransformer(OwlViTPreTrainedModel):
  621. def __init__(self, config: OwlViTVisionConfig):
  622. super().__init__(config)
  623. self.embeddings = OwlViTVisionEmbeddings(config)
  624. self.pre_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  625. self.encoder = OwlViTEncoder(config)
  626. self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  627. # Initialize weights and apply final processing
  628. self.post_init()
  629. @merge_with_config_defaults
  630. @capture_outputs(tie_last_hidden_states=False)
  631. @auto_docstring
  632. def forward(
  633. self,
  634. pixel_values: torch.FloatTensor,
  635. interpolate_pos_encoding: bool | None = False,
  636. **kwargs: Unpack[TransformersKwargs],
  637. ) -> tuple | BaseModelOutputWithPooling:
  638. # Cast the input to the expected `dtype`
  639. expected_input_dtype = self.embeddings.patch_embedding.weight.dtype
  640. pixel_values = pixel_values.to(expected_input_dtype)
  641. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  642. hidden_states = self.pre_layernorm(hidden_states)
  643. encoder_outputs: BaseModelOutput = self.encoder(
  644. inputs_embeds=hidden_states,
  645. **kwargs,
  646. )
  647. last_hidden_state = encoder_outputs.last_hidden_state
  648. pooled_output = last_hidden_state[:, 0, :]
  649. pooled_output = self.post_layernorm(pooled_output)
  650. return BaseModelOutputWithPooling(
  651. last_hidden_state=last_hidden_state,
  652. pooler_output=pooled_output,
  653. )
  654. class OwlViTVisionModel(OwlViTPreTrainedModel):
  655. config: OwlViTVisionConfig
  656. main_input_name = "pixel_values"
  657. input_modalities = ("image",)
  658. def __init__(self, config: OwlViTVisionConfig):
  659. super().__init__(config)
  660. self.vision_model = OwlViTVisionTransformer(config)
  661. # Initialize weights and apply final processing
  662. self.post_init()
  663. def get_input_embeddings(self) -> nn.Module:
  664. return self.vision_model.embeddings.patch_embedding
  665. @auto_docstring
  666. def forward(
  667. self,
  668. pixel_values: torch.FloatTensor | None = None,
  669. interpolate_pos_encoding: bool = False,
  670. **kwargs: Unpack[TransformersKwargs],
  671. ) -> BaseModelOutputWithPooling:
  672. r"""
  673. Examples:
  674. ```python
  675. >>> from PIL import Image
  676. >>> import httpx
  677. >>> from io import BytesIO
  678. >>> from transformers import AutoProcessor, OwlViTVisionModel
  679. >>> model = OwlViTVisionModel.from_pretrained("google/owlvit-base-patch32")
  680. >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32")
  681. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  682. >>> with httpx.stream("GET", url) as response:
  683. ... image = Image.open(BytesIO(response.read()))
  684. >>> inputs = processor(images=image, return_tensors="pt")
  685. >>> outputs = model(**inputs)
  686. >>> last_hidden_state = outputs.last_hidden_state
  687. >>> pooled_output = outputs.pooler_output # pooled CLS states
  688. ```"""
  689. return self.vision_model(
  690. pixel_values=pixel_values,
  691. interpolate_pos_encoding=interpolate_pos_encoding,
  692. **kwargs,
  693. )
  694. @auto_docstring
  695. class OwlViTModel(OwlViTPreTrainedModel):
  696. config: OwlViTConfig
  697. def __init__(self, config: OwlViTConfig):
  698. super().__init__(config)
  699. text_config = config.text_config
  700. vision_config = config.vision_config
  701. self.projection_dim = config.projection_dim
  702. self.text_embed_dim = text_config.hidden_size
  703. self.vision_embed_dim = vision_config.hidden_size
  704. self.text_model = OwlViTTextTransformer(text_config)
  705. self.vision_model = OwlViTVisionTransformer(vision_config)
  706. self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
  707. self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
  708. self.logit_scale = nn.Parameter(torch.tensor(config.logit_scale_init_value))
  709. # Initialize weights and apply final processing
  710. self.post_init()
  711. @can_return_tuple
  712. @auto_docstring
  713. def get_text_features(
  714. self,
  715. input_ids: torch.Tensor,
  716. attention_mask: torch.Tensor | None = None,
  717. **kwargs: Unpack[TransformersKwargs],
  718. ) -> tuple | BaseModelOutputWithPooling:
  719. r"""
  720. input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`):
  721. Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
  722. [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
  723. IDs?](../glossary#input-ids)
  724. Examples:
  725. ```python
  726. >>> import torch
  727. >>> from transformers import AutoProcessor, OwlViTModel
  728. >>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32")
  729. >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32")
  730. >>> inputs = processor(
  731. ... text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt"
  732. ... )
  733. >>> with torch.inference_mode():
  734. ... text_features = model.get_text_features(**inputs)
  735. ```"""
  736. text_outputs: BaseModelOutputWithPooling = self.text_model(
  737. input_ids=input_ids,
  738. attention_mask=attention_mask,
  739. **kwargs,
  740. )
  741. pooled_output = text_outputs.pooler_output
  742. text_outputs.pooler_output = self.text_projection(pooled_output)
  743. return text_outputs
  744. @can_return_tuple
  745. @auto_docstring
  746. def get_image_features(
  747. self,
  748. pixel_values: torch.Tensor,
  749. interpolate_pos_encoding: bool = False,
  750. **kwargs: Unpack[TransformersKwargs],
  751. ) -> tuple | BaseModelOutputWithPooling:
  752. r"""
  753. Examples:
  754. ```python
  755. >>> import torch
  756. >>> from transformers.image_utils import load_image
  757. >>> from transformers import AutoProcessor, OwlViTModel
  758. >>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32")
  759. >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32")
  760. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  761. >>> image = load_image(url)
  762. >>> inputs = processor(images=image, return_tensors="pt")
  763. >>> with torch.inference_mode():
  764. ... image_features = model.get_image_features(**inputs)
  765. ```"""
  766. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  767. pixel_values=pixel_values,
  768. interpolate_pos_encoding=interpolate_pos_encoding,
  769. **kwargs,
  770. )
  771. vision_outputs.pooler_output = self.visual_projection(vision_outputs.pooler_output)
  772. return vision_outputs
  773. @can_return_tuple
  774. @auto_docstring
  775. def forward(
  776. self,
  777. input_ids: torch.LongTensor | None = None,
  778. pixel_values: torch.FloatTensor | None = None,
  779. attention_mask: torch.Tensor | None = None,
  780. return_loss: bool | None = None,
  781. interpolate_pos_encoding: bool = False,
  782. return_base_image_embeds: bool | None = None,
  783. **kwargs: Unpack[TransformersKwargs],
  784. ) -> tuple | OwlViTOutput:
  785. r"""
  786. return_loss (`bool`, *optional*):
  787. Whether or not to return the contrastive loss.
  788. return_base_image_embeds (`bool`, *optional*):
  789. Whether or not to return the base image embeddings.
  790. Examples:
  791. ```python
  792. >>> from PIL import Image
  793. >>> import httpx
  794. >>> from io import BytesIO
  795. >>> from transformers import AutoProcessor, OwlViTModel
  796. >>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32")
  797. >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32")
  798. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  799. >>> with httpx.stream("GET", url) as response:
  800. ... image = Image.open(BytesIO(response.read()))
  801. >>> inputs = processor(text=[["a photo of a cat", "a photo of a dog"]], images=image, return_tensors="pt")
  802. >>> outputs = model(**inputs)
  803. >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
  804. >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
  805. ```"""
  806. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  807. pixel_values=pixel_values,
  808. interpolate_pos_encoding=interpolate_pos_encoding,
  809. **kwargs,
  810. )
  811. # Get embeddings for all text queries in all batch samples
  812. text_outputs: BaseModelOutputWithPooling = self.text_model(
  813. input_ids=input_ids,
  814. attention_mask=attention_mask,
  815. **kwargs,
  816. )
  817. text_embeds = text_outputs.pooler_output
  818. text_embeds = self.text_projection(text_embeds)
  819. image_embeds = vision_outputs.pooler_output
  820. image_embeds = self.visual_projection(image_embeds)
  821. # normalized features
  822. image_embeds = image_embeds / torch.linalg.norm(image_embeds, ord=2, dim=-1, keepdim=True)
  823. text_embeds_norm = text_embeds / torch.linalg.norm(text_embeds, ord=2, dim=-1, keepdim=True)
  824. # cosine similarity as logits and set it on the correct device
  825. logit_scale = self.logit_scale.exp().to(image_embeds.device)
  826. logits_per_text = torch.matmul(text_embeds_norm, image_embeds.t()) * logit_scale
  827. logits_per_image = logits_per_text.t()
  828. loss = None
  829. if return_loss:
  830. loss = owlvit_loss(logits_per_text)
  831. text_embeds = text_embeds_norm
  832. return OwlViTOutput(
  833. loss=loss,
  834. logits_per_image=logits_per_image,
  835. logits_per_text=logits_per_text,
  836. text_embeds=text_embeds,
  837. image_embeds=image_embeds,
  838. text_model_output=text_outputs,
  839. vision_model_output=vision_outputs,
  840. )
  841. class OwlViTBoxPredictionHead(nn.Module):
  842. def __init__(self, config: OwlViTConfig, out_dim: int = 4):
  843. super().__init__()
  844. width = config.vision_config.hidden_size
  845. self.dense0 = nn.Linear(width, width)
  846. self.dense1 = nn.Linear(width, width)
  847. self.gelu = nn.GELU()
  848. self.dense2 = nn.Linear(width, out_dim)
  849. def forward(self, image_features: torch.Tensor) -> torch.FloatTensor:
  850. output = self.dense0(image_features)
  851. output = self.gelu(output)
  852. output = self.dense1(output)
  853. output = self.gelu(output)
  854. output = self.dense2(output)
  855. return output
  856. class OwlViTClassPredictionHead(nn.Module):
  857. def __init__(self, config: OwlViTConfig):
  858. super().__init__()
  859. out_dim = config.text_config.hidden_size
  860. self.query_dim = config.vision_config.hidden_size
  861. self.dense0 = nn.Linear(self.query_dim, out_dim)
  862. self.logit_shift = nn.Linear(self.query_dim, 1)
  863. self.logit_scale = nn.Linear(self.query_dim, 1)
  864. self.elu = nn.ELU()
  865. def forward(
  866. self,
  867. image_embeds: torch.FloatTensor,
  868. query_embeds: torch.FloatTensor | None,
  869. query_mask: torch.Tensor | None,
  870. ) -> tuple[torch.FloatTensor]:
  871. image_class_embeds = self.dense0(image_embeds)
  872. if query_embeds is None:
  873. device = image_class_embeds.device
  874. batch_size, num_patches = image_class_embeds.shape[:2]
  875. pred_logits = torch.zeros((batch_size, num_patches, self.query_dim)).to(device)
  876. return (pred_logits, image_class_embeds)
  877. # Normalize image and text features
  878. image_class_embeds = image_class_embeds / (torch.linalg.norm(image_class_embeds, dim=-1, keepdim=True) + 1e-6)
  879. query_embeds = query_embeds / (torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6)
  880. # Get class predictions
  881. pred_logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds)
  882. # Apply a learnable shift and scale to logits
  883. logit_shift = self.logit_shift(image_embeds)
  884. logit_scale = self.logit_scale(image_embeds)
  885. logit_scale = self.elu(logit_scale) + 1
  886. pred_logits = (pred_logits + logit_shift) * logit_scale
  887. if query_mask is not None:
  888. if query_mask.ndim > 1:
  889. query_mask = torch.unsqueeze(query_mask, dim=-2)
  890. pred_logits = torch.where(query_mask == 0, torch.finfo(pred_logits.dtype).min, pred_logits)
  891. pred_logits = pred_logits.to(torch.float32)
  892. return (pred_logits, image_class_embeds)
  893. class OwlViTForObjectDetection(OwlViTPreTrainedModel):
  894. config: OwlViTConfig
  895. def __init__(self, config: OwlViTConfig):
  896. super().__init__(config)
  897. self.owlvit = OwlViTModel(config)
  898. self.class_head = OwlViTClassPredictionHead(config)
  899. self.box_head = OwlViTBoxPredictionHead(config)
  900. self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps)
  901. self.sigmoid = nn.Sigmoid()
  902. self.config = config
  903. self.num_patches_height = self.config.vision_config.image_size // self.config.vision_config.patch_size
  904. self.num_patches_width = self.config.vision_config.image_size // self.config.vision_config.patch_size
  905. self.register_buffer(
  906. "box_bias", self.compute_box_bias(self.num_patches_height, self.num_patches_width), persistent=False
  907. )
  908. self.post_init()
  909. @staticmethod
  910. def normalize_grid_corner_coordinates(num_patches_height: int, num_patches_width: int) -> torch.Tensor:
  911. # Create grid coordinates using torch
  912. x_coordinates = torch.arange(1, num_patches_width + 1, dtype=torch.float32)
  913. y_coordinates = torch.arange(1, num_patches_height + 1, dtype=torch.float32)
  914. xx, yy = torch.meshgrid(x_coordinates, y_coordinates, indexing="xy")
  915. # Stack the coordinates and divide by their respective patch counts
  916. box_coordinates = torch.stack((xx, yy), dim=-1)
  917. box_coordinates[..., 0] /= num_patches_width
  918. box_coordinates[..., 1] /= num_patches_height
  919. # Flatten (h, w, 2) -> (h*w, 2)
  920. box_coordinates = box_coordinates.view(-1, 2)
  921. return box_coordinates
  922. def compute_box_bias(self, num_patches_height: int, num_patches_width: int) -> torch.Tensor:
  923. # The box center is biased to its position on the feature grid
  924. box_coordinates = self.normalize_grid_corner_coordinates(num_patches_height, num_patches_width)
  925. box_coordinates = torch.clip(box_coordinates, 0.0, 1.0)
  926. # Unnormalize xy
  927. box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4)
  928. # The box size is biased to the patch size
  929. box_size = torch.full_like(box_coord_bias, 1.0)
  930. box_size[..., 0] /= num_patches_width
  931. box_size[..., 1] /= num_patches_height
  932. box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4)
  933. # Compute box bias
  934. box_bias = torch.cat([box_coord_bias, box_size_bias], dim=-1)
  935. return box_bias
  936. def box_predictor(
  937. self,
  938. image_feats: torch.FloatTensor,
  939. feature_map: torch.FloatTensor,
  940. interpolate_pos_encoding: bool = False,
  941. ) -> torch.FloatTensor:
  942. """
  943. Args:
  944. image_feats:
  945. Features extracted from the image, returned by the `image_text_embedder` method.
  946. feature_map:
  947. A spatial re-arrangement of image_features, also returned by the `image_text_embedder` method.
  948. interpolate_pos_encoding:
  949. Whether to interpolate the pre-trained position encodings.
  950. Returns:
  951. pred_boxes:
  952. List of predicted boxes (cxcywh normalized to 0, 1) nested within a dictionary.
  953. """
  954. # Bounding box detection head [batch_size, num_boxes, 4].
  955. pred_boxes = self.box_head(image_feats)
  956. # Compute the location of each token on the grid and use it to compute a bias for the bbox prediction
  957. if interpolate_pos_encoding:
  958. _, num_patches_height, num_patches_width, _ = feature_map.shape
  959. box_bias = self.compute_box_bias(num_patches_height, num_patches_width)
  960. else:
  961. box_bias = self.box_bias
  962. box_bias = box_bias.to(feature_map.device)
  963. pred_boxes += box_bias
  964. pred_boxes = self.sigmoid(pred_boxes)
  965. return pred_boxes
  966. def class_predictor(
  967. self,
  968. image_feats: torch.FloatTensor,
  969. query_embeds: torch.FloatTensor | None = None,
  970. query_mask: torch.Tensor | None = None,
  971. ) -> tuple[torch.FloatTensor]:
  972. """
  973. Args:
  974. image_feats:
  975. Features extracted from the `image_text_embedder`.
  976. query_embeds:
  977. Text query embeddings.
  978. query_mask:
  979. Must be provided with query_embeddings. A mask indicating which query embeddings are valid.
  980. """
  981. (pred_logits, image_class_embeds) = self.class_head(image_feats, query_embeds, query_mask)
  982. return (pred_logits, image_class_embeds)
  983. def image_text_embedder(
  984. self,
  985. input_ids: torch.Tensor,
  986. pixel_values: torch.FloatTensor,
  987. attention_mask: torch.Tensor,
  988. interpolate_pos_encoding: bool = False,
  989. **kwargs: Unpack[TransformersKwargs],
  990. ) -> tuple[torch.FloatTensor]:
  991. outputs = self.owlvit(
  992. pixel_values=pixel_values,
  993. input_ids=input_ids,
  994. attention_mask=attention_mask,
  995. interpolate_pos_encoding=interpolate_pos_encoding,
  996. **kwargs,
  997. )
  998. if interpolate_pos_encoding:
  999. _, _, height, width = pixel_values.shape
  1000. num_patches_height = height // self.config.vision_config.patch_size
  1001. num_patches_width = width // self.config.vision_config.patch_size
  1002. else:
  1003. num_patches_height = self.num_patches_height
  1004. num_patches_width = self.num_patches_width
  1005. # Get image embeddings
  1006. last_hidden_state = outputs.vision_model_output[0]
  1007. image_embeds = self.owlvit.vision_model.post_layernorm(last_hidden_state)
  1008. # Resize class token
  1009. class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape)
  1010. # Merge image embedding with class tokens
  1011. image_embeds = image_embeds[:, 1:, :] * class_token_out
  1012. image_embeds = self.layer_norm(image_embeds)
  1013. # Resize to [batch_size, num_patches_height, num_patches_width, hidden_size]
  1014. new_size = (
  1015. image_embeds.shape[0],
  1016. num_patches_height,
  1017. num_patches_width,
  1018. image_embeds.shape[-1],
  1019. )
  1020. image_embeds = image_embeds.reshape(new_size)
  1021. text_embeds = outputs[-4]
  1022. return (text_embeds, image_embeds, outputs)
  1023. def image_embedder(
  1024. self,
  1025. pixel_values: torch.FloatTensor,
  1026. interpolate_pos_encoding: bool = False,
  1027. **kwargs: Unpack[TransformersKwargs],
  1028. ) -> tuple[torch.FloatTensor]:
  1029. # Get OwlViTModel vision embeddings (same as CLIP)
  1030. vision_outputs: BaseModelOutputWithPooling = self.owlvit.vision_model(
  1031. pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, **kwargs
  1032. )
  1033. if interpolate_pos_encoding:
  1034. _, _, height, width = pixel_values.shape
  1035. num_patches_height = height // self.config.vision_config.patch_size
  1036. num_patches_width = width // self.config.vision_config.patch_size
  1037. else:
  1038. num_patches_height = self.num_patches_height
  1039. num_patches_width = self.num_patches_width
  1040. # Apply post_layernorm to last_hidden_state, return non-projected output
  1041. last_hidden_state = vision_outputs[0]
  1042. image_embeds = self.owlvit.vision_model.post_layernorm(last_hidden_state)
  1043. # Resize class token
  1044. class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape)
  1045. # Merge image embedding with class tokens
  1046. image_embeds = image_embeds[:, 1:, :] * class_token_out
  1047. image_embeds = self.layer_norm(image_embeds)
  1048. # Resize to [batch_size, num_patches_height, num_patches_width, hidden_size]
  1049. new_size = (
  1050. image_embeds.shape[0],
  1051. num_patches_height,
  1052. num_patches_width,
  1053. image_embeds.shape[-1],
  1054. )
  1055. image_embeds = image_embeds.reshape(new_size)
  1056. return (image_embeds, vision_outputs)
  1057. def embed_image_query(
  1058. self,
  1059. query_image_features: torch.FloatTensor,
  1060. query_feature_map: torch.FloatTensor,
  1061. interpolate_pos_encoding: bool = False,
  1062. ) -> torch.FloatTensor:
  1063. _, class_embeds = self.class_predictor(query_image_features)
  1064. pred_boxes = self.box_predictor(query_image_features, query_feature_map, interpolate_pos_encoding)
  1065. pred_boxes_as_corners = center_to_corners_format(pred_boxes)
  1066. # Loop over query images
  1067. best_class_embeds = []
  1068. best_box_indices = []
  1069. pred_boxes_device = pred_boxes_as_corners.device
  1070. for i in range(query_image_features.shape[0]):
  1071. each_query_box = torch.tensor([[0, 0, 1, 1]], device=pred_boxes_device)
  1072. each_query_pred_boxes = pred_boxes_as_corners[i]
  1073. ious, _ = box_iou(each_query_box, each_query_pred_boxes)
  1074. # If there are no overlapping boxes, fall back to generalized IoU
  1075. if torch.all(ious[0] == 0.0):
  1076. ious = generalized_box_iou(each_query_box, each_query_pred_boxes)
  1077. # Use an adaptive threshold to include all boxes within 80% of the best IoU
  1078. iou_threshold = torch.max(ious) * 0.8
  1079. selected_inds = (ious[0] >= iou_threshold).nonzero()
  1080. if selected_inds.numel():
  1081. selected_embeddings = class_embeds[i][selected_inds.squeeze(1)]
  1082. mean_embeds = torch.mean(class_embeds[i], axis=0)
  1083. mean_sim = torch.einsum("d,id->i", mean_embeds, selected_embeddings)
  1084. best_box_ind = selected_inds[torch.argmin(mean_sim)]
  1085. best_class_embeds.append(class_embeds[i][best_box_ind])
  1086. best_box_indices.append(best_box_ind)
  1087. if best_class_embeds:
  1088. query_embeds = torch.stack(best_class_embeds)
  1089. box_indices = torch.stack(best_box_indices)
  1090. else:
  1091. query_embeds, box_indices = None, None
  1092. return query_embeds, box_indices, pred_boxes
  1093. @can_return_tuple
  1094. @auto_docstring
  1095. def image_guided_detection(
  1096. self,
  1097. pixel_values: torch.FloatTensor,
  1098. query_pixel_values: torch.FloatTensor | None = None,
  1099. interpolate_pos_encoding: bool = False,
  1100. **kwargs: Unpack[TransformersKwargs],
  1101. ) -> OwlViTImageGuidedObjectDetectionOutput:
  1102. r"""
  1103. query_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  1104. Pixel values of query image(s) to be detected. Pass in one query image per target image.
  1105. Examples:
  1106. ```python
  1107. >>> import httpx
  1108. >>> from io import BytesIO
  1109. >>> from PIL import Image
  1110. >>> import torch
  1111. >>> from transformers import AutoProcessor, OwlViTForObjectDetection
  1112. >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch16")
  1113. >>> model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch16")
  1114. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1115. >>> with httpx.stream("GET", url) as response:
  1116. ... image = Image.open(BytesIO(response.read()))
  1117. >>> query_url = "http://images.cocodataset.org/val2017/000000001675.jpg"
  1118. >>> with httpx.stream("GET", query_url) as response:
  1119. ... query_image = Image.open(BytesIO(response.read()))
  1120. >>> inputs = processor(images=image, query_images=query_image, return_tensors="pt")
  1121. >>> with torch.no_grad():
  1122. ... outputs = model.image_guided_detection(**inputs)
  1123. >>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
  1124. >>> target_sizes = torch.Tensor([image.size[::-1]])
  1125. >>> # Convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
  1126. >>> results = processor.post_process_image_guided_detection(
  1127. ... outputs=outputs, threshold=0.6, nms_threshold=0.3, target_sizes=target_sizes
  1128. ... )
  1129. >>> i = 0 # Retrieve predictions for the first image
  1130. >>> boxes, scores = results[i]["boxes"], results[i]["scores"]
  1131. >>> for box, score in zip(boxes, scores):
  1132. ... box = [round(i, 2) for i in box.tolist()]
  1133. ... print(f"Detected similar object with confidence {round(score.item(), 3)} at location {box}")
  1134. Detected similar object with confidence 0.856 at location [10.94, 50.4, 315.8, 471.39]
  1135. Detected similar object with confidence 1.0 at location [334.84, 25.33, 636.16, 374.71]
  1136. ```"""
  1137. # Compute feature maps for the input and query images
  1138. query_feature_map = self.image_embedder(
  1139. pixel_values=query_pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
  1140. )[0]
  1141. feature_map, vision_outputs = self.image_embedder(
  1142. pixel_values=pixel_values,
  1143. interpolate_pos_encoding=interpolate_pos_encoding,
  1144. **kwargs,
  1145. )
  1146. batch_size, num_patches_height, num_patches_width, hidden_dim = feature_map.shape
  1147. image_feats = torch.reshape(feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim))
  1148. batch_size, num_patches_height, num_patches_width, hidden_dim = query_feature_map.shape
  1149. query_image_feats = torch.reshape(
  1150. query_feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim)
  1151. )
  1152. # Get top class embedding and best box index for each query image in batch
  1153. query_embeds, best_box_indices, query_pred_boxes = self.embed_image_query(
  1154. query_image_feats, query_feature_map, interpolate_pos_encoding
  1155. )
  1156. # Predict object classes [batch_size, num_patches, num_queries+1]
  1157. (pred_logits, class_embeds) = self.class_predictor(image_feats=image_feats, query_embeds=query_embeds)
  1158. # Predict object boxes
  1159. target_pred_boxes = self.box_predictor(image_feats, feature_map, interpolate_pos_encoding)
  1160. return OwlViTImageGuidedObjectDetectionOutput(
  1161. image_embeds=feature_map,
  1162. query_image_embeds=query_feature_map,
  1163. target_pred_boxes=target_pred_boxes,
  1164. query_pred_boxes=query_pred_boxes,
  1165. logits=pred_logits,
  1166. class_embeds=class_embeds,
  1167. text_model_output=None,
  1168. vision_model_output=vision_outputs,
  1169. )
  1170. @can_return_tuple
  1171. @auto_docstring
  1172. def forward(
  1173. self,
  1174. input_ids: torch.Tensor,
  1175. pixel_values: torch.FloatTensor,
  1176. attention_mask: torch.Tensor | None = None,
  1177. interpolate_pos_encoding: bool = False,
  1178. **kwargs: Unpack[TransformersKwargs],
  1179. ) -> OwlViTObjectDetectionOutput:
  1180. r"""
  1181. input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`, *optional*):
  1182. Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
  1183. [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
  1184. IDs?](../glossary#input-ids).
  1185. Examples:
  1186. ```python
  1187. >>> import httpx
  1188. >>> from io import BytesIO
  1189. >>> from PIL import Image
  1190. >>> import torch
  1191. >>> from transformers import OwlViTProcessor, OwlViTForObjectDetection
  1192. >>> processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
  1193. >>> model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
  1194. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1195. >>> with httpx.stream("GET", url) as response:
  1196. ... image = Image.open(BytesIO(response.read()))
  1197. >>> text_labels = [["a photo of a cat", "a photo of a dog"]]
  1198. >>> inputs = processor(text=text_labels, images=image, return_tensors="pt")
  1199. >>> outputs = model(**inputs)
  1200. >>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
  1201. >>> target_sizes = torch.tensor([(image.height, image.width)])
  1202. >>> # Convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
  1203. >>> results = processor.post_process_grounded_object_detection(
  1204. ... outputs=outputs, target_sizes=target_sizes, threshold=0.1, text_labels=text_labels
  1205. ... )
  1206. >>> # Retrieve predictions for the first image for the corresponding text queries
  1207. >>> result = results[0]
  1208. >>> boxes, scores, text_labels = result["boxes"], result["scores"], result["text_labels"]
  1209. >>> for box, score, text_label in zip(boxes, scores, text_labels):
  1210. ... box = [round(i, 2) for i in box.tolist()]
  1211. ... print(f"Detected {text_label} with confidence {round(score.item(), 3)} at location {box}")
  1212. Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29]
  1213. Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17]
  1214. ```"""
  1215. # Embed images and text queries
  1216. query_embeds, feature_map, outputs = self.image_text_embedder(
  1217. input_ids=input_ids,
  1218. pixel_values=pixel_values,
  1219. attention_mask=attention_mask,
  1220. interpolate_pos_encoding=interpolate_pos_encoding,
  1221. **kwargs,
  1222. )
  1223. # Text and vision model outputs
  1224. text_outputs = outputs.text_model_output
  1225. vision_outputs = outputs.vision_model_output
  1226. batch_size, num_patches_height, num_patches_width, hidden_dim = feature_map.shape
  1227. image_feats = torch.reshape(feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim))
  1228. # Reshape from [batch_size * max_text_queries, hidden_dim] -> [batch_size, max_text_queries, hidden_dim]
  1229. max_text_queries = input_ids.shape[0] // batch_size
  1230. query_embeds = query_embeds.reshape(batch_size, max_text_queries, query_embeds.shape[-1])
  1231. # If first token is 0, then this is a padded query [batch_size, num_queries].
  1232. input_ids = input_ids.reshape(batch_size, max_text_queries, input_ids.shape[-1])
  1233. query_mask = input_ids[..., 0] > 0
  1234. # Predict object classes [batch_size, num_patches, num_queries+1]
  1235. (pred_logits, class_embeds) = self.class_predictor(image_feats, query_embeds, query_mask)
  1236. # Predict object boxes
  1237. pred_boxes = self.box_predictor(image_feats, feature_map, interpolate_pos_encoding)
  1238. return OwlViTObjectDetectionOutput(
  1239. image_embeds=feature_map,
  1240. text_embeds=query_embeds,
  1241. pred_boxes=pred_boxes,
  1242. logits=pred_logits,
  1243. class_embeds=class_embeds,
  1244. text_model_output=text_outputs,
  1245. vision_model_output=vision_outputs,
  1246. )
  1247. __all__ = ["OwlViTModel", "OwlViTPreTrainedModel", "OwlViTTextModel", "OwlViTVisionModel", "OwlViTForObjectDetection"]