modeling_git.py 49 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233
  1. # Copyright 2022 Microsoft Research and The HuggingFace Inc. team.
  2. # All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch GIT model."""
  16. import math
  17. from collections.abc import Callable
  18. from dataclasses import dataclass
  19. import torch
  20. from torch import nn
  21. from ... import initialization as init
  22. from ...activations import ACT2FN
  23. from ...cache_utils import Cache, DynamicCache
  24. from ...configuration_utils import PreTrainedConfig
  25. from ...generation import GenerationMixin
  26. from ...masking_utils import create_masks_for_generate
  27. from ...modeling_layers import GradientCheckpointingLayer
  28. from ...modeling_outputs import (
  29. BaseModelOutput,
  30. BaseModelOutputWithPast,
  31. BaseModelOutputWithPooling,
  32. CausalLMOutputWithPast,
  33. )
  34. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  35. from ...processing_utils import Unpack
  36. from ...pytorch_utils import apply_chunking_to_forward
  37. from ...utils import (
  38. ModelOutput,
  39. TransformersKwargs,
  40. auto_docstring,
  41. logging,
  42. torch_int,
  43. )
  44. from ...utils.deprecation import deprecate_kwarg
  45. from ...utils.generic import merge_with_config_defaults
  46. from ...utils.output_capturing import capture_outputs
  47. from .configuration_git import GitConfig, GitVisionConfig
  48. logger = logging.get_logger(__name__)
  49. @dataclass
  50. @auto_docstring(
  51. custom_intro="""
  52. Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
  53. """
  54. )
  55. # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Git
  56. class GitVisionModelOutput(ModelOutput):
  57. r"""
  58. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  59. The image embeddings obtained by applying the projection layer to the pooler_output.
  60. """
  61. image_embeds: torch.FloatTensor | None = None
  62. last_hidden_state: torch.FloatTensor | None = None
  63. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  64. attentions: tuple[torch.FloatTensor, ...] | None = None
  65. # Copied from transformers.models.gemma3.modeling_gemma3.token_type_ids_mask_function
  66. def token_type_ids_mask_function(group_ids: torch.Tensor) -> Callable:
  67. """
  68. This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
  69. not start and end indices.
  70. Args:
  71. group_ids (`torch.Tensor`):
  72. A tensor of shape `(bs, len)` assigning each token to a vision group. Tokens with the same group
  73. come from the same input image. Text is denoted by `-1`.
  74. """
  75. def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
  76. seq_length = group_ids.shape[-1]
  77. # clamp indices because with static cache they can go beyond `group_ids.shape[-1]`
  78. q_idx_clamped = q_idx.clamp(max=seq_length - 1)
  79. kv_idx_clamped = kv_idx.clamp(max=seq_length - 1)
  80. # Unmask if the q and kv come from same group which is not -1 (i.e. non-text)
  81. q_group = group_ids[batch_idx, q_idx_clamped]
  82. kv_group = group_ids[batch_idx, kv_idx_clamped]
  83. q_group = torch.where(q_idx < seq_length, q_group, -1)
  84. kv_group = torch.where(kv_idx < seq_length, kv_group, -1)
  85. return (q_group == kv_group) & (q_group >= 0)
  86. return inner_mask
  87. @deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
  88. # Copied from transformers.models.gemma3.modeling_gemma3.create_causal_mask_mapping
  89. def create_causal_mask_mapping(
  90. config: PreTrainedConfig,
  91. inputs_embeds: torch.Tensor,
  92. attention_mask: torch.Tensor | None,
  93. past_key_values: Cache | None,
  94. position_ids: torch.Tensor | None,
  95. token_type_ids: torch.Tensor | None = None,
  96. pixel_values: torch.FloatTensor | None = None,
  97. is_training: bool = False,
  98. is_first_iteration: bool | None = None,
  99. **kwargs,
  100. ) -> dict:
  101. """
  102. Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping
  103. for all kinds of forward passes. Gemma3 uses a bidirectional mask for images.
  104. Uses `pixel_values` as an optional input to disambiguate edge cases.
  105. """
  106. if is_training and token_type_ids is None:
  107. raise ValueError("`token_type_ids` is required as a model input when training")
  108. mask_kwargs = {
  109. "config": config.get_text_config(),
  110. "inputs_embeds": inputs_embeds,
  111. "attention_mask": attention_mask,
  112. "past_key_values": past_key_values,
  113. "position_ids": position_ids,
  114. }
  115. # NOTE: this `may_have_image_input` logic is not flawless, it fails when we're using a cache eagerly initialized
  116. # (e.g. compiled prefill) AND `pixel_values` are not provided (i.e. the image data is provided through other
  117. # means). Determining prefill in that case requires checking data values, which is not compile-compatible.
  118. is_first_iteration = (
  119. is_first_iteration
  120. if is_first_iteration is not None
  121. else (past_key_values is None or not past_key_values.is_initialized or pixel_values is not None)
  122. )
  123. if token_type_ids is not None and is_first_iteration:
  124. # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` (to
  125. # undo the causal masking)
  126. # First find where a new image block starts: 1 if image and previous not image
  127. # The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
  128. is_image = (token_type_ids == 1).to(inputs_embeds.device)
  129. is_previous_image = nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
  130. new_image_start = is_image & ~is_previous_image
  131. group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
  132. group_ids = torch.where(is_image, group_ids, -1)
  133. mask_kwargs["or_mask_function"] = token_type_ids_mask_function(group_ids)
  134. return create_masks_for_generate(**mask_kwargs)
  135. class GitEmbeddings(nn.Module):
  136. """Construct the embeddings from word and position embeddings."""
  137. def __init__(self, config):
  138. super().__init__()
  139. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  140. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  141. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  142. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  143. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  144. self.register_buffer(
  145. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  146. )
  147. def forward(
  148. self,
  149. input_ids: torch.LongTensor | None = None,
  150. position_ids: torch.LongTensor | None = None,
  151. inputs_embeds: torch.FloatTensor | None = None,
  152. past_key_values_length: int = 0,
  153. ) -> torch.Tensor:
  154. if input_ids is not None:
  155. input_shape = input_ids.size()
  156. else:
  157. input_shape = inputs_embeds.size()[:-1]
  158. seq_length = input_shape[1]
  159. if position_ids is None:
  160. position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
  161. if inputs_embeds is None:
  162. embeddings = self.word_embeddings(input_ids)
  163. else:
  164. embeddings = inputs_embeds
  165. position_embeddings = self.position_embeddings(position_ids)
  166. embeddings += position_embeddings
  167. embeddings = self.LayerNorm(embeddings)
  168. embeddings = self.dropout(embeddings)
  169. return embeddings
  170. class GitSelfAttention(nn.Module):
  171. def __init__(self, config, layer_idx=None):
  172. super().__init__()
  173. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  174. raise ValueError(
  175. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  176. f"heads ({config.num_attention_heads})"
  177. )
  178. self.layer_idx = layer_idx
  179. if layer_idx is None:
  180. logger.warning_once(
  181. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  182. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  183. "when creating this class."
  184. )
  185. self.num_attention_heads = config.num_attention_heads
  186. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  187. self.all_head_size = self.num_attention_heads * self.attention_head_size
  188. self.image_patch_tokens = int((config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1)
  189. if config.num_image_with_embedding is not None:
  190. self.image_patch_tokens *= config.num_image_with_embedding
  191. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  192. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  193. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  194. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  195. def forward(
  196. self,
  197. hidden_states: torch.Tensor,
  198. attention_mask: torch.FloatTensor | None = None,
  199. past_key_values: Cache | None = None,
  200. **kwargs: Unpack[TransformersKwargs],
  201. ) -> tuple[torch.Tensor]:
  202. input_shape = hidden_states.shape[:-1]
  203. hidden_shape = (*input_shape, -1, self.attention_head_size)
  204. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  205. key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
  206. value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
  207. if past_key_values is not None:
  208. key_layer, value_layer = past_key_values.update(key_layer, value_layer, self.layer_idx)
  209. # Take the dot product between "query" and "key" to get the raw attention scores.
  210. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  211. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  212. if attention_mask is not None:
  213. # Apply the attention mask is (precomputed for all layers in GitModel forward() function)
  214. attention_scores = attention_scores + attention_mask
  215. # Normalize the attention scores to probabilities.
  216. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  217. # This is actually dropping out entire tokens to attend to, which might
  218. # seem a bit unusual, but is taken from the original Transformer paper.
  219. attention_probs = self.dropout(attention_probs)
  220. context_layer = torch.matmul(attention_probs, value_layer)
  221. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  222. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  223. context_layer = context_layer.view(new_context_layer_shape)
  224. return context_layer, attention_probs
  225. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
  226. class GitSelfOutput(nn.Module):
  227. def __init__(self, config):
  228. super().__init__()
  229. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  230. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  231. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  232. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  233. hidden_states = self.dense(hidden_states)
  234. hidden_states = self.dropout(hidden_states)
  235. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  236. return hidden_states
  237. GIT_SELF_ATTENTION_CLASSES = {
  238. "eager": GitSelfAttention,
  239. }
  240. class GitAttention(nn.Module):
  241. def __init__(self, config, layer_idx=None):
  242. super().__init__()
  243. self.self = GIT_SELF_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
  244. self.output = GitSelfOutput(config)
  245. def forward(
  246. self,
  247. hidden_states: torch.Tensor,
  248. attention_mask: torch.FloatTensor | None = None,
  249. past_key_values: Cache | None = None,
  250. **kwargs: Unpack[TransformersKwargs],
  251. ) -> tuple[torch.Tensor]:
  252. attn_output, _ = self.self(
  253. hidden_states,
  254. attention_mask,
  255. past_key_values,
  256. **kwargs,
  257. )
  258. attention_output = self.output(attn_output, hidden_states)
  259. return attention_output
  260. # Copied from transformers.models.bert.modeling_bert.BertIntermediate
  261. class GitIntermediate(nn.Module):
  262. def __init__(self, config):
  263. super().__init__()
  264. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  265. if isinstance(config.hidden_act, str):
  266. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  267. else:
  268. self.intermediate_act_fn = config.hidden_act
  269. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  270. hidden_states = self.dense(hidden_states)
  271. hidden_states = self.intermediate_act_fn(hidden_states)
  272. return hidden_states
  273. # Copied from transformers.models.bert.modeling_bert.BertOutput
  274. class GitOutput(nn.Module):
  275. def __init__(self, config):
  276. super().__init__()
  277. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  278. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  279. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  280. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  281. hidden_states = self.dense(hidden_states)
  282. hidden_states = self.dropout(hidden_states)
  283. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  284. return hidden_states
  285. class GitLayer(GradientCheckpointingLayer):
  286. def __init__(self, config, layer_idx=None):
  287. super().__init__()
  288. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  289. self.seq_len_dim = 1
  290. self.attention = GitAttention(config, layer_idx=layer_idx)
  291. self.intermediate = GitIntermediate(config)
  292. self.output = GitOutput(config)
  293. def forward(
  294. self,
  295. hidden_states: torch.Tensor,
  296. attention_mask: torch.FloatTensor | None = None,
  297. past_key_values: Cache | None = None,
  298. **kwargs: Unpack[TransformersKwargs],
  299. ) -> tuple[torch.Tensor]:
  300. attention_output = self.attention(
  301. hidden_states,
  302. attention_mask,
  303. past_key_values=past_key_values,
  304. **kwargs,
  305. )
  306. layer_output = apply_chunking_to_forward(
  307. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  308. )
  309. return layer_output
  310. def feed_forward_chunk(self, attention_output):
  311. intermediate_output = self.intermediate(attention_output)
  312. layer_output = self.output(intermediate_output, attention_output)
  313. return layer_output
  314. class GitEncoder(nn.Module):
  315. def __init__(self, config):
  316. super().__init__()
  317. self.config = config
  318. self.layer = nn.ModuleList([GitLayer(config, i) for i in range(config.num_hidden_layers)])
  319. self.gradient_checkpointing = False
  320. def forward(
  321. self,
  322. hidden_states: torch.Tensor,
  323. attention_mask: torch.FloatTensor | None = None,
  324. past_key_values: Cache | None = None,
  325. use_cache: bool | None = None,
  326. **kwargs: Unpack[TransformersKwargs],
  327. ) -> BaseModelOutputWithPast:
  328. for layer_module in self.layer:
  329. hidden_states = layer_module(
  330. hidden_states,
  331. attention_mask,
  332. past_key_values,
  333. **kwargs,
  334. )
  335. return BaseModelOutputWithPast(
  336. last_hidden_state=hidden_states,
  337. past_key_values=past_key_values,
  338. )
  339. @auto_docstring
  340. class GitPreTrainedModel(PreTrainedModel):
  341. config: GitConfig
  342. base_model_prefix = "git"
  343. input_modalities = ("image", "text")
  344. supports_gradient_checkpointing = True
  345. @torch.no_grad()
  346. def _init_weights(self, module):
  347. """Initialize the weights"""
  348. if isinstance(module, GitVisionEmbeddings):
  349. init.normal_(module.class_embedding, mean=0.0, std=self.config.initializer_range)
  350. init.normal_(module.patch_embedding.weight, std=self.config.initializer_range)
  351. init.normal_(module.position_embedding.weight, std=self.config.initializer_range)
  352. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  353. if isinstance(module, nn.Linear):
  354. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  355. if module.bias is not None:
  356. init.zeros_(module.bias)
  357. elif isinstance(module, nn.Embedding):
  358. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  359. # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
  360. if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
  361. init.zeros_(module.weight[module.padding_idx])
  362. elif isinstance(module, nn.LayerNorm):
  363. init.zeros_(module.bias)
  364. init.ones_(module.weight)
  365. elif isinstance(module, GitEmbeddings):
  366. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  367. # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Git
  368. class GitVisionEmbeddings(nn.Module):
  369. def __init__(self, config: GitVisionConfig):
  370. super().__init__()
  371. self.config = config
  372. self.embed_dim = config.hidden_size
  373. self.image_size = config.image_size
  374. self.patch_size = config.patch_size
  375. self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
  376. self.patch_embedding = nn.Conv2d(
  377. in_channels=config.num_channels,
  378. out_channels=self.embed_dim,
  379. kernel_size=self.patch_size,
  380. stride=self.patch_size,
  381. bias=False,
  382. )
  383. self.num_patches = (self.image_size // self.patch_size) ** 2
  384. self.num_positions = self.num_patches + 1
  385. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  386. self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
  387. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  388. """
  389. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  390. images. This method is also adapted to support torch.jit tracing.
  391. Adapted from:
  392. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  393. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  394. """
  395. num_patches = embeddings.shape[1] - 1
  396. position_embedding = self.position_embedding.weight.unsqueeze(0)
  397. num_positions = position_embedding.shape[1] - 1
  398. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  399. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  400. return self.position_embedding(self.position_ids)
  401. class_pos_embed = position_embedding[:, :1]
  402. patch_pos_embed = position_embedding[:, 1:]
  403. dim = embeddings.shape[-1]
  404. new_height = height // self.patch_size
  405. new_width = width // self.patch_size
  406. sqrt_num_positions = torch_int(num_positions**0.5)
  407. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  408. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  409. patch_pos_embed = nn.functional.interpolate(
  410. patch_pos_embed,
  411. size=(new_height, new_width),
  412. mode="bicubic",
  413. align_corners=False,
  414. )
  415. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  416. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  417. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
  418. batch_size, _, height, width = pixel_values.shape
  419. if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
  420. raise ValueError(
  421. f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})."
  422. )
  423. target_dtype = self.patch_embedding.weight.dtype
  424. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
  425. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  426. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  427. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  428. if interpolate_pos_encoding:
  429. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  430. else:
  431. embeddings = embeddings + self.position_embedding(self.position_ids)
  432. return embeddings
  433. class GitVisionMLP(nn.Module):
  434. def __init__(self, config):
  435. super().__init__()
  436. self.config = config
  437. self.activation_fn = ACT2FN[config.hidden_act]
  438. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  439. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  440. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  441. hidden_states = self.fc1(hidden_states)
  442. hidden_states = self.activation_fn(hidden_states)
  443. hidden_states = self.fc2(hidden_states)
  444. return hidden_states
  445. # Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
  446. def eager_attention_forward(
  447. module: nn.Module,
  448. query: torch.Tensor,
  449. key: torch.Tensor,
  450. value: torch.Tensor,
  451. attention_mask: torch.Tensor | None,
  452. scaling: float,
  453. dropout: float = 0.0,
  454. **kwargs,
  455. ):
  456. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  457. if attention_mask is not None:
  458. attn_weights = attn_weights + attention_mask
  459. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  460. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  461. attn_output = torch.matmul(attn_weights, value)
  462. attn_output = attn_output.transpose(1, 2).contiguous()
  463. return attn_output, attn_weights
  464. class GitVisionAttention(nn.Module):
  465. """Multi-headed attention from 'Attention Is All You Need' paper"""
  466. def __init__(self, config):
  467. super().__init__()
  468. self.config = config
  469. self.embed_dim = config.hidden_size
  470. self.num_heads = config.num_attention_heads
  471. self.head_dim = self.embed_dim // self.num_heads
  472. if self.head_dim * self.num_heads != self.embed_dim:
  473. raise ValueError(
  474. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  475. f" {self.num_heads})."
  476. )
  477. self.scale = self.head_dim**-0.5
  478. self.dropout = config.attention_dropout
  479. self.is_causal = False
  480. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  481. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  482. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  483. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  484. def forward(
  485. self,
  486. hidden_states: torch.Tensor,
  487. attention_mask: torch.Tensor | None = None,
  488. **kwargs: Unpack[TransformersKwargs],
  489. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  490. """Input shape: Batch x Time x Channel"""
  491. input_shape = hidden_states.shape[:-1]
  492. hidden_shape = (*input_shape, -1, self.head_dim)
  493. queries = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  494. keys = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  495. values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  496. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  497. self.config._attn_implementation, eager_attention_forward
  498. )
  499. attn_output, attn_weights = attention_interface(
  500. self,
  501. queries,
  502. keys,
  503. values,
  504. attention_mask,
  505. is_causal=self.is_causal,
  506. scaling=self.scale,
  507. dropout=0.0 if not self.training else self.dropout,
  508. **kwargs,
  509. )
  510. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  511. attn_output = self.out_proj(attn_output)
  512. return attn_output, attn_weights
  513. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->GitVision
  514. class GitVisionEncoderLayer(GradientCheckpointingLayer):
  515. def __init__(self, config: GitVisionConfig):
  516. super().__init__()
  517. self.embed_dim = config.hidden_size
  518. self.self_attn = GitVisionAttention(config)
  519. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  520. self.mlp = GitVisionMLP(config)
  521. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  522. def forward(
  523. self,
  524. hidden_states: torch.Tensor,
  525. attention_mask: torch.Tensor,
  526. **kwargs: Unpack[TransformersKwargs],
  527. ) -> tuple[torch.FloatTensor, torch.Tensor | None]:
  528. residual = hidden_states
  529. hidden_states = self.layer_norm1(hidden_states)
  530. hidden_states, _ = self.self_attn(
  531. hidden_states=hidden_states,
  532. attention_mask=attention_mask,
  533. **kwargs,
  534. )
  535. hidden_states = residual + hidden_states
  536. residual = hidden_states
  537. hidden_states = self.layer_norm2(hidden_states)
  538. hidden_states = self.mlp(hidden_states)
  539. hidden_states = residual + hidden_states
  540. return hidden_states
  541. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->GitVision, CLIPConfig
  542. class GitVisionEncoder(nn.Module):
  543. """
  544. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  545. [`GitVisionEncoderLayer`].
  546. Args:
  547. config: GitVisionConfig
  548. """
  549. def __init__(self, config: GitVisionConfig):
  550. super().__init__()
  551. self.config = config
  552. self.layers = nn.ModuleList([GitVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  553. self.gradient_checkpointing = False
  554. def forward(
  555. self,
  556. inputs_embeds,
  557. attention_mask: torch.Tensor | None = None,
  558. **kwargs: Unpack[TransformersKwargs],
  559. ) -> tuple | BaseModelOutput:
  560. r"""
  561. Args:
  562. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  563. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  564. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  565. than the model's internal embedding lookup matrix.
  566. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  567. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  568. - 1 for tokens that are **not masked**,
  569. - 0 for tokens that are **masked**.
  570. [What are attention masks?](../glossary#attention-mask)
  571. """
  572. hidden_states = inputs_embeds
  573. for encoder_layer in self.layers:
  574. hidden_states = encoder_layer(
  575. hidden_states,
  576. attention_mask,
  577. **kwargs,
  578. )
  579. return BaseModelOutput(
  580. last_hidden_state=hidden_states,
  581. )
  582. class GitVisionTransformer(nn.Module):
  583. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPVisionTransformer.__init__ with AltCLIPEncoder->GitVisionEncoder, AltCLIP->Git
  584. def __init__(self, config: GitVisionConfig):
  585. super().__init__()
  586. self.config = config
  587. embed_dim = config.hidden_size
  588. self.embeddings = GitVisionEmbeddings(config)
  589. self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  590. self.encoder = GitVisionEncoder(config)
  591. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  592. @auto_docstring
  593. def forward(
  594. self,
  595. pixel_values: torch.FloatTensor | None = None,
  596. interpolate_pos_encoding: bool | None = False,
  597. **kwargs: Unpack[TransformersKwargs],
  598. ) -> BaseModelOutput:
  599. if pixel_values is None:
  600. raise ValueError("You have to specify pixel_values")
  601. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  602. hidden_states = self.pre_layrnorm(hidden_states)
  603. encoder_outputs = self.encoder(
  604. inputs_embeds=hidden_states,
  605. **kwargs,
  606. )
  607. last_hidden_state = encoder_outputs.last_hidden_state
  608. last_hidden_state = self.post_layernorm(last_hidden_state)
  609. return BaseModelOutput(
  610. last_hidden_state=last_hidden_state,
  611. )
  612. @auto_docstring(
  613. custom_intro="""
  614. The vision model from CLIP, used in GIT, without any head or projection on top.
  615. """
  616. )
  617. class GitVisionModel(GitPreTrainedModel):
  618. config: GitVisionConfig
  619. main_input_name = "pixel_values"
  620. input_modalities = ("image",)
  621. _can_record_outputs = {
  622. "hidden_states": GitVisionEncoderLayer,
  623. "attentions": GitVisionAttention,
  624. }
  625. # Copied from transformers.models.clip.modeling_clip.CLIPVisionModel.__init__ with CLIP->Git
  626. def __init__(self, config: GitVisionConfig):
  627. super().__init__(config)
  628. self.vision_model = GitVisionTransformer(config)
  629. # Initialize weights and apply final processing
  630. self.post_init()
  631. def get_input_embeddings(self) -> nn.Module:
  632. return self.vision_model.embeddings.patch_embedding
  633. @merge_with_config_defaults
  634. @capture_outputs(tie_last_hidden_states=False)
  635. @auto_docstring
  636. def forward(
  637. self,
  638. pixel_values: torch.FloatTensor | None = None,
  639. interpolate_pos_encoding: bool = False,
  640. **kwargs: Unpack[TransformersKwargs],
  641. ) -> tuple | BaseModelOutput:
  642. r"""
  643. Examples:
  644. ```python
  645. >>> from PIL import Image
  646. >>> import httpx
  647. >>> from io import BytesIO
  648. >>> from transformers import AutoProcessor, GitVisionModel
  649. >>> processor = AutoProcessor.from_pretrained("microsoft/git-base")
  650. >>> model = GitVisionModel.from_pretrained("microsoft/git-base")
  651. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  652. >>> with httpx.stream("GET", url) as response:
  653. ... image = Image.open(BytesIO(response.read()))
  654. >>> inputs = processor(images=image, return_tensors="pt")
  655. >>> outputs = model(**inputs)
  656. >>> last_hidden_state = outputs.last_hidden_state
  657. ```"""
  658. return self.vision_model(
  659. pixel_values=pixel_values,
  660. interpolate_pos_encoding=interpolate_pos_encoding,
  661. **kwargs,
  662. )
  663. class GitProjection(nn.Module):
  664. def __init__(self, config: GitConfig):
  665. super().__init__()
  666. self.config = config
  667. self.visual_projection = nn.Sequential(
  668. nn.Linear(config.vision_config.hidden_size, config.hidden_size),
  669. nn.LayerNorm(config.hidden_size, eps=config.vision_config.layer_norm_eps),
  670. )
  671. def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
  672. return self.visual_projection(embeddings)
  673. @auto_docstring(
  674. custom_intro="""
  675. The bare GIT Model transformer consisting of a CLIP image encoder and text decoder outputting raw hidden-states
  676. """
  677. )
  678. class GitModel(GitPreTrainedModel):
  679. _can_record_outputs = {
  680. "hidden_states": GitLayer,
  681. "attentions": GitSelfAttention,
  682. }
  683. def __init__(self, config):
  684. super().__init__(config)
  685. self.config = config
  686. self.embeddings = GitEmbeddings(config)
  687. self.image_encoder = GitVisionModel(config.vision_config)
  688. self.encoder = GitEncoder(config)
  689. self.visual_projection = GitProjection(config)
  690. if config.num_image_with_embedding is not None:
  691. self.img_temporal_embedding = nn.ParameterList(
  692. nn.Parameter(torch.zeros(1, 1, config.vision_config.hidden_size))
  693. for _ in range(config.num_image_with_embedding)
  694. )
  695. # Initialize weights and apply final processing
  696. self.post_init()
  697. def get_input_embeddings(self):
  698. return self.embeddings.word_embeddings
  699. def set_input_embeddings(self, value):
  700. self.embeddings.word_embeddings = value
  701. @merge_with_config_defaults
  702. @capture_outputs
  703. @auto_docstring
  704. def forward(
  705. self,
  706. input_ids: torch.Tensor | None = None,
  707. attention_mask: torch.Tensor | None = None,
  708. position_ids: torch.Tensor | None = None,
  709. pixel_values: torch.Tensor | None = None,
  710. inputs_embeds: torch.Tensor | None = None,
  711. past_key_values: Cache | None = None,
  712. use_cache: bool | None = None,
  713. interpolate_pos_encoding: bool = False,
  714. **kwargs: Unpack[TransformersKwargs],
  715. ) -> tuple[torch.Tensor] | BaseModelOutputWithPooling:
  716. r"""
  717. Examples:
  718. ```python
  719. >>> from transformers import AutoProcessor, AutoModel
  720. >>> import httpx
  721. >>> from io import BytesIO
  722. >>> from PIL import Image
  723. >>> processor = AutoProcessor.from_pretrained("microsoft/git-base")
  724. >>> model = AutoModel.from_pretrained("microsoft/git-base")
  725. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  726. >>> with httpx.stream("GET", url) as response:
  727. ... image = Image.open(BytesIO(response.read()))
  728. >>> text = "this is an image of two cats"
  729. >>> inputs = processor(images=image, text=text, return_tensors="pt")
  730. >>> outputs = model(**inputs)
  731. >>> last_hidden_state = outputs.last_hidden_state
  732. ```"""
  733. if (input_ids is None) ^ (inputs_embeds is not None):
  734. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  735. if use_cache and past_key_values is None:
  736. past_key_values = DynamicCache(config=self.config)
  737. # past_key_values_length
  738. past_key_values_length = 0
  739. if past_key_values is not None:
  740. past_key_values_length = (
  741. past_key_values.get_seq_length()
  742. if not isinstance(past_key_values, Cache)
  743. else past_key_values.get_seq_length()
  744. )
  745. # Adjust position ids by adding image seq length
  746. if pixel_values is None and past_key_values is not None and input_ids.shape[1] == 1:
  747. position_ids = position_ids + past_key_values_length
  748. embedding_output = self.embeddings(
  749. input_ids=input_ids,
  750. position_ids=position_ids,
  751. inputs_embeds=inputs_embeds,
  752. past_key_values_length=past_key_values_length,
  753. )
  754. # Always create `token_type_ids` so we can re-use Gemma3 style mask preparation fn
  755. token_type_ids = torch.zeros_like(embedding_output, dtype=torch.int)[..., 0]
  756. if pixel_values is not None:
  757. if pixel_values.ndim == 4:
  758. # here we assume pixel_values is of shape (batch_size, num_channels, height, width)
  759. visual_features = self.image_encoder(
  760. pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
  761. ).last_hidden_state
  762. elif pixel_values.ndim == 5:
  763. # here we assume pixel_values is of shape (batch_size, num_frames, num_channels, height, width)
  764. visual_features = []
  765. for frame_idx in range(pixel_values.shape[1]):
  766. visual_features_frame = self.image_encoder(
  767. pixel_values[:, frame_idx, :, :], interpolate_pos_encoding=interpolate_pos_encoding
  768. ).last_hidden_state
  769. visual_features_frame += self.img_temporal_embedding[frame_idx]
  770. visual_features.append(visual_features_frame)
  771. # finally, concatenate all features along sequence dimension
  772. visual_features = torch.cat(visual_features, dim=1)
  773. else:
  774. raise ValueError("pixel_values must be of rank 4 or 5")
  775. projected_visual_features = self.visual_projection(visual_features)
  776. # Repeat visual features to match embedding batch size.
  777. projected_visual_features = projected_visual_features.repeat(
  778. embedding_output.size(0) // projected_visual_features.size(0), 1, 1
  779. )
  780. # concatenate patch token and text token embeddings
  781. embedding_output = torch.cat((projected_visual_features, embedding_output), dim=1)
  782. image_token_type_ids = torch.ones_like(projected_visual_features, dtype=torch.int)[..., 0]
  783. token_type_ids = torch.cat([image_token_type_ids, token_type_ids], dim=-1)
  784. if attention_mask is not None:
  785. attention_mask = torch.cat([torch.ones_like(image_token_type_ids), attention_mask], dim=-1)
  786. elif past_key_values is not None and input_ids.shape[1] == 1:
  787. # Expand attention mask and cache position with image tokens because GIT doesn't add image
  788. # placeholder tokens when processing. Doesn't worth the refactor, low usage!
  789. extended_attention_mask = torch.ones(
  790. (attention_mask.shape[0], past_key_values_length - attention_mask.shape[1] + 1),
  791. dtype=attention_mask.dtype,
  792. device=attention_mask.device,
  793. )
  794. attention_mask = torch.cat([extended_attention_mask, attention_mask], dim=-1)
  795. # Images attend each other bidirectionally while text remains causal
  796. causal_mask = create_causal_mask_mapping(
  797. self.config,
  798. embedding_output,
  799. attention_mask,
  800. past_key_values,
  801. None,
  802. token_type_ids,
  803. pixel_values,
  804. )
  805. hidden_states = embedding_output
  806. encoder_outputs: BaseModelOutputWithPast = self.encoder(
  807. hidden_states,
  808. attention_mask=causal_mask,
  809. past_key_values=past_key_values,
  810. use_cache=use_cache,
  811. **kwargs,
  812. )
  813. return BaseModelOutputWithPast(
  814. last_hidden_state=encoder_outputs.last_hidden_state,
  815. past_key_values=encoder_outputs.past_key_values,
  816. )
  817. @auto_docstring(
  818. custom_intro="""
  819. GIT Model with a `language modeling` head on top for autoregressive language modeling.
  820. """
  821. )
  822. class GitForCausalLM(GitPreTrainedModel, GenerationMixin):
  823. _tied_weights_keys = {"output.weight": "git.embeddings.word_embeddings.weight"}
  824. def __init__(self, config):
  825. super().__init__(config)
  826. self.git = GitModel(config)
  827. self.output = nn.Linear(config.hidden_size, config.vocab_size)
  828. # Initialize weights and apply final processing
  829. self.post_init()
  830. def get_output_embeddings(self):
  831. return self.output
  832. def set_output_embeddings(self, new_embeddings):
  833. self.output = new_embeddings
  834. @merge_with_config_defaults
  835. @capture_outputs
  836. @auto_docstring
  837. def forward(
  838. self,
  839. input_ids: torch.Tensor | None = None,
  840. attention_mask: torch.Tensor | None = None,
  841. position_ids: torch.Tensor | None = None,
  842. pixel_values: torch.Tensor | None = None,
  843. inputs_embeds: torch.Tensor | None = None,
  844. labels: torch.Tensor | None = None,
  845. past_key_values: Cache | None = None,
  846. use_cache: bool | None = None,
  847. interpolate_pos_encoding: bool = False,
  848. logits_to_keep: int | torch.Tensor = 0,
  849. **kwargs: Unpack[TransformersKwargs],
  850. ) -> tuple[torch.Tensor] | CausalLMOutputWithPast:
  851. r"""
  852. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  853. Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
  854. `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
  855. ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
  856. Examples:
  857. Image captioning example:
  858. ```python
  859. >>> from transformers import AutoProcessor, AutoModelForCausalLM
  860. >>> import httpx
  861. >>> from io import BytesIO
  862. >>> from PIL import Image
  863. >>> processor = AutoProcessor.from_pretrained("microsoft/git-base-coco")
  864. >>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")
  865. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  866. >>> with httpx.stream("GET", url) as response:
  867. ... image = Image.open(BytesIO(response.read()))
  868. >>> pixel_values = processor(images=image, return_tensors="pt").pixel_values
  869. >>> generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
  870. >>> generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  871. >>> print(generated_caption)
  872. two cats sleeping on a pink blanket next to remotes.
  873. ```
  874. Visual question answering (VQA) example:
  875. ```python
  876. >>> from transformers import AutoProcessor, AutoModelForCausalLM
  877. >>> from huggingface_hub import hf_hub_download
  878. >>> from PIL import Image
  879. >>> processor = AutoProcessor.from_pretrained("microsoft/git-base-textvqa")
  880. >>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-textvqa")
  881. >>> file_path = hf_hub_download(repo_id="nielsr/textvqa-sample", filename="bus.png", repo_type="dataset")
  882. >>> image = Image.open(file_path).convert("RGB")
  883. >>> pixel_values = processor(images=image, return_tensors="pt").pixel_values
  884. >>> question = "what does the front of the bus say at the top?"
  885. >>> input_ids = processor(text=question, add_special_tokens=False).input_ids
  886. >>> input_ids = [processor.tokenizer.cls_token_id] + input_ids
  887. >>> input_ids = torch.tensor(input_ids).unsqueeze(0)
  888. >>> generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50)
  889. >>> print(processor.batch_decode(generated_ids, skip_special_tokens=True))
  890. ['what does the front of the bus say at the top? special']
  891. ```
  892. Video captioning example:
  893. ```python
  894. >>> import av
  895. >>> import numpy as np
  896. >>> from PIL import Image
  897. >>> from huggingface_hub import hf_hub_download
  898. >>> from transformers import AutoProcessor, AutoModelForCausalLM
  899. >>> processor = AutoProcessor.from_pretrained("microsoft/git-base-vatex")
  900. >>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vatex")
  901. >>> # set seed for reproducibility
  902. >>> np.random.seed(45)
  903. >>> def read_video_pyav(container, indices):
  904. ... '''
  905. ... Decode the video with PyAV decoder.
  906. ... Args:
  907. ... container (`av.container.input.InputContainer`): PyAV container.
  908. ... indices (`list[int]`): List of frame indices to decode.
  909. ... Returns:
  910. ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
  911. ... '''
  912. ... frames = []
  913. ... container.seek(0)
  914. ... start_index = indices[0]
  915. ... end_index = indices[-1]
  916. ... for i, frame in enumerate(container.decode(video=0)):
  917. ... if i > end_index:
  918. ... break
  919. ... if i >= start_index and i in indices:
  920. ... frames.append(frame)
  921. ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
  922. >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
  923. ... '''
  924. ... Sample a given number of frame indices from the video.
  925. ... Args:
  926. ... clip_len (`int`): Total number of frames to sample.
  927. ... frame_sample_rate (`int`): Sample every n-th frame.
  928. ... seg_len (`int`): Maximum allowed index of sample's last frame.
  929. ... Returns:
  930. ... indices (`list[int]`): List of sampled frame indices
  931. ... '''
  932. ... converted_len = int(clip_len * frame_sample_rate)
  933. ... end_idx = np.random.randint(converted_len, seg_len)
  934. ... start_idx = end_idx - converted_len
  935. ... indices = np.linspace(start_idx, end_idx, num=clip_len)
  936. ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
  937. ... return indices
  938. >>> # load video
  939. >>> file_path = hf_hub_download(
  940. ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
  941. ... )
  942. >>> container = av.open(file_path)
  943. >>> # sample frames
  944. >>> num_frames = model.config.num_image_with_embedding
  945. >>> indices = sample_frame_indices(
  946. ... clip_len=num_frames, frame_sample_rate=4, seg_len=container.streams.video[0].frames
  947. ... )
  948. >>> frames = read_video_pyav(container, indices)
  949. >>> pixel_values = processor(images=list(frames), return_tensors="pt").pixel_values
  950. >>> generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
  951. >>> print("Generated caption:", processor.batch_decode(generated_ids, skip_special_tokens=True))
  952. Generated caption: ['a woman is sitting at a table and she is talking about the food she is holding.']
  953. ```
  954. """
  955. if labels is not None:
  956. use_cache = False
  957. outputs: BaseModelOutputWithPast = self.git(
  958. input_ids,
  959. attention_mask=attention_mask,
  960. position_ids=position_ids,
  961. pixel_values=pixel_values,
  962. inputs_embeds=inputs_embeds,
  963. past_key_values=past_key_values,
  964. use_cache=use_cache,
  965. interpolate_pos_encoding=interpolate_pos_encoding,
  966. **kwargs,
  967. )
  968. hidden_states = outputs.last_hidden_state
  969. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  970. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  971. logits = self.output(hidden_states[:, slice_indices, :])
  972. loss = None
  973. if labels is not None:
  974. # we are doing next-token prediction; shift prediction scores and input ids by one
  975. num_image_tokens = self.git.encoder.layer[0].attention.self.image_patch_tokens
  976. shifted_logits = logits[:, num_image_tokens:-1, :].contiguous()
  977. labels = labels[:, 1:].contiguous()
  978. loss = self.loss_function(
  979. shifted_logits.view(-1, self.config.vocab_size),
  980. labels.view(-1),
  981. vocab_size=self.config.vocab_size,
  982. **kwargs,
  983. )
  984. return CausalLMOutputWithPast(
  985. loss=loss,
  986. logits=logits,
  987. past_key_values=outputs.past_key_values,
  988. hidden_states=outputs.hidden_states,
  989. attentions=outputs.attentions,
  990. )
  991. def prepare_inputs_for_generation(
  992. self,
  993. input_ids,
  994. past_key_values=None,
  995. pixel_values=None,
  996. attention_mask=None,
  997. use_cache=None,
  998. is_first_iteration=False,
  999. **kwargs,
  1000. ):
  1001. # Overwritten -- `git` has special `pixel_values` handling
  1002. model_inputs = super().prepare_inputs_for_generation(
  1003. input_ids,
  1004. past_key_values=past_key_values,
  1005. attention_mask=attention_mask,
  1006. use_cache=use_cache,
  1007. is_first_iteration=is_first_iteration,
  1008. **kwargs,
  1009. )
  1010. if is_first_iteration or not use_cache:
  1011. model_inputs["pixel_values"] = pixel_values
  1012. return model_inputs
  1013. __all__ = ["GitForCausalLM", "GitModel", "GitPreTrainedModel", "GitVisionModel"]