| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343 |
- # Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch GroupViT model."""
- import collections.abc
- from dataclasses import dataclass
- from typing import Any
- import numpy as np
- import torch
- from torch import nn
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...masking_utils import create_causal_mask
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
- from ...modeling_utils import PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int
- from ...utils.generic import merge_with_config_defaults
- from ...utils.output_capturing import capture_outputs
- from .configuration_groupvit import GroupViTConfig, GroupViTTextConfig, GroupViTVisionConfig
- logger = logging.get_logger(__name__)
- # contrastive loss function, adapted from
- # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
- def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
- return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
- # Copied from transformers.models.clip.modeling_clip.clip_loss with clip->groupvit
- def groupvit_loss(similarity: torch.Tensor) -> torch.Tensor:
- caption_loss = contrastive_loss(similarity)
- image_loss = contrastive_loss(similarity.t())
- return (caption_loss + image_loss) / 2.0
- def hard_softmax(logits: torch.Tensor, dim: int):
- y_soft = logits.softmax(dim)
- # Straight through.
- index = y_soft.max(dim, keepdim=True)[1]
- y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
- ret = y_hard - y_soft.detach() + y_soft
- return ret
- def gumbel_softmax(logits: torch.Tensor, tau: float = 1, hard: bool = False, dim: int = -1) -> torch.Tensor:
- # more stable https://github.com/pytorch/pytorch/issues/41663
- gumbel_dist = torch.distributions.gumbel.Gumbel(
- torch.tensor(0.0, device=logits.device, dtype=logits.dtype),
- torch.tensor(1.0, device=logits.device, dtype=logits.dtype),
- )
- gumbels = gumbel_dist.sample(logits.shape)
- gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
- y_soft = gumbels.softmax(dim)
- if hard:
- # Straight through.
- index = y_soft.max(dim, keepdim=True)[1]
- y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
- ret = y_hard - y_soft.detach() + y_soft
- else:
- # Reparameterization trick.
- ret = y_soft
- return ret
- def resize_attention_map(attentions, height, width, align_corners=False):
- """
- Args:
- attentions (`torch.Tensor`): attention map of shape [batch_size, groups, feat_height*feat_width]
- height (`int`): height of the output attention map
- width (`int`): width of the output attention map
- align_corners (`bool`, *optional*): the `align_corner` argument for `nn.functional.interpolate`.
- Returns:
- `torch.Tensor`: resized attention map of shape [batch_size, groups, height, width]
- """
- scale = (height * width // attentions.shape[2]) ** 0.5
- if height > width:
- feat_width = int(np.round(width / scale))
- feat_height = attentions.shape[2] // feat_width
- else:
- feat_height = int(np.round(height / scale))
- feat_width = attentions.shape[2] // feat_height
- batch_size = attentions.shape[0]
- groups = attentions.shape[1] # number of group token
- # [batch_size, groups, height*width, groups] -> [batch_size, groups, height, width]
- attentions = attentions.reshape(batch_size, groups, feat_height, feat_width)
- attentions = nn.functional.interpolate(
- attentions, size=(height, width), mode="bilinear", align_corners=align_corners
- )
- return attentions
- def get_grouping_from_attentions(attentions, hw_shape):
- """
- Args:
- attentions (`tuple(torch.FloatTensor)`: tuple of attention maps returned by `GroupViTVisionTransformer`
- hw_shape (`tuple(int)`): height and width of the output attention map
- Returns:
- `torch.Tensor`: the attention map of shape [batch_size, groups, height, width]
- """
- attn_maps = []
- with torch.no_grad():
- prev_attn_masks = None
- for attn_masks in attentions:
- # [batch_size, num_groups, height x width] -> [batch_size, height x width, num_groups]
- attn_masks = attn_masks.permute(0, 2, 1).contiguous()
- if prev_attn_masks is None:
- prev_attn_masks = attn_masks
- else:
- prev_attn_masks = prev_attn_masks @ attn_masks
- # [batch_size, heightxwidth, num_groups] -> [batch_size, num_groups, heightxwidth] -> [batch_size, num_groups, height, width]
- cur_attn_map = resize_attention_map(prev_attn_masks.permute(0, 2, 1).contiguous(), *hw_shape)
- attn_maps.append(cur_attn_map)
- # [batch_size, num_groups, height, width]
- final_grouping = attn_maps[-1]
- return final_grouping
- class GroupViTCrossAttentionLayer(nn.Module):
- def __init__(self, config: GroupViTVisionConfig):
- super().__init__()
- self.attn = GroupViTAttention(config)
- self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.mlp = GroupViTMLP(config)
- self.norm_post = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- def forward(self, query, key):
- x = query
- x = x + self.attn(query, encoder_hidden_states=key)[0]
- x = x + self.mlp(self.norm2(x))
- x = self.norm_post(x)
- return x
- class GroupViTAssignAttention(nn.Module):
- def __init__(self, config: GroupViTVisionConfig):
- super().__init__()
- self.scale = config.hidden_size**-0.5
- self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
- self.k_proj = nn.Linear(config.hidden_size, config.hidden_size)
- self.v_proj = nn.Linear(config.hidden_size, config.hidden_size)
- self.proj = nn.Linear(config.hidden_size, config.hidden_size)
- self.assign_eps = config.assign_eps
- def get_attn(self, attn, gumbel=True, hard=True):
- if gumbel and self.training:
- attn = gumbel_softmax(attn, dim=-2, hard=hard)
- else:
- if hard:
- attn = hard_softmax(attn, dim=-2)
- else:
- attn = nn.functional.softmax(attn, dim=-2)
- return attn
- def forward(self, query, key):
- value = key
- # [batch_size, query_length, channels]
- query = self.q_proj(query)
- # [batch_size, key_length, channels]
- key = self.k_proj(key)
- # [batch_size, key_length, channels]
- value = self.v_proj(value)
- # [batch_size, query_length, key_length]
- raw_attn = (query @ key.transpose(-2, -1)) * self.scale
- attn = self.get_attn(raw_attn)
- soft_attn = self.get_attn(raw_attn, gumbel=False, hard=False)
- attn = attn / (attn.sum(dim=-1, keepdim=True) + self.assign_eps)
- out = attn @ value
- out = self.proj(out)
- return out, soft_attn
- class GroupViTTokenAssign(nn.Module):
- def __init__(self, config: GroupViTVisionConfig, num_group_token, num_output_group):
- super().__init__()
- self.num_output_group = num_output_group
- # norm on group_tokens
- self.norm_tokens = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- assign_mlp_ratio = (
- config.assign_mlp_ratio
- if isinstance(config.assign_mlp_ratio, collections.abc.Iterable)
- else (config.assign_mlp_ratio, config.assign_mlp_ratio)
- )
- tokens_dim, channels_dim = [int(x * config.hidden_size) for x in assign_mlp_ratio]
- self.mlp_inter = GroupViTMixerMLP(config, num_group_token, tokens_dim, num_output_group)
- self.norm_post_tokens = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- # norm on x
- self.norm_x = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.pre_assign_attn = GroupViTCrossAttentionLayer(config)
- self.assign = GroupViTAssignAttention(config)
- self.norm_new_x = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.mlp_channels = GroupViTMLP(config, config.hidden_size, channels_dim, config.hidden_size)
- def project_group_token(self, group_tokens):
- """
- Args:
- group_tokens (torch.Tensor): group tokens, [batch_size, num_group_tokens, channels]
- Returns:
- projected_group_tokens (torch.Tensor): [batch_size, num_output_groups, channels]
- """
- # [B, num_output_groups, C] <- [B, num_group_tokens, C]
- projected_group_tokens = self.mlp_inter(group_tokens)
- projected_group_tokens = self.norm_post_tokens(projected_group_tokens)
- return projected_group_tokens
- def forward(self, image_tokens, group_tokens):
- """
- Args:
- image_tokens (`torch.Tensor`): image tokens, of shape [batch_size, input_length, channels]
- group_tokens (`torch.Tensor`): group tokens, [batch_size, num_group_tokens, channels]
- """
- group_tokens = self.norm_tokens(group_tokens)
- image_tokens = self.norm_x(image_tokens)
- # [batch_size, num_output_groups, channels]
- projected_group_tokens = self.project_group_token(group_tokens)
- projected_group_tokens = self.pre_assign_attn(projected_group_tokens, image_tokens)
- new_image_tokens, attention = self.assign(projected_group_tokens, image_tokens)
- new_image_tokens += projected_group_tokens
- new_image_tokens = new_image_tokens + self.mlp_channels(self.norm_new_x(new_image_tokens))
- return new_image_tokens, attention
- @dataclass
- @auto_docstring
- class GroupViTModelOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
- Contrastive loss for image-text similarity.
- logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
- The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
- similarity scores.
- logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
- The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
- similarity scores.
- segmentation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
- Classification scores for each pixel.
- <Tip warning={true}>
- The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
- to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
- original image size as post-processing. You should always check your logits shape and resize as needed.
- </Tip>
- text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
- The text embeddings obtained by applying the projection layer to the pooled output of
- [`GroupViTTextModel`].
- image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
- The image embeddings obtained by applying the projection layer to the pooled output of
- [`GroupViTVisionModel`].
- text_model_output (`BaseModelOutputWithPooling`):
- The output of the [`GroupViTTextModel`].
- vision_model_output (`BaseModelOutputWithPooling`):
- The output of the [`GroupViTVisionModel`].
- """
- loss: torch.FloatTensor | None = None
- logits_per_image: torch.FloatTensor | None = None
- logits_per_text: torch.FloatTensor | None = None
- segmentation_logits: torch.FloatTensor | None = None
- text_embeds: torch.FloatTensor | None = None
- image_embeds: torch.FloatTensor | None = None
- text_model_output: BaseModelOutputWithPooling = None
- vision_model_output: BaseModelOutputWithPooling = None
- def to_tuple(self) -> tuple[Any]:
- return tuple(
- self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
- for k in self.keys()
- )
- class GroupViTPatchEmbeddings(nn.Module):
- """
- Image to Patch Embedding.
- """
- def __init__(
- self,
- image_size: int | list[int] | tuple[int, int] = 224,
- patch_size: int | tuple[int, int] = 16,
- num_channels: int = 3,
- embed_dim: int = 768,
- ):
- super().__init__()
- image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
- patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
- self.image_size = image_size
- self.patch_size = patch_size
- self.num_patches = num_patches
- self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
- def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
- batch_size, num_channels, height, width = pixel_values.shape
- if not interpolate_pos_encoding:
- if height != self.image_size[0] or width != self.image_size[1]:
- raise ValueError(
- f"Input image size ({height}*{width}) doesn't match model"
- f" ({self.image_size[0]}*{self.image_size[1]})."
- )
- x = self.projection(pixel_values).flatten(2).transpose(1, 2)
- return x
- class GroupViTVisionEmbeddings(nn.Module):
- def __init__(self, config: GroupViTVisionConfig):
- super().__init__()
- self.patch_embeddings = GroupViTPatchEmbeddings(
- image_size=config.image_size,
- patch_size=config.patch_size,
- num_channels=config.num_channels,
- embed_dim=config.hidden_size,
- )
- num_patches = self.patch_embeddings.num_patches
- self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches, config.hidden_size))
- self.dropout = nn.Dropout(config.dropout)
- self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.patch_size = config.patch_size
- self.config = config
- def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
- """
- This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
- images. This method is also adapted to support torch.jit tracing and no class embeddings.
- Adapted from:
- - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
- """
- num_patches = embeddings.shape[1]
- num_positions = self.position_embeddings.shape[1]
- # always interpolate when tracing to ensure the exported model works for dynamic input shapes
- if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
- return self.position_embeddings
- patch_pos_embed = self.position_embeddings
- dim = embeddings.shape[-1]
- new_height = height // self.patch_size
- new_width = width // self.patch_size
- sqrt_num_positions = torch_int(num_positions**0.5)
- patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
- patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
- patch_pos_embed = nn.functional.interpolate(
- patch_pos_embed,
- size=(new_height, new_width),
- mode="bicubic",
- align_corners=False,
- )
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
- return patch_pos_embed
- def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
- batch_size, num_channels, height, width = pixel_values.shape
- embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
- embeddings = self.layernorm(embeddings)
- batch_size, seq_len, _ = embeddings.size()
- # add positional encoding to each token
- if interpolate_pos_encoding:
- embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
- else:
- embeddings = embeddings + self.position_embeddings
- embeddings = self.dropout(embeddings)
- return embeddings
- # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->GroupViT
- class GroupViTTextEmbeddings(nn.Module):
- def __init__(self, config: GroupViTTextConfig):
- super().__init__()
- embed_dim = config.hidden_size
- self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
- self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
- # position_ids (1, len position emb) is contiguous in memory and exported when serialized
- self.register_buffer(
- "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
- )
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- ) -> torch.Tensor:
- seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
- max_position_embedding = self.position_embedding.weight.shape[0]
- if seq_length > max_position_embedding:
- raise ValueError(
- f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
- f"{seq_length} and max_position_embeddings: {max_position_embedding}"
- )
- if position_ids is None:
- position_ids = self.position_ids[:, :seq_length]
- if inputs_embeds is None:
- inputs_embeds = self.token_embedding(input_ids)
- position_embeddings = self.position_embedding(position_ids)
- embeddings = inputs_embeds + position_embeddings
- return embeddings
- class GroupViTStage(nn.Module):
- """This corresponds to the `GroupingLayer` class in the GroupViT implementation."""
- def __init__(
- self,
- config: GroupViTVisionConfig,
- depth: int,
- num_prev_group_token: int,
- num_group_token: int,
- num_output_group: int,
- ):
- super().__init__()
- self.depth = depth
- self.num_group_token = num_group_token
- if num_group_token > 0:
- self.group_token = nn.Parameter(torch.zeros(1, num_group_token, config.hidden_size))
- else:
- self.group_token = None
- self.layers = nn.ModuleList([GroupViTEncoderLayer(config) for _ in range(depth)])
- if num_group_token > 0:
- self.downsample = GroupViTTokenAssign(
- config=config,
- num_group_token=num_group_token,
- num_output_group=num_output_group,
- )
- else:
- self.downsample = None
- if num_prev_group_token > 0 and num_group_token > 0:
- self.group_projector = nn.Sequential(
- nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps),
- GroupViTMixerMLP(config, num_prev_group_token, config.hidden_size // 2, num_group_token),
- )
- else:
- self.group_projector = None
- @property
- def with_group_token(self):
- return self.group_token is not None
- def split_x(self, x):
- if self.with_group_token:
- return x[:, : -self.num_group_token], x[:, -self.num_group_token :]
- else:
- return x, None
- def concat_x(self, x: torch.Tensor, group_token: torch.Tensor | None = None) -> torch.Tensor:
- if group_token is None:
- return x
- return torch.cat([x, group_token], dim=1)
- def forward(
- self,
- hidden_states: torch.Tensor,
- prev_group_token: torch.Tensor | None = None,
- output_attentions: bool | None = False,
- ) -> tuple[torch.FloatTensor]:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- attention_mask (`torch.FloatTensor`): attention mask of size
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
- `(config.encoder_attention_heads,)`.
- output_attentions (`bool`, *optional*):
- Whether or not to return the grouping tensors of Grouping block.
- """
- if self.with_group_token:
- group_token = self.group_token.expand(hidden_states.size(0), -1, -1)
- if self.group_projector is not None:
- group_token = group_token + self.group_projector(prev_group_token)
- else:
- group_token = None
- x = hidden_states
- cat_x = self.concat_x(x, group_token)
- for layer in self.layers:
- cat_x = layer(cat_x, attention_mask=None)
- x, group_token = self.split_x(cat_x)
- attention = None
- if self.downsample is not None:
- x, attention = self.downsample(x, group_token)
- outputs = (x, group_token)
- if output_attentions:
- outputs = outputs + (attention,)
- return outputs
- class GroupViTMLP(nn.Module):
- def __init__(
- self,
- config: GroupViTVisionConfig,
- hidden_size: int | None = None,
- intermediate_size: int | None = None,
- output_size: int | None = None,
- ):
- super().__init__()
- self.config = config
- self.activation_fn = ACT2FN[config.hidden_act]
- hidden_size = hidden_size if hidden_size is not None else config.hidden_size
- intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
- output_size = output_size if output_size is not None else hidden_size
- self.fc1 = nn.Linear(hidden_size, intermediate_size)
- self.fc2 = nn.Linear(intermediate_size, output_size)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.fc1(hidden_states)
- hidden_states = self.activation_fn(hidden_states)
- hidden_states = self.fc2(hidden_states)
- return hidden_states
- class GroupViTMixerMLP(GroupViTMLP):
- def forward(self, x):
- x = super().forward(x.transpose(1, 2))
- return x.transpose(1, 2)
- class GroupViTAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.embed_dim = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = self.embed_dim // self.num_heads
- if self.head_dim * self.num_heads != self.embed_dim:
- raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
- f" {self.num_heads})."
- )
- self.scale = self.head_dim**-0.5
- self.dropout = config.attention_dropout
- self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
- self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
- self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
- self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- encoder_hidden_states: torch.FloatTensor | None = None,
- **kwargs,
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
- """Input shape: Batch x Time x Channel"""
- bsz, tgt_len, embed_dim = hidden_states.size()
- is_cross_attention = encoder_hidden_states is not None
- # get query proj
- query_states = self.q_proj(hidden_states) * self.scale
- if is_cross_attention:
- key_states = self._shape(self.k_proj(encoder_hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(encoder_hidden_states), -1, bsz)
- else:
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
- key_states = key_states.view(*proj_shape)
- value_states = value_states.view(*proj_shape)
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, tgt_len, src_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
- )
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
- attn_output = torch.bmm(attn_probs, value_states)
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
- attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
- attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights_reshaped
- # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->GroupViT
- class GroupViTEncoderLayer(GradientCheckpointingLayer):
- def __init__(self, config: GroupViTConfig):
- super().__init__()
- self.embed_dim = config.hidden_size
- self.self_attn = GroupViTAttention(config)
- self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
- self.mlp = GroupViTMLP(config)
- self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.FloatTensor, torch.Tensor | None]:
- residual = hidden_states
- hidden_states = self.layer_norm1(hidden_states)
- hidden_states, _ = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- **kwargs,
- )
- hidden_states = residual + hidden_states
- residual = hidden_states
- hidden_states = self.layer_norm2(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
- return hidden_states
- @auto_docstring
- class GroupViTPreTrainedModel(PreTrainedModel):
- config: GroupViTConfig
- base_model_prefix = "groupvit"
- input_modalities = ("image", "text")
- supports_gradient_checkpointing = True
- _can_record_outputs = {
- "hidden_states": GroupViTEncoderLayer,
- "attentions": GroupViTAttention,
- }
- @torch.no_grad()
- def _init_weights(self, module):
- """Initialize the weights"""
- init_range = self.config.initializer_range
- if isinstance(module, (nn.Linear, nn.Conv2d)):
- init.normal_(module.weight, mean=0.0, std=init_range)
- if module.bias is not None:
- init.zeros_(module.bias)
- elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
- init.zeros_(module.bias)
- init.ones_(module.weight)
- if getattr(module, "running_mean", None) is not None:
- init.zeros_(module.running_mean)
- init.ones_(module.running_var)
- init.zeros_(module.num_batches_tracked)
- factor = self.config.initializer_factor
- if isinstance(module, GroupViTTextEmbeddings):
- init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02)
- init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02)
- init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
- elif isinstance(module, GroupViTAttention):
- factor = self.config.initializer_factor
- in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
- out_proj_std = (module.embed_dim**-0.5) * factor
- init.normal_(module.q_proj.weight, std=in_proj_std)
- init.normal_(module.k_proj.weight, std=in_proj_std)
- init.normal_(module.v_proj.weight, std=in_proj_std)
- init.normal_(module.out_proj.weight, std=out_proj_std)
- elif isinstance(module, GroupViTMLP):
- factor = self.config.initializer_factor
- in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
- fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
- init.normal_(module.fc1.weight, std=fc_std)
- init.normal_(module.fc2.weight, std=in_proj_std)
- class GroupViTVisionEncoder(nn.Module):
- def __init__(self, config: GroupViTVisionConfig) -> None:
- super().__init__()
- self.config = config
- self.stages = nn.ModuleList(
- [
- GroupViTStage(
- config=config,
- depth=config.depths[i],
- num_group_token=config.num_group_tokens[i],
- num_output_group=config.num_output_groups[i],
- num_prev_group_token=config.num_output_groups[i - 1] if i > 0 else 0,
- )
- for i in range(len(config.depths))
- ]
- )
- self.gradient_checkpointing = False
- def forward(
- self,
- hidden_states: torch.Tensor,
- output_hidden_states: bool | None = None,
- output_attentions: bool | None = None,
- return_dict: bool | None = None,
- ) -> tuple | BaseModelOutput:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- all_hidden_states = () if output_hidden_states else None
- all_groupings = () if output_attentions else None
- group_tokens = None
- for i, stage in enumerate(self.stages):
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- layer_outputs = stage(hidden_states, group_tokens, output_attentions)
- hidden_states = layer_outputs[0]
- group_tokens = layer_outputs[1]
- if output_attentions and layer_outputs[2] is not None:
- all_groupings = all_groupings + (layer_outputs[2],)
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- if not return_dict:
- return tuple(v for v in [hidden_states, all_hidden_states, all_groupings] if v is not None)
- return BaseModelOutput(
- last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_groupings
- )
- class GroupViTTextEncoder(nn.Module):
- """
- Transformer encoder consisting of `config.num_hidden_layers` self-attention layers. Each layer is a
- [`GroupViTEncoderLayer`].
- Args:
- config: GroupViTTextConfig
- """
- def __init__(self, config: GroupViTTextConfig):
- super().__init__()
- self.config = config
- self.layers = nn.ModuleList([GroupViTEncoderLayer(config) for _ in range(config.num_hidden_layers)])
- self.gradient_checkpointing = False
- def forward(
- self,
- inputs_embeds,
- attention_mask: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BaseModelOutput:
- r"""
- Args:
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
- This is useful if you want more control over how to convert `input_ids` indices into associated vectors
- than the model's internal embedding lookup matrix.
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- """
- hidden_states = inputs_embeds
- for encoder_layer in self.layers:
- hidden_states = encoder_layer(
- hidden_states,
- attention_mask,
- **kwargs,
- )
- return BaseModelOutput(
- last_hidden_state=hidden_states,
- )
- class GroupViTTextTransformer(GroupViTPreTrainedModel):
- def __init__(self, config: GroupViTTextConfig):
- super().__init__(config)
- embed_dim = config.hidden_size
- self.embeddings = GroupViTTextEmbeddings(config)
- self.encoder = GroupViTTextEncoder(config)
- self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
- # For `pooled_output` computation
- self.eos_token_id = config.eos_token_id
- self.post_init()
- @merge_with_config_defaults
- @capture_outputs(tie_last_hidden_states=False)
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutputWithPooling:
- if input_ids is None:
- raise ValueError("You have to specify input_ids")
- input_shape = input_ids.size()
- input_ids = input_ids.view(-1, input_shape[-1])
- hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
- attention_mask = create_causal_mask(
- config=self.config,
- inputs_embeds=hidden_states,
- attention_mask=attention_mask,
- past_key_values=None,
- )
- kwargs.pop("is_causal", None)
- encoder_outputs: BaseModelOutput = self.encoder(
- inputs_embeds=hidden_states,
- attention_mask=attention_mask,
- is_causal=True,
- **kwargs,
- )
- last_hidden_state = encoder_outputs[0]
- last_hidden_state = self.final_layer_norm(last_hidden_state)
- if self.eos_token_id == 2:
- # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
- # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
- # ------------------------------------------------------------
- # text_embeds.shape = [batch_size, sequence_length, transformer.width]
- # take features from the eot embedding (eot_token is the highest number in each sequence)
- # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
- pooled_output = last_hidden_state[
- torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
- input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
- ]
- else:
- # The config gets updated `eos_token_id` from PR #24773 (so the use of extra new tokens is possible)
- pooled_output = last_hidden_state[
- torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
- # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
- # Note: we assume each sequence (along batch dim.) contains an `eos_token_id` (e.g. prepared by the tokenizer)
- (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id)
- .int()
- .argmax(dim=-1),
- ]
- return BaseModelOutputWithPooling(
- last_hidden_state=last_hidden_state,
- pooler_output=pooled_output,
- )
- class GroupViTTextModel(GroupViTPreTrainedModel):
- config: GroupViTTextConfig
- input_modalities = ("text",)
- def __init__(self, config: GroupViTTextConfig):
- super().__init__(config)
- self.text_model = GroupViTTextTransformer(config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self) -> nn.Module:
- return self.text_model.embeddings.token_embedding
- def set_input_embeddings(self, value):
- self.text_model.embeddings.token_embedding = value
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BaseModelOutputWithPooling:
- r"""
- Examples:
- ```python
- >>> from transformers import CLIPTokenizer, GroupViTTextModel
- >>> tokenizer = CLIPTokenizer.from_pretrained("nvidia/groupvit-gcc-yfcc")
- >>> model = GroupViTTextModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
- >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
- >>> outputs = model(**inputs)
- >>> last_hidden_state = outputs.last_hidden_state
- >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
- ```"""
- return self.text_model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- **kwargs,
- )
- class GroupViTVisionTransformer(nn.Module):
- def __init__(self, config: GroupViTVisionConfig):
- super().__init__()
- self.config = config
- embed_dim = config.hidden_size
- self.embeddings = GroupViTVisionEmbeddings(config)
- self.encoder = GroupViTVisionEncoder(config)
- self.layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
- @auto_docstring
- def forward(
- self,
- pixel_values: torch.FloatTensor | None = None,
- output_hidden_states: bool | None = None,
- output_attentions: bool | None = None,
- return_dict: bool | None = None,
- ) -> tuple | BaseModelOutputWithPooling:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- if pixel_values is None:
- raise ValueError("You have to specify pixel_values")
- hidden_states = self.embeddings(pixel_values)
- encoder_outputs = self.encoder(
- hidden_states=hidden_states,
- output_hidden_states=output_hidden_states,
- output_attentions=output_attentions,
- return_dict=return_dict,
- )
- last_hidden_state = encoder_outputs[0]
- # normalize the last hidden state
- last_hidden_state = self.layernorm(last_hidden_state)
- pooled_output = last_hidden_state.mean(dim=1)
- if not return_dict:
- return (last_hidden_state, pooled_output) + encoder_outputs[1:]
- return BaseModelOutputWithPooling(
- last_hidden_state=last_hidden_state,
- pooler_output=pooled_output,
- hidden_states=encoder_outputs.hidden_states,
- attentions=encoder_outputs.attentions,
- )
- class GroupViTVisionModel(GroupViTPreTrainedModel):
- config: GroupViTVisionConfig
- main_input_name = "pixel_values"
- input_modalities = ("image",)
- _can_record_outputs = {}
- def __init__(self, config: GroupViTVisionConfig):
- super().__init__(config)
- self.vision_model = GroupViTVisionTransformer(config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self) -> GroupViTPatchEmbeddings:
- return self.vision_model.embeddings.patch_embeddings
- @auto_docstring
- def forward(
- self,
- pixel_values: torch.FloatTensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | BaseModelOutputWithPooling:
- r"""
- Examples:
- ```python
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> from transformers import AutoProcessor, GroupViTVisionModel
- >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")
- >>> model = GroupViTVisionModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> inputs = processor(images=image, return_tensors="pt")
- >>> outputs = model(**inputs)
- >>> last_hidden_state = outputs.last_hidden_state
- >>> pooled_output = outputs.pooler_output # pooled CLS states
- ```"""
- return self.vision_model(
- pixel_values=pixel_values,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- @auto_docstring
- class GroupViTModel(GroupViTPreTrainedModel):
- config: GroupViTConfig
- def __init__(self, config: GroupViTConfig):
- super().__init__(config)
- if not isinstance(config.text_config, GroupViTTextConfig):
- raise TypeError(
- "config.text_config is expected to be of type GroupViTTextConfig but is of type"
- f" {type(config.text_config)}."
- )
- if not isinstance(config.vision_config, GroupViTVisionConfig):
- raise TypeError(
- "config.vision_config is expected to be of type GroupViTVisionConfig but is of type"
- f" {type(config.vision_config)}."
- )
- text_config = config.text_config
- vision_config = config.vision_config
- self.projection_dim = config.projection_dim
- self.projection_intermediate_dim = config.projection_intermediate_dim
- self.text_embed_dim = text_config.hidden_size
- self.vision_embed_dim = vision_config.hidden_size
- self.text_model = GroupViTTextTransformer(text_config)
- self.vision_model = GroupViTVisionTransformer(vision_config)
- self.visual_projection = nn.Sequential(
- nn.Linear(self.vision_embed_dim, self.projection_intermediate_dim, bias=True),
- nn.BatchNorm1d(self.projection_intermediate_dim),
- nn.ReLU(inplace=True),
- nn.Linear(self.projection_intermediate_dim, self.projection_dim, bias=True),
- )
- self.text_projection = nn.Sequential(
- nn.Linear(self.text_embed_dim, self.projection_intermediate_dim, bias=True),
- nn.BatchNorm1d(self.projection_intermediate_dim),
- nn.ReLU(inplace=True),
- nn.Linear(self.projection_intermediate_dim, self.projection_dim, bias=True),
- )
- self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
- # Initialize weights and apply final processing
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def get_text_features(
- self,
- input_ids: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BaseModelOutputWithPooling:
- r"""
- Examples:
- ```python
- >>> import torch
- >>> from transformers import CLIPTokenizer, GroupViTModel
- >>> model = GroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
- >>> tokenizer = CLIPTokenizer.from_pretrained("nvidia/groupvit-gcc-yfcc")
- >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
- >>> with torch.inference_mode():
- ... text_features = model.get_text_features(**inputs)
- ```"""
- text_outputs: BaseModelOutputWithPooling = self.text_model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- return_dict=True,
- **kwargs,
- )
- pooled_output = text_outputs.pooler_output
- text_outputs.pooler_output = self.text_projection(pooled_output)
- return text_outputs
- @can_return_tuple
- @auto_docstring
- def get_image_features(
- self,
- pixel_values: torch.Tensor,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BaseModelOutputWithPooling:
- r"""
- Examples:
- ```python
- >>> import torch
- >>> from transformers import AutoProcessor, GroupViTModel
- >>> from transformers.image_utils import load_image
- >>> model = GroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
- >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> image = load_image(url)
- >>> inputs = processor(images=image, return_tensors="pt")
- >>> with torch.inference_mode():
- ... image_features = model.get_image_features(**inputs)
- ```"""
- vision_outputs: BaseModelOutputWithPooling = self.vision_model(pixel_values, return_dict=True, **kwargs)
- vision_outputs.pooler_output = self.visual_projection(vision_outputs.pooler_output)
- return vision_outputs
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- pixel_values: torch.FloatTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- return_loss: bool | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- output_segmentation: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | GroupViTModelOutput:
- r"""
- return_loss (`bool`, *optional*):
- Whether or not to return the contrastive loss.
- output_segmentation (`bool`, *optional*):
- Whether or not to return the segmentation logits.
- Examples:
- ```python
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> from transformers import AutoProcessor, GroupViTModel
- >>> model = GroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
- >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> inputs = processor(
- ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
- ... )
- >>> outputs = model(**inputs)
- >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
- >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
- ```"""
- # Use GROUPVIT model's config for some fields (if specified) instead of those of vision & text components.
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_segmentation = (
- output_segmentation if output_segmentation is not None else self.config.output_segmentation
- )
- if output_segmentation:
- output_attentions = True
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- # Vision side uses explicit flags (nn.Module-based, not hook-based)
- vision_outputs = self.vision_model(
- pixel_values=pixel_values,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=True,
- )
- text_outputs: BaseModelOutputWithPooling = self.text_model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- **kwargs,
- )
- image_embeds = vision_outputs.pooler_output
- image_embeds = self.visual_projection(image_embeds)
- text_embeds = text_outputs.pooler_output
- text_embeds = self.text_projection(text_embeds)
- # normalized features
- image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
- text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
- # cosine similarity as logits
- logit_scale = self.logit_scale.exp()
- logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
- logits_per_image = logits_per_text.t()
- seg_logits = None
- if output_segmentation:
- # grouped features
- # [batch_size_image, num_group, hidden_size]
- image_group_embeds = vision_outputs.last_hidden_state
- # [batch_size_image*num_group, hidden_size]
- image_group_embeds = self.visual_projection(image_group_embeds.reshape(-1, image_group_embeds.shape[-1]))
- attentions = vision_outputs.attentions
- # [batch_size_image, num_group, height, width]
- grouping = get_grouping_from_attentions(attentions, pixel_values.shape[2:])
- # normalized features
- image_group_embeds = image_group_embeds / image_group_embeds.norm(dim=-1, keepdim=True)
- # [batch_size_image x num_group, batch_size_text]
- logits_per_image_group = torch.matmul(image_group_embeds, text_embeds.t()) * logit_scale
- # [batch_size_image, batch_size_text, num_group]
- logits_per_image_group = logits_per_image_group.reshape(
- image_embeds.shape[0], -1, text_embeds.shape[0]
- ).permute(0, 2, 1)
- # [batch_size_image, batch_size_text, height x width]
- flatten_grouping = grouping.reshape(grouping.shape[0], grouping.shape[1], -1)
- # [batch_size_image, batch_size_text, height, width]
- seg_logits = torch.matmul(logits_per_image_group, flatten_grouping) * logit_scale
- seg_logits = seg_logits.reshape(
- seg_logits.shape[0], seg_logits.shape[1], grouping.shape[2], grouping.shape[3]
- )
- loss = None
- if return_loss:
- loss = groupvit_loss(logits_per_text)
- return GroupViTModelOutput(
- loss=loss,
- logits_per_image=logits_per_image,
- logits_per_text=logits_per_text,
- segmentation_logits=seg_logits,
- text_embeds=text_embeds,
- image_embeds=image_embeds,
- text_model_output=text_outputs,
- vision_model_output=vision_outputs,
- )
- __all__ = ["GroupViTModel", "GroupViTPreTrainedModel", "GroupViTTextModel", "GroupViTVisionModel"]
|