modeling_groupvit.py 53 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343
  1. # Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch GroupViT model."""
  15. import collections.abc
  16. from dataclasses import dataclass
  17. from typing import Any
  18. import numpy as np
  19. import torch
  20. from torch import nn
  21. from ... import initialization as init
  22. from ...activations import ACT2FN
  23. from ...masking_utils import create_causal_mask
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  26. from ...modeling_utils import PreTrainedModel
  27. from ...processing_utils import Unpack
  28. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int
  29. from ...utils.generic import merge_with_config_defaults
  30. from ...utils.output_capturing import capture_outputs
  31. from .configuration_groupvit import GroupViTConfig, GroupViTTextConfig, GroupViTVisionConfig
  32. logger = logging.get_logger(__name__)
  33. # contrastive loss function, adapted from
  34. # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
  35. def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
  36. return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
  37. # Copied from transformers.models.clip.modeling_clip.clip_loss with clip->groupvit
  38. def groupvit_loss(similarity: torch.Tensor) -> torch.Tensor:
  39. caption_loss = contrastive_loss(similarity)
  40. image_loss = contrastive_loss(similarity.t())
  41. return (caption_loss + image_loss) / 2.0
  42. def hard_softmax(logits: torch.Tensor, dim: int):
  43. y_soft = logits.softmax(dim)
  44. # Straight through.
  45. index = y_soft.max(dim, keepdim=True)[1]
  46. y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
  47. ret = y_hard - y_soft.detach() + y_soft
  48. return ret
  49. def gumbel_softmax(logits: torch.Tensor, tau: float = 1, hard: bool = False, dim: int = -1) -> torch.Tensor:
  50. # more stable https://github.com/pytorch/pytorch/issues/41663
  51. gumbel_dist = torch.distributions.gumbel.Gumbel(
  52. torch.tensor(0.0, device=logits.device, dtype=logits.dtype),
  53. torch.tensor(1.0, device=logits.device, dtype=logits.dtype),
  54. )
  55. gumbels = gumbel_dist.sample(logits.shape)
  56. gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
  57. y_soft = gumbels.softmax(dim)
  58. if hard:
  59. # Straight through.
  60. index = y_soft.max(dim, keepdim=True)[1]
  61. y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
  62. ret = y_hard - y_soft.detach() + y_soft
  63. else:
  64. # Reparameterization trick.
  65. ret = y_soft
  66. return ret
  67. def resize_attention_map(attentions, height, width, align_corners=False):
  68. """
  69. Args:
  70. attentions (`torch.Tensor`): attention map of shape [batch_size, groups, feat_height*feat_width]
  71. height (`int`): height of the output attention map
  72. width (`int`): width of the output attention map
  73. align_corners (`bool`, *optional*): the `align_corner` argument for `nn.functional.interpolate`.
  74. Returns:
  75. `torch.Tensor`: resized attention map of shape [batch_size, groups, height, width]
  76. """
  77. scale = (height * width // attentions.shape[2]) ** 0.5
  78. if height > width:
  79. feat_width = int(np.round(width / scale))
  80. feat_height = attentions.shape[2] // feat_width
  81. else:
  82. feat_height = int(np.round(height / scale))
  83. feat_width = attentions.shape[2] // feat_height
  84. batch_size = attentions.shape[0]
  85. groups = attentions.shape[1] # number of group token
  86. # [batch_size, groups, height*width, groups] -> [batch_size, groups, height, width]
  87. attentions = attentions.reshape(batch_size, groups, feat_height, feat_width)
  88. attentions = nn.functional.interpolate(
  89. attentions, size=(height, width), mode="bilinear", align_corners=align_corners
  90. )
  91. return attentions
  92. def get_grouping_from_attentions(attentions, hw_shape):
  93. """
  94. Args:
  95. attentions (`tuple(torch.FloatTensor)`: tuple of attention maps returned by `GroupViTVisionTransformer`
  96. hw_shape (`tuple(int)`): height and width of the output attention map
  97. Returns:
  98. `torch.Tensor`: the attention map of shape [batch_size, groups, height, width]
  99. """
  100. attn_maps = []
  101. with torch.no_grad():
  102. prev_attn_masks = None
  103. for attn_masks in attentions:
  104. # [batch_size, num_groups, height x width] -> [batch_size, height x width, num_groups]
  105. attn_masks = attn_masks.permute(0, 2, 1).contiguous()
  106. if prev_attn_masks is None:
  107. prev_attn_masks = attn_masks
  108. else:
  109. prev_attn_masks = prev_attn_masks @ attn_masks
  110. # [batch_size, heightxwidth, num_groups] -> [batch_size, num_groups, heightxwidth] -> [batch_size, num_groups, height, width]
  111. cur_attn_map = resize_attention_map(prev_attn_masks.permute(0, 2, 1).contiguous(), *hw_shape)
  112. attn_maps.append(cur_attn_map)
  113. # [batch_size, num_groups, height, width]
  114. final_grouping = attn_maps[-1]
  115. return final_grouping
  116. class GroupViTCrossAttentionLayer(nn.Module):
  117. def __init__(self, config: GroupViTVisionConfig):
  118. super().__init__()
  119. self.attn = GroupViTAttention(config)
  120. self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  121. self.mlp = GroupViTMLP(config)
  122. self.norm_post = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  123. def forward(self, query, key):
  124. x = query
  125. x = x + self.attn(query, encoder_hidden_states=key)[0]
  126. x = x + self.mlp(self.norm2(x))
  127. x = self.norm_post(x)
  128. return x
  129. class GroupViTAssignAttention(nn.Module):
  130. def __init__(self, config: GroupViTVisionConfig):
  131. super().__init__()
  132. self.scale = config.hidden_size**-0.5
  133. self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
  134. self.k_proj = nn.Linear(config.hidden_size, config.hidden_size)
  135. self.v_proj = nn.Linear(config.hidden_size, config.hidden_size)
  136. self.proj = nn.Linear(config.hidden_size, config.hidden_size)
  137. self.assign_eps = config.assign_eps
  138. def get_attn(self, attn, gumbel=True, hard=True):
  139. if gumbel and self.training:
  140. attn = gumbel_softmax(attn, dim=-2, hard=hard)
  141. else:
  142. if hard:
  143. attn = hard_softmax(attn, dim=-2)
  144. else:
  145. attn = nn.functional.softmax(attn, dim=-2)
  146. return attn
  147. def forward(self, query, key):
  148. value = key
  149. # [batch_size, query_length, channels]
  150. query = self.q_proj(query)
  151. # [batch_size, key_length, channels]
  152. key = self.k_proj(key)
  153. # [batch_size, key_length, channels]
  154. value = self.v_proj(value)
  155. # [batch_size, query_length, key_length]
  156. raw_attn = (query @ key.transpose(-2, -1)) * self.scale
  157. attn = self.get_attn(raw_attn)
  158. soft_attn = self.get_attn(raw_attn, gumbel=False, hard=False)
  159. attn = attn / (attn.sum(dim=-1, keepdim=True) + self.assign_eps)
  160. out = attn @ value
  161. out = self.proj(out)
  162. return out, soft_attn
  163. class GroupViTTokenAssign(nn.Module):
  164. def __init__(self, config: GroupViTVisionConfig, num_group_token, num_output_group):
  165. super().__init__()
  166. self.num_output_group = num_output_group
  167. # norm on group_tokens
  168. self.norm_tokens = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  169. assign_mlp_ratio = (
  170. config.assign_mlp_ratio
  171. if isinstance(config.assign_mlp_ratio, collections.abc.Iterable)
  172. else (config.assign_mlp_ratio, config.assign_mlp_ratio)
  173. )
  174. tokens_dim, channels_dim = [int(x * config.hidden_size) for x in assign_mlp_ratio]
  175. self.mlp_inter = GroupViTMixerMLP(config, num_group_token, tokens_dim, num_output_group)
  176. self.norm_post_tokens = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  177. # norm on x
  178. self.norm_x = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  179. self.pre_assign_attn = GroupViTCrossAttentionLayer(config)
  180. self.assign = GroupViTAssignAttention(config)
  181. self.norm_new_x = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  182. self.mlp_channels = GroupViTMLP(config, config.hidden_size, channels_dim, config.hidden_size)
  183. def project_group_token(self, group_tokens):
  184. """
  185. Args:
  186. group_tokens (torch.Tensor): group tokens, [batch_size, num_group_tokens, channels]
  187. Returns:
  188. projected_group_tokens (torch.Tensor): [batch_size, num_output_groups, channels]
  189. """
  190. # [B, num_output_groups, C] <- [B, num_group_tokens, C]
  191. projected_group_tokens = self.mlp_inter(group_tokens)
  192. projected_group_tokens = self.norm_post_tokens(projected_group_tokens)
  193. return projected_group_tokens
  194. def forward(self, image_tokens, group_tokens):
  195. """
  196. Args:
  197. image_tokens (`torch.Tensor`): image tokens, of shape [batch_size, input_length, channels]
  198. group_tokens (`torch.Tensor`): group tokens, [batch_size, num_group_tokens, channels]
  199. """
  200. group_tokens = self.norm_tokens(group_tokens)
  201. image_tokens = self.norm_x(image_tokens)
  202. # [batch_size, num_output_groups, channels]
  203. projected_group_tokens = self.project_group_token(group_tokens)
  204. projected_group_tokens = self.pre_assign_attn(projected_group_tokens, image_tokens)
  205. new_image_tokens, attention = self.assign(projected_group_tokens, image_tokens)
  206. new_image_tokens += projected_group_tokens
  207. new_image_tokens = new_image_tokens + self.mlp_channels(self.norm_new_x(new_image_tokens))
  208. return new_image_tokens, attention
  209. @dataclass
  210. @auto_docstring
  211. class GroupViTModelOutput(ModelOutput):
  212. r"""
  213. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  214. Contrastive loss for image-text similarity.
  215. logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
  216. The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
  217. similarity scores.
  218. logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
  219. The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
  220. similarity scores.
  221. segmentation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
  222. Classification scores for each pixel.
  223. <Tip warning={true}>
  224. The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
  225. to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
  226. original image size as post-processing. You should always check your logits shape and resize as needed.
  227. </Tip>
  228. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  229. The text embeddings obtained by applying the projection layer to the pooled output of
  230. [`GroupViTTextModel`].
  231. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  232. The image embeddings obtained by applying the projection layer to the pooled output of
  233. [`GroupViTVisionModel`].
  234. text_model_output (`BaseModelOutputWithPooling`):
  235. The output of the [`GroupViTTextModel`].
  236. vision_model_output (`BaseModelOutputWithPooling`):
  237. The output of the [`GroupViTVisionModel`].
  238. """
  239. loss: torch.FloatTensor | None = None
  240. logits_per_image: torch.FloatTensor | None = None
  241. logits_per_text: torch.FloatTensor | None = None
  242. segmentation_logits: torch.FloatTensor | None = None
  243. text_embeds: torch.FloatTensor | None = None
  244. image_embeds: torch.FloatTensor | None = None
  245. text_model_output: BaseModelOutputWithPooling = None
  246. vision_model_output: BaseModelOutputWithPooling = None
  247. def to_tuple(self) -> tuple[Any]:
  248. return tuple(
  249. self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
  250. for k in self.keys()
  251. )
  252. class GroupViTPatchEmbeddings(nn.Module):
  253. """
  254. Image to Patch Embedding.
  255. """
  256. def __init__(
  257. self,
  258. image_size: int | list[int] | tuple[int, int] = 224,
  259. patch_size: int | tuple[int, int] = 16,
  260. num_channels: int = 3,
  261. embed_dim: int = 768,
  262. ):
  263. super().__init__()
  264. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  265. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  266. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  267. self.image_size = image_size
  268. self.patch_size = patch_size
  269. self.num_patches = num_patches
  270. self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
  271. def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  272. batch_size, num_channels, height, width = pixel_values.shape
  273. if not interpolate_pos_encoding:
  274. if height != self.image_size[0] or width != self.image_size[1]:
  275. raise ValueError(
  276. f"Input image size ({height}*{width}) doesn't match model"
  277. f" ({self.image_size[0]}*{self.image_size[1]})."
  278. )
  279. x = self.projection(pixel_values).flatten(2).transpose(1, 2)
  280. return x
  281. class GroupViTVisionEmbeddings(nn.Module):
  282. def __init__(self, config: GroupViTVisionConfig):
  283. super().__init__()
  284. self.patch_embeddings = GroupViTPatchEmbeddings(
  285. image_size=config.image_size,
  286. patch_size=config.patch_size,
  287. num_channels=config.num_channels,
  288. embed_dim=config.hidden_size,
  289. )
  290. num_patches = self.patch_embeddings.num_patches
  291. self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches, config.hidden_size))
  292. self.dropout = nn.Dropout(config.dropout)
  293. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  294. self.patch_size = config.patch_size
  295. self.config = config
  296. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  297. """
  298. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  299. images. This method is also adapted to support torch.jit tracing and no class embeddings.
  300. Adapted from:
  301. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  302. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  303. """
  304. num_patches = embeddings.shape[1]
  305. num_positions = self.position_embeddings.shape[1]
  306. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  307. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  308. return self.position_embeddings
  309. patch_pos_embed = self.position_embeddings
  310. dim = embeddings.shape[-1]
  311. new_height = height // self.patch_size
  312. new_width = width // self.patch_size
  313. sqrt_num_positions = torch_int(num_positions**0.5)
  314. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  315. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  316. patch_pos_embed = nn.functional.interpolate(
  317. patch_pos_embed,
  318. size=(new_height, new_width),
  319. mode="bicubic",
  320. align_corners=False,
  321. )
  322. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  323. return patch_pos_embed
  324. def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  325. batch_size, num_channels, height, width = pixel_values.shape
  326. embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  327. embeddings = self.layernorm(embeddings)
  328. batch_size, seq_len, _ = embeddings.size()
  329. # add positional encoding to each token
  330. if interpolate_pos_encoding:
  331. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  332. else:
  333. embeddings = embeddings + self.position_embeddings
  334. embeddings = self.dropout(embeddings)
  335. return embeddings
  336. # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->GroupViT
  337. class GroupViTTextEmbeddings(nn.Module):
  338. def __init__(self, config: GroupViTTextConfig):
  339. super().__init__()
  340. embed_dim = config.hidden_size
  341. self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
  342. self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
  343. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  344. self.register_buffer(
  345. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  346. )
  347. def forward(
  348. self,
  349. input_ids: torch.LongTensor | None = None,
  350. position_ids: torch.LongTensor | None = None,
  351. inputs_embeds: torch.FloatTensor | None = None,
  352. ) -> torch.Tensor:
  353. seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
  354. max_position_embedding = self.position_embedding.weight.shape[0]
  355. if seq_length > max_position_embedding:
  356. raise ValueError(
  357. f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
  358. f"{seq_length} and max_position_embeddings: {max_position_embedding}"
  359. )
  360. if position_ids is None:
  361. position_ids = self.position_ids[:, :seq_length]
  362. if inputs_embeds is None:
  363. inputs_embeds = self.token_embedding(input_ids)
  364. position_embeddings = self.position_embedding(position_ids)
  365. embeddings = inputs_embeds + position_embeddings
  366. return embeddings
  367. class GroupViTStage(nn.Module):
  368. """This corresponds to the `GroupingLayer` class in the GroupViT implementation."""
  369. def __init__(
  370. self,
  371. config: GroupViTVisionConfig,
  372. depth: int,
  373. num_prev_group_token: int,
  374. num_group_token: int,
  375. num_output_group: int,
  376. ):
  377. super().__init__()
  378. self.depth = depth
  379. self.num_group_token = num_group_token
  380. if num_group_token > 0:
  381. self.group_token = nn.Parameter(torch.zeros(1, num_group_token, config.hidden_size))
  382. else:
  383. self.group_token = None
  384. self.layers = nn.ModuleList([GroupViTEncoderLayer(config) for _ in range(depth)])
  385. if num_group_token > 0:
  386. self.downsample = GroupViTTokenAssign(
  387. config=config,
  388. num_group_token=num_group_token,
  389. num_output_group=num_output_group,
  390. )
  391. else:
  392. self.downsample = None
  393. if num_prev_group_token > 0 and num_group_token > 0:
  394. self.group_projector = nn.Sequential(
  395. nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps),
  396. GroupViTMixerMLP(config, num_prev_group_token, config.hidden_size // 2, num_group_token),
  397. )
  398. else:
  399. self.group_projector = None
  400. @property
  401. def with_group_token(self):
  402. return self.group_token is not None
  403. def split_x(self, x):
  404. if self.with_group_token:
  405. return x[:, : -self.num_group_token], x[:, -self.num_group_token :]
  406. else:
  407. return x, None
  408. def concat_x(self, x: torch.Tensor, group_token: torch.Tensor | None = None) -> torch.Tensor:
  409. if group_token is None:
  410. return x
  411. return torch.cat([x, group_token], dim=1)
  412. def forward(
  413. self,
  414. hidden_states: torch.Tensor,
  415. prev_group_token: torch.Tensor | None = None,
  416. output_attentions: bool | None = False,
  417. ) -> tuple[torch.FloatTensor]:
  418. """
  419. Args:
  420. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  421. attention_mask (`torch.FloatTensor`): attention mask of size
  422. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  423. `(config.encoder_attention_heads,)`.
  424. output_attentions (`bool`, *optional*):
  425. Whether or not to return the grouping tensors of Grouping block.
  426. """
  427. if self.with_group_token:
  428. group_token = self.group_token.expand(hidden_states.size(0), -1, -1)
  429. if self.group_projector is not None:
  430. group_token = group_token + self.group_projector(prev_group_token)
  431. else:
  432. group_token = None
  433. x = hidden_states
  434. cat_x = self.concat_x(x, group_token)
  435. for layer in self.layers:
  436. cat_x = layer(cat_x, attention_mask=None)
  437. x, group_token = self.split_x(cat_x)
  438. attention = None
  439. if self.downsample is not None:
  440. x, attention = self.downsample(x, group_token)
  441. outputs = (x, group_token)
  442. if output_attentions:
  443. outputs = outputs + (attention,)
  444. return outputs
  445. class GroupViTMLP(nn.Module):
  446. def __init__(
  447. self,
  448. config: GroupViTVisionConfig,
  449. hidden_size: int | None = None,
  450. intermediate_size: int | None = None,
  451. output_size: int | None = None,
  452. ):
  453. super().__init__()
  454. self.config = config
  455. self.activation_fn = ACT2FN[config.hidden_act]
  456. hidden_size = hidden_size if hidden_size is not None else config.hidden_size
  457. intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
  458. output_size = output_size if output_size is not None else hidden_size
  459. self.fc1 = nn.Linear(hidden_size, intermediate_size)
  460. self.fc2 = nn.Linear(intermediate_size, output_size)
  461. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  462. hidden_states = self.fc1(hidden_states)
  463. hidden_states = self.activation_fn(hidden_states)
  464. hidden_states = self.fc2(hidden_states)
  465. return hidden_states
  466. class GroupViTMixerMLP(GroupViTMLP):
  467. def forward(self, x):
  468. x = super().forward(x.transpose(1, 2))
  469. return x.transpose(1, 2)
  470. class GroupViTAttention(nn.Module):
  471. """Multi-headed attention from 'Attention Is All You Need' paper"""
  472. def __init__(self, config):
  473. super().__init__()
  474. self.config = config
  475. self.embed_dim = config.hidden_size
  476. self.num_heads = config.num_attention_heads
  477. self.head_dim = self.embed_dim // self.num_heads
  478. if self.head_dim * self.num_heads != self.embed_dim:
  479. raise ValueError(
  480. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  481. f" {self.num_heads})."
  482. )
  483. self.scale = self.head_dim**-0.5
  484. self.dropout = config.attention_dropout
  485. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  486. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  487. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  488. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  489. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  490. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  491. def forward(
  492. self,
  493. hidden_states: torch.Tensor,
  494. attention_mask: torch.Tensor | None = None,
  495. encoder_hidden_states: torch.FloatTensor | None = None,
  496. **kwargs,
  497. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  498. """Input shape: Batch x Time x Channel"""
  499. bsz, tgt_len, embed_dim = hidden_states.size()
  500. is_cross_attention = encoder_hidden_states is not None
  501. # get query proj
  502. query_states = self.q_proj(hidden_states) * self.scale
  503. if is_cross_attention:
  504. key_states = self._shape(self.k_proj(encoder_hidden_states), -1, bsz)
  505. value_states = self._shape(self.v_proj(encoder_hidden_states), -1, bsz)
  506. else:
  507. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  508. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  509. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  510. query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
  511. key_states = key_states.view(*proj_shape)
  512. value_states = value_states.view(*proj_shape)
  513. src_len = key_states.size(1)
  514. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  515. if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  516. raise ValueError(
  517. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  518. f" {attn_weights.size()}"
  519. )
  520. if attention_mask is not None:
  521. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  522. raise ValueError(
  523. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
  524. )
  525. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
  526. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  527. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  528. # this operation is a bit awkward, but it's required to
  529. # make sure that attn_weights keeps its gradient.
  530. # In order to do so, attn_weights have to reshaped
  531. # twice and have to be reused in the following
  532. attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  533. attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
  534. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  535. attn_output = torch.bmm(attn_probs, value_states)
  536. if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  537. raise ValueError(
  538. f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
  539. f" {attn_output.size()}"
  540. )
  541. attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
  542. attn_output = attn_output.transpose(1, 2)
  543. attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
  544. attn_output = self.out_proj(attn_output)
  545. return attn_output, attn_weights_reshaped
  546. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->GroupViT
  547. class GroupViTEncoderLayer(GradientCheckpointingLayer):
  548. def __init__(self, config: GroupViTConfig):
  549. super().__init__()
  550. self.embed_dim = config.hidden_size
  551. self.self_attn = GroupViTAttention(config)
  552. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  553. self.mlp = GroupViTMLP(config)
  554. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  555. def forward(
  556. self,
  557. hidden_states: torch.Tensor,
  558. attention_mask: torch.Tensor,
  559. **kwargs: Unpack[TransformersKwargs],
  560. ) -> tuple[torch.FloatTensor, torch.Tensor | None]:
  561. residual = hidden_states
  562. hidden_states = self.layer_norm1(hidden_states)
  563. hidden_states, _ = self.self_attn(
  564. hidden_states=hidden_states,
  565. attention_mask=attention_mask,
  566. **kwargs,
  567. )
  568. hidden_states = residual + hidden_states
  569. residual = hidden_states
  570. hidden_states = self.layer_norm2(hidden_states)
  571. hidden_states = self.mlp(hidden_states)
  572. hidden_states = residual + hidden_states
  573. return hidden_states
  574. @auto_docstring
  575. class GroupViTPreTrainedModel(PreTrainedModel):
  576. config: GroupViTConfig
  577. base_model_prefix = "groupvit"
  578. input_modalities = ("image", "text")
  579. supports_gradient_checkpointing = True
  580. _can_record_outputs = {
  581. "hidden_states": GroupViTEncoderLayer,
  582. "attentions": GroupViTAttention,
  583. }
  584. @torch.no_grad()
  585. def _init_weights(self, module):
  586. """Initialize the weights"""
  587. init_range = self.config.initializer_range
  588. if isinstance(module, (nn.Linear, nn.Conv2d)):
  589. init.normal_(module.weight, mean=0.0, std=init_range)
  590. if module.bias is not None:
  591. init.zeros_(module.bias)
  592. elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
  593. init.zeros_(module.bias)
  594. init.ones_(module.weight)
  595. if getattr(module, "running_mean", None) is not None:
  596. init.zeros_(module.running_mean)
  597. init.ones_(module.running_var)
  598. init.zeros_(module.num_batches_tracked)
  599. factor = self.config.initializer_factor
  600. if isinstance(module, GroupViTTextEmbeddings):
  601. init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02)
  602. init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02)
  603. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  604. elif isinstance(module, GroupViTAttention):
  605. factor = self.config.initializer_factor
  606. in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  607. out_proj_std = (module.embed_dim**-0.5) * factor
  608. init.normal_(module.q_proj.weight, std=in_proj_std)
  609. init.normal_(module.k_proj.weight, std=in_proj_std)
  610. init.normal_(module.v_proj.weight, std=in_proj_std)
  611. init.normal_(module.out_proj.weight, std=out_proj_std)
  612. elif isinstance(module, GroupViTMLP):
  613. factor = self.config.initializer_factor
  614. in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  615. fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
  616. init.normal_(module.fc1.weight, std=fc_std)
  617. init.normal_(module.fc2.weight, std=in_proj_std)
  618. class GroupViTVisionEncoder(nn.Module):
  619. def __init__(self, config: GroupViTVisionConfig) -> None:
  620. super().__init__()
  621. self.config = config
  622. self.stages = nn.ModuleList(
  623. [
  624. GroupViTStage(
  625. config=config,
  626. depth=config.depths[i],
  627. num_group_token=config.num_group_tokens[i],
  628. num_output_group=config.num_output_groups[i],
  629. num_prev_group_token=config.num_output_groups[i - 1] if i > 0 else 0,
  630. )
  631. for i in range(len(config.depths))
  632. ]
  633. )
  634. self.gradient_checkpointing = False
  635. def forward(
  636. self,
  637. hidden_states: torch.Tensor,
  638. output_hidden_states: bool | None = None,
  639. output_attentions: bool | None = None,
  640. return_dict: bool | None = None,
  641. ) -> tuple | BaseModelOutput:
  642. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  643. output_hidden_states = (
  644. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  645. )
  646. return_dict = return_dict if return_dict is not None else self.config.return_dict
  647. all_hidden_states = () if output_hidden_states else None
  648. all_groupings = () if output_attentions else None
  649. group_tokens = None
  650. for i, stage in enumerate(self.stages):
  651. if output_hidden_states:
  652. all_hidden_states = all_hidden_states + (hidden_states,)
  653. layer_outputs = stage(hidden_states, group_tokens, output_attentions)
  654. hidden_states = layer_outputs[0]
  655. group_tokens = layer_outputs[1]
  656. if output_attentions and layer_outputs[2] is not None:
  657. all_groupings = all_groupings + (layer_outputs[2],)
  658. if output_hidden_states:
  659. all_hidden_states = all_hidden_states + (hidden_states,)
  660. if not return_dict:
  661. return tuple(v for v in [hidden_states, all_hidden_states, all_groupings] if v is not None)
  662. return BaseModelOutput(
  663. last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_groupings
  664. )
  665. class GroupViTTextEncoder(nn.Module):
  666. """
  667. Transformer encoder consisting of `config.num_hidden_layers` self-attention layers. Each layer is a
  668. [`GroupViTEncoderLayer`].
  669. Args:
  670. config: GroupViTTextConfig
  671. """
  672. def __init__(self, config: GroupViTTextConfig):
  673. super().__init__()
  674. self.config = config
  675. self.layers = nn.ModuleList([GroupViTEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  676. self.gradient_checkpointing = False
  677. def forward(
  678. self,
  679. inputs_embeds,
  680. attention_mask: torch.Tensor | None = None,
  681. **kwargs: Unpack[TransformersKwargs],
  682. ) -> tuple | BaseModelOutput:
  683. r"""
  684. Args:
  685. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  686. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  687. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  688. than the model's internal embedding lookup matrix.
  689. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  690. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  691. - 1 for tokens that are **not masked**,
  692. - 0 for tokens that are **masked**.
  693. [What are attention masks?](../glossary#attention-mask)
  694. """
  695. hidden_states = inputs_embeds
  696. for encoder_layer in self.layers:
  697. hidden_states = encoder_layer(
  698. hidden_states,
  699. attention_mask,
  700. **kwargs,
  701. )
  702. return BaseModelOutput(
  703. last_hidden_state=hidden_states,
  704. )
  705. class GroupViTTextTransformer(GroupViTPreTrainedModel):
  706. def __init__(self, config: GroupViTTextConfig):
  707. super().__init__(config)
  708. embed_dim = config.hidden_size
  709. self.embeddings = GroupViTTextEmbeddings(config)
  710. self.encoder = GroupViTTextEncoder(config)
  711. self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  712. # For `pooled_output` computation
  713. self.eos_token_id = config.eos_token_id
  714. self.post_init()
  715. @merge_with_config_defaults
  716. @capture_outputs(tie_last_hidden_states=False)
  717. @auto_docstring
  718. def forward(
  719. self,
  720. input_ids: torch.Tensor | None = None,
  721. attention_mask: torch.Tensor | None = None,
  722. position_ids: torch.Tensor | None = None,
  723. **kwargs: Unpack[TransformersKwargs],
  724. ) -> BaseModelOutputWithPooling:
  725. if input_ids is None:
  726. raise ValueError("You have to specify input_ids")
  727. input_shape = input_ids.size()
  728. input_ids = input_ids.view(-1, input_shape[-1])
  729. hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
  730. attention_mask = create_causal_mask(
  731. config=self.config,
  732. inputs_embeds=hidden_states,
  733. attention_mask=attention_mask,
  734. past_key_values=None,
  735. )
  736. kwargs.pop("is_causal", None)
  737. encoder_outputs: BaseModelOutput = self.encoder(
  738. inputs_embeds=hidden_states,
  739. attention_mask=attention_mask,
  740. is_causal=True,
  741. **kwargs,
  742. )
  743. last_hidden_state = encoder_outputs[0]
  744. last_hidden_state = self.final_layer_norm(last_hidden_state)
  745. if self.eos_token_id == 2:
  746. # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
  747. # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
  748. # ------------------------------------------------------------
  749. # text_embeds.shape = [batch_size, sequence_length, transformer.width]
  750. # take features from the eot embedding (eot_token is the highest number in each sequence)
  751. # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
  752. pooled_output = last_hidden_state[
  753. torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
  754. input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
  755. ]
  756. else:
  757. # The config gets updated `eos_token_id` from PR #24773 (so the use of extra new tokens is possible)
  758. pooled_output = last_hidden_state[
  759. torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
  760. # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
  761. # Note: we assume each sequence (along batch dim.) contains an `eos_token_id` (e.g. prepared by the tokenizer)
  762. (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id)
  763. .int()
  764. .argmax(dim=-1),
  765. ]
  766. return BaseModelOutputWithPooling(
  767. last_hidden_state=last_hidden_state,
  768. pooler_output=pooled_output,
  769. )
  770. class GroupViTTextModel(GroupViTPreTrainedModel):
  771. config: GroupViTTextConfig
  772. input_modalities = ("text",)
  773. def __init__(self, config: GroupViTTextConfig):
  774. super().__init__(config)
  775. self.text_model = GroupViTTextTransformer(config)
  776. # Initialize weights and apply final processing
  777. self.post_init()
  778. def get_input_embeddings(self) -> nn.Module:
  779. return self.text_model.embeddings.token_embedding
  780. def set_input_embeddings(self, value):
  781. self.text_model.embeddings.token_embedding = value
  782. @auto_docstring
  783. def forward(
  784. self,
  785. input_ids: torch.Tensor | None = None,
  786. attention_mask: torch.Tensor | None = None,
  787. position_ids: torch.Tensor | None = None,
  788. **kwargs: Unpack[TransformersKwargs],
  789. ) -> tuple | BaseModelOutputWithPooling:
  790. r"""
  791. Examples:
  792. ```python
  793. >>> from transformers import CLIPTokenizer, GroupViTTextModel
  794. >>> tokenizer = CLIPTokenizer.from_pretrained("nvidia/groupvit-gcc-yfcc")
  795. >>> model = GroupViTTextModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
  796. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
  797. >>> outputs = model(**inputs)
  798. >>> last_hidden_state = outputs.last_hidden_state
  799. >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
  800. ```"""
  801. return self.text_model(
  802. input_ids=input_ids,
  803. attention_mask=attention_mask,
  804. position_ids=position_ids,
  805. **kwargs,
  806. )
  807. class GroupViTVisionTransformer(nn.Module):
  808. def __init__(self, config: GroupViTVisionConfig):
  809. super().__init__()
  810. self.config = config
  811. embed_dim = config.hidden_size
  812. self.embeddings = GroupViTVisionEmbeddings(config)
  813. self.encoder = GroupViTVisionEncoder(config)
  814. self.layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  815. @auto_docstring
  816. def forward(
  817. self,
  818. pixel_values: torch.FloatTensor | None = None,
  819. output_hidden_states: bool | None = None,
  820. output_attentions: bool | None = None,
  821. return_dict: bool | None = None,
  822. ) -> tuple | BaseModelOutputWithPooling:
  823. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  824. output_hidden_states = (
  825. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  826. )
  827. return_dict = return_dict if return_dict is not None else self.config.return_dict
  828. if pixel_values is None:
  829. raise ValueError("You have to specify pixel_values")
  830. hidden_states = self.embeddings(pixel_values)
  831. encoder_outputs = self.encoder(
  832. hidden_states=hidden_states,
  833. output_hidden_states=output_hidden_states,
  834. output_attentions=output_attentions,
  835. return_dict=return_dict,
  836. )
  837. last_hidden_state = encoder_outputs[0]
  838. # normalize the last hidden state
  839. last_hidden_state = self.layernorm(last_hidden_state)
  840. pooled_output = last_hidden_state.mean(dim=1)
  841. if not return_dict:
  842. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  843. return BaseModelOutputWithPooling(
  844. last_hidden_state=last_hidden_state,
  845. pooler_output=pooled_output,
  846. hidden_states=encoder_outputs.hidden_states,
  847. attentions=encoder_outputs.attentions,
  848. )
  849. class GroupViTVisionModel(GroupViTPreTrainedModel):
  850. config: GroupViTVisionConfig
  851. main_input_name = "pixel_values"
  852. input_modalities = ("image",)
  853. _can_record_outputs = {}
  854. def __init__(self, config: GroupViTVisionConfig):
  855. super().__init__(config)
  856. self.vision_model = GroupViTVisionTransformer(config)
  857. # Initialize weights and apply final processing
  858. self.post_init()
  859. def get_input_embeddings(self) -> GroupViTPatchEmbeddings:
  860. return self.vision_model.embeddings.patch_embeddings
  861. @auto_docstring
  862. def forward(
  863. self,
  864. pixel_values: torch.FloatTensor | None = None,
  865. output_attentions: bool | None = None,
  866. output_hidden_states: bool | None = None,
  867. return_dict: bool | None = None,
  868. **kwargs,
  869. ) -> tuple | BaseModelOutputWithPooling:
  870. r"""
  871. Examples:
  872. ```python
  873. >>> from PIL import Image
  874. >>> import httpx
  875. >>> from io import BytesIO
  876. >>> from transformers import AutoProcessor, GroupViTVisionModel
  877. >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")
  878. >>> model = GroupViTVisionModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
  879. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  880. >>> with httpx.stream("GET", url) as response:
  881. ... image = Image.open(BytesIO(response.read()))
  882. >>> inputs = processor(images=image, return_tensors="pt")
  883. >>> outputs = model(**inputs)
  884. >>> last_hidden_state = outputs.last_hidden_state
  885. >>> pooled_output = outputs.pooler_output # pooled CLS states
  886. ```"""
  887. return self.vision_model(
  888. pixel_values=pixel_values,
  889. output_attentions=output_attentions,
  890. output_hidden_states=output_hidden_states,
  891. return_dict=return_dict,
  892. )
  893. @auto_docstring
  894. class GroupViTModel(GroupViTPreTrainedModel):
  895. config: GroupViTConfig
  896. def __init__(self, config: GroupViTConfig):
  897. super().__init__(config)
  898. if not isinstance(config.text_config, GroupViTTextConfig):
  899. raise TypeError(
  900. "config.text_config is expected to be of type GroupViTTextConfig but is of type"
  901. f" {type(config.text_config)}."
  902. )
  903. if not isinstance(config.vision_config, GroupViTVisionConfig):
  904. raise TypeError(
  905. "config.vision_config is expected to be of type GroupViTVisionConfig but is of type"
  906. f" {type(config.vision_config)}."
  907. )
  908. text_config = config.text_config
  909. vision_config = config.vision_config
  910. self.projection_dim = config.projection_dim
  911. self.projection_intermediate_dim = config.projection_intermediate_dim
  912. self.text_embed_dim = text_config.hidden_size
  913. self.vision_embed_dim = vision_config.hidden_size
  914. self.text_model = GroupViTTextTransformer(text_config)
  915. self.vision_model = GroupViTVisionTransformer(vision_config)
  916. self.visual_projection = nn.Sequential(
  917. nn.Linear(self.vision_embed_dim, self.projection_intermediate_dim, bias=True),
  918. nn.BatchNorm1d(self.projection_intermediate_dim),
  919. nn.ReLU(inplace=True),
  920. nn.Linear(self.projection_intermediate_dim, self.projection_dim, bias=True),
  921. )
  922. self.text_projection = nn.Sequential(
  923. nn.Linear(self.text_embed_dim, self.projection_intermediate_dim, bias=True),
  924. nn.BatchNorm1d(self.projection_intermediate_dim),
  925. nn.ReLU(inplace=True),
  926. nn.Linear(self.projection_intermediate_dim, self.projection_dim, bias=True),
  927. )
  928. self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
  929. # Initialize weights and apply final processing
  930. self.post_init()
  931. @can_return_tuple
  932. @auto_docstring
  933. def get_text_features(
  934. self,
  935. input_ids: torch.Tensor,
  936. attention_mask: torch.Tensor | None = None,
  937. position_ids: torch.Tensor | None = None,
  938. **kwargs: Unpack[TransformersKwargs],
  939. ) -> tuple | BaseModelOutputWithPooling:
  940. r"""
  941. Examples:
  942. ```python
  943. >>> import torch
  944. >>> from transformers import CLIPTokenizer, GroupViTModel
  945. >>> model = GroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
  946. >>> tokenizer = CLIPTokenizer.from_pretrained("nvidia/groupvit-gcc-yfcc")
  947. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
  948. >>> with torch.inference_mode():
  949. ... text_features = model.get_text_features(**inputs)
  950. ```"""
  951. text_outputs: BaseModelOutputWithPooling = self.text_model(
  952. input_ids=input_ids,
  953. attention_mask=attention_mask,
  954. position_ids=position_ids,
  955. return_dict=True,
  956. **kwargs,
  957. )
  958. pooled_output = text_outputs.pooler_output
  959. text_outputs.pooler_output = self.text_projection(pooled_output)
  960. return text_outputs
  961. @can_return_tuple
  962. @auto_docstring
  963. def get_image_features(
  964. self,
  965. pixel_values: torch.Tensor,
  966. **kwargs: Unpack[TransformersKwargs],
  967. ) -> tuple | BaseModelOutputWithPooling:
  968. r"""
  969. Examples:
  970. ```python
  971. >>> import torch
  972. >>> from transformers import AutoProcessor, GroupViTModel
  973. >>> from transformers.image_utils import load_image
  974. >>> model = GroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
  975. >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")
  976. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  977. >>> image = load_image(url)
  978. >>> inputs = processor(images=image, return_tensors="pt")
  979. >>> with torch.inference_mode():
  980. ... image_features = model.get_image_features(**inputs)
  981. ```"""
  982. vision_outputs: BaseModelOutputWithPooling = self.vision_model(pixel_values, return_dict=True, **kwargs)
  983. vision_outputs.pooler_output = self.visual_projection(vision_outputs.pooler_output)
  984. return vision_outputs
  985. @can_return_tuple
  986. @auto_docstring
  987. def forward(
  988. self,
  989. input_ids: torch.LongTensor | None = None,
  990. pixel_values: torch.FloatTensor | None = None,
  991. attention_mask: torch.Tensor | None = None,
  992. position_ids: torch.LongTensor | None = None,
  993. return_loss: bool | None = None,
  994. output_attentions: bool | None = None,
  995. output_hidden_states: bool | None = None,
  996. output_segmentation: bool | None = None,
  997. **kwargs: Unpack[TransformersKwargs],
  998. ) -> tuple | GroupViTModelOutput:
  999. r"""
  1000. return_loss (`bool`, *optional*):
  1001. Whether or not to return the contrastive loss.
  1002. output_segmentation (`bool`, *optional*):
  1003. Whether or not to return the segmentation logits.
  1004. Examples:
  1005. ```python
  1006. >>> from PIL import Image
  1007. >>> import httpx
  1008. >>> from io import BytesIO
  1009. >>> from transformers import AutoProcessor, GroupViTModel
  1010. >>> model = GroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
  1011. >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")
  1012. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1013. >>> with httpx.stream("GET", url) as response:
  1014. ... image = Image.open(BytesIO(response.read()))
  1015. >>> inputs = processor(
  1016. ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
  1017. ... )
  1018. >>> outputs = model(**inputs)
  1019. >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
  1020. >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
  1021. ```"""
  1022. # Use GROUPVIT model's config for some fields (if specified) instead of those of vision & text components.
  1023. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1024. output_segmentation = (
  1025. output_segmentation if output_segmentation is not None else self.config.output_segmentation
  1026. )
  1027. if output_segmentation:
  1028. output_attentions = True
  1029. output_hidden_states = (
  1030. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1031. )
  1032. # Vision side uses explicit flags (nn.Module-based, not hook-based)
  1033. vision_outputs = self.vision_model(
  1034. pixel_values=pixel_values,
  1035. output_attentions=output_attentions,
  1036. output_hidden_states=output_hidden_states,
  1037. return_dict=True,
  1038. )
  1039. text_outputs: BaseModelOutputWithPooling = self.text_model(
  1040. input_ids=input_ids,
  1041. attention_mask=attention_mask,
  1042. position_ids=position_ids,
  1043. **kwargs,
  1044. )
  1045. image_embeds = vision_outputs.pooler_output
  1046. image_embeds = self.visual_projection(image_embeds)
  1047. text_embeds = text_outputs.pooler_output
  1048. text_embeds = self.text_projection(text_embeds)
  1049. # normalized features
  1050. image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
  1051. text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
  1052. # cosine similarity as logits
  1053. logit_scale = self.logit_scale.exp()
  1054. logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
  1055. logits_per_image = logits_per_text.t()
  1056. seg_logits = None
  1057. if output_segmentation:
  1058. # grouped features
  1059. # [batch_size_image, num_group, hidden_size]
  1060. image_group_embeds = vision_outputs.last_hidden_state
  1061. # [batch_size_image*num_group, hidden_size]
  1062. image_group_embeds = self.visual_projection(image_group_embeds.reshape(-1, image_group_embeds.shape[-1]))
  1063. attentions = vision_outputs.attentions
  1064. # [batch_size_image, num_group, height, width]
  1065. grouping = get_grouping_from_attentions(attentions, pixel_values.shape[2:])
  1066. # normalized features
  1067. image_group_embeds = image_group_embeds / image_group_embeds.norm(dim=-1, keepdim=True)
  1068. # [batch_size_image x num_group, batch_size_text]
  1069. logits_per_image_group = torch.matmul(image_group_embeds, text_embeds.t()) * logit_scale
  1070. # [batch_size_image, batch_size_text, num_group]
  1071. logits_per_image_group = logits_per_image_group.reshape(
  1072. image_embeds.shape[0], -1, text_embeds.shape[0]
  1073. ).permute(0, 2, 1)
  1074. # [batch_size_image, batch_size_text, height x width]
  1075. flatten_grouping = grouping.reshape(grouping.shape[0], grouping.shape[1], -1)
  1076. # [batch_size_image, batch_size_text, height, width]
  1077. seg_logits = torch.matmul(logits_per_image_group, flatten_grouping) * logit_scale
  1078. seg_logits = seg_logits.reshape(
  1079. seg_logits.shape[0], seg_logits.shape[1], grouping.shape[2], grouping.shape[3]
  1080. )
  1081. loss = None
  1082. if return_loss:
  1083. loss = groupvit_loss(logits_per_text)
  1084. return GroupViTModelOutput(
  1085. loss=loss,
  1086. logits_per_image=logits_per_image,
  1087. logits_per_text=logits_per_text,
  1088. segmentation_logits=seg_logits,
  1089. text_embeds=text_embeds,
  1090. image_embeds=image_embeds,
  1091. text_model_output=text_outputs,
  1092. vision_model_output=vision_outputs,
  1093. )
  1094. __all__ = ["GroupViTModel", "GroupViTPreTrainedModel", "GroupViTTextModel", "GroupViTVisionModel"]