modeling_kosmos2.py 71 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693
  1. # Copyright 2023 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch KOSMOS-2 model."""
  15. import math
  16. import warnings
  17. from collections.abc import Callable
  18. from dataclasses import dataclass
  19. from typing import Any
  20. import torch
  21. from torch import nn
  22. from ... import initialization as init
  23. from ...activations import ACT2FN
  24. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  25. from ...generation import GenerationMixin
  26. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  27. from ...modeling_layers import GradientCheckpointingLayer
  28. from ...modeling_outputs import (
  29. BaseModelOutput,
  30. BaseModelOutputWithPastAndCrossAttentions,
  31. BaseModelOutputWithPooling,
  32. CausalLMOutputWithCrossAttentions,
  33. )
  34. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  35. from ...processing_utils import Unpack
  36. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int
  37. from ...utils.generic import merge_with_config_defaults
  38. from ...utils.output_capturing import OutputRecorder, capture_outputs
  39. from .configuration_kosmos2 import Kosmos2Config, Kosmos2TextConfig, Kosmos2VisionConfig
  40. logger = logging.get_logger(__name__)
  41. @auto_docstring
  42. class Kosmos2PreTrainedModel(PreTrainedModel):
  43. config: Kosmos2Config
  44. input_modalities = ("image", "text")
  45. supports_gradient_checkpointing = True
  46. _no_split_modules = ["Kosmos2VisionEncoderLayer", "Kosmos2TextBlock"]
  47. _supports_attention_backend = True
  48. _supports_flash_attn = False # cuda device errors
  49. _supports_sdpa = True
  50. @torch.no_grad()
  51. def _init_weights(self, module: nn.Module):
  52. """Initialize the weights"""
  53. if hasattr(self.config, "initializer_factor"):
  54. factor = self.config.initializer_factor
  55. elif hasattr(self.config, "vision_config"):
  56. factor = self.config.vision_config.initializer_factor
  57. if hasattr(self.config, "init_std"):
  58. std = self.config.init_std
  59. elif hasattr(self.config, "text_config"):
  60. std = self.config.text_config.init_std
  61. if isinstance(module, Kosmos2VisionEmbeddings):
  62. init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
  63. init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
  64. init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
  65. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  66. elif isinstance(module, Kosmos2VisionAttention):
  67. in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  68. out_proj_std = (module.embed_dim**-0.5) * factor
  69. init.normal_(module.q_proj.weight, std=in_proj_std)
  70. init.normal_(module.k_proj.weight, std=in_proj_std)
  71. init.normal_(module.v_proj.weight, std=in_proj_std)
  72. init.normal_(module.out_proj.weight, std=out_proj_std)
  73. elif isinstance(module, Kosmos2VisionMLP):
  74. in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  75. fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
  76. init.normal_(module.fc1.weight, std=fc_std)
  77. init.normal_(module.fc2.weight, std=in_proj_std)
  78. elif isinstance(module, KosmosTextAttention):
  79. init.normal_(module.q_proj.weight, std=std)
  80. init.normal_(module.k_proj.weight, std=std)
  81. init.normal_(module.v_proj.weight, std=std)
  82. init.normal_(module.out_proj.weight, std=std)
  83. elif isinstance(module, Kosmos2TextFFN):
  84. init.normal_(module.fc1.weight, std=std)
  85. init.normal_(module.fc2.weight, std=std)
  86. elif isinstance(module, Kosmos2TextForCausalLM):
  87. init.normal_(module.lm_head.weight, std=std)
  88. elif isinstance(module, Kosmos2ImageToTextProjection):
  89. init.normal_(module.dense.weight, std=std)
  90. init.normal_(module.latent_query)
  91. elif isinstance(module, Kosmos2TextTransformer):
  92. init.normal_(module.embed_tokens.weight, mean=0.0, std=std)
  93. if module.embed_tokens.padding_idx is not None:
  94. init.zeros_(module.embed_tokens.weight[module.embed_tokens.padding_idx])
  95. elif isinstance(module, nn.LayerNorm):
  96. init.ones_(module.weight)
  97. init.zeros_(module.bias)
  98. elif isinstance(module, Kosmos2TextSinusoidalPositionalEmbedding):
  99. emb_weights = module.get_embedding(
  100. module.num_positions + module.offset, module.embedding_dim, module.padding_idx
  101. )
  102. init.copy_(module.weights, emb_weights)
  103. if isinstance(module, nn.Linear) and module.bias is not None:
  104. init.zeros_(module.bias)
  105. def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = None):
  106. """
  107. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
  108. """
  109. bsz, src_len = mask.size()
  110. tgt_len = tgt_len if tgt_len is not None else src_len
  111. expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
  112. inverted_mask = 1.0 - expanded_mask
  113. return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
  114. def _make_causal_mask(
  115. input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
  116. ):
  117. """
  118. Make causal mask used for bi-directional self-attention.
  119. """
  120. bsz, tgt_len = input_ids_shape
  121. mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
  122. mask_cond = torch.arange(mask.size(-1), device=device)
  123. mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
  124. mask = mask.to(dtype)
  125. if past_key_values_length > 0:
  126. mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
  127. return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
  128. @dataclass
  129. @auto_docstring
  130. class BaseModelOutputWithProjectionAttentions(BaseModelOutputWithPooling):
  131. r"""
  132. projection_attentions (`tuple(torch.FloatTensor)`):
  133. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  134. sequence_length)`.
  135. Attentions weights given by `Kosmos2ImageToTextProjection`, after the attention softmax, used to compute
  136. the weighted average in the self-attention heads.
  137. """
  138. projection_attentions: tuple[torch.FloatTensor] | None = None
  139. @dataclass
  140. @auto_docstring(
  141. custom_intro="""
  142. Base class for text model's outputs that also contains a pooling of the last hidden states.
  143. """
  144. )
  145. class Kosmos2ModelOutput(ModelOutput):
  146. r"""
  147. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  148. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  149. Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
  150. `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
  151. input) to speed up sequential decoding.
  152. image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*):
  153. Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`.
  154. projection_attentions (`tuple(torch.FloatTensor)`, *optional*):
  155. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  156. sequence_length)`.
  157. Attentions weights given by `Kosmos2ImageToTextProjection`, after the attention softmax, used to compute
  158. the weighted average in the self-attention heads.
  159. vision_model_output (`BaseModelOutputWithPooling`, *optional*):
  160. The output of the [`Kosmos2VisionModel`].
  161. """
  162. last_hidden_state: torch.FloatTensor | None = None
  163. past_key_values: Cache | None = None
  164. hidden_states: tuple[torch.FloatTensor] | None = None
  165. attentions: tuple[torch.FloatTensor] | None = None
  166. image_embeds: torch.FloatTensor | None = None
  167. projection_attentions: tuple[torch.FloatTensor] | None = None
  168. vision_model_output: BaseModelOutputWithPooling = None
  169. def to_tuple(self) -> tuple[Any]:
  170. return tuple(
  171. self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
  172. for k in self.keys()
  173. )
  174. @dataclass
  175. @auto_docstring(
  176. custom_intro="""
  177. Model output class for `Kosmos2ForConditionalGeneration`.
  178. """
  179. )
  180. class Kosmos2ForConditionalGenerationModelOutput(ModelOutput):
  181. r"""
  182. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  183. Language modeling loss (for next-token prediction).
  184. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  185. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  186. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  187. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  188. Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
  189. `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
  190. input) to speed up sequential decoding.
  191. image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*):
  192. Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`.
  193. projection_attentions (`tuple(torch.FloatTensor)`, *optional*):
  194. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  195. sequence_length)`.
  196. Attentions weights given by `Kosmos2ImageToTextProjection`, after the attention softmax, used to compute
  197. the weighted average in the self-attention heads.
  198. vision_model_output (`BaseModelOutputWithPooling`, *optional*):
  199. The output of the [`Kosmos2VisionModel`].
  200. """
  201. loss: torch.FloatTensor | None = None
  202. logits: torch.FloatTensor | None = None
  203. past_key_values: Cache | None = None
  204. hidden_states: tuple[torch.FloatTensor] | None = None
  205. attentions: tuple[torch.FloatTensor] | None = None
  206. image_embeds: torch.FloatTensor | None = None
  207. projection_attentions: tuple[torch.FloatTensor] | None = None
  208. vision_model_output: BaseModelOutputWithPooling = None
  209. def to_tuple(self) -> tuple[Any]:
  210. return tuple(
  211. self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
  212. for k in self.keys()
  213. )
  214. # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Kosmos2
  215. class Kosmos2VisionEmbeddings(nn.Module):
  216. def __init__(self, config: Kosmos2VisionConfig):
  217. super().__init__()
  218. self.config = config
  219. self.embed_dim = config.hidden_size
  220. self.image_size = config.image_size
  221. self.patch_size = config.patch_size
  222. self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
  223. self.patch_embedding = nn.Conv2d(
  224. in_channels=config.num_channels,
  225. out_channels=self.embed_dim,
  226. kernel_size=self.patch_size,
  227. stride=self.patch_size,
  228. bias=False,
  229. )
  230. self.num_patches = (self.image_size // self.patch_size) ** 2
  231. self.num_positions = self.num_patches + 1
  232. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  233. self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
  234. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  235. """
  236. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  237. images. This method is also adapted to support torch.jit tracing.
  238. Adapted from:
  239. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  240. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  241. """
  242. num_patches = embeddings.shape[1] - 1
  243. position_embedding = self.position_embedding.weight.unsqueeze(0)
  244. num_positions = position_embedding.shape[1] - 1
  245. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  246. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  247. return self.position_embedding(self.position_ids)
  248. class_pos_embed = position_embedding[:, :1]
  249. patch_pos_embed = position_embedding[:, 1:]
  250. dim = embeddings.shape[-1]
  251. new_height = height // self.patch_size
  252. new_width = width // self.patch_size
  253. sqrt_num_positions = torch_int(num_positions**0.5)
  254. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  255. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  256. patch_pos_embed = nn.functional.interpolate(
  257. patch_pos_embed,
  258. size=(new_height, new_width),
  259. mode="bicubic",
  260. align_corners=False,
  261. )
  262. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  263. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  264. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
  265. batch_size, _, height, width = pixel_values.shape
  266. if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
  267. raise ValueError(
  268. f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})."
  269. )
  270. target_dtype = self.patch_embedding.weight.dtype
  271. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
  272. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  273. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  274. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  275. if interpolate_pos_encoding:
  276. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  277. else:
  278. embeddings = embeddings + self.position_embedding(self.position_ids)
  279. return embeddings
  280. # Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> Kosmos2 doesn't cast attn weights to fp32
  281. def eager_attention_forward(
  282. module: nn.Module,
  283. query: torch.Tensor,
  284. key: torch.Tensor,
  285. value: torch.Tensor,
  286. attention_mask: torch.Tensor | None,
  287. scaling: float,
  288. dropout: float = 0.0,
  289. **kwargs,
  290. ):
  291. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  292. if attention_mask is not None:
  293. attn_weights = attn_weights + attention_mask
  294. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  295. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  296. attn_output = torch.matmul(attn_weights, value)
  297. attn_output = attn_output.transpose(1, 2).contiguous()
  298. return attn_output, attn_weights
  299. class Kosmos2VisionAttention(nn.Module):
  300. """Multi-headed attention from 'Attention Is All You Need' paper"""
  301. def __init__(self, config):
  302. super().__init__()
  303. self.config = config
  304. self.embed_dim = config.hidden_size
  305. self.num_heads = config.num_attention_heads
  306. self.head_dim = self.embed_dim // self.num_heads
  307. if self.head_dim * self.num_heads != self.embed_dim:
  308. raise ValueError(
  309. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  310. f" {self.num_heads})."
  311. )
  312. self.scale = self.head_dim**-0.5
  313. self.dropout = config.attention_dropout
  314. self.is_causal = False
  315. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  316. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  317. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  318. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  319. def forward(
  320. self,
  321. hidden_states: torch.Tensor,
  322. attention_mask: torch.Tensor | None = None,
  323. **kwargs: Unpack[TransformersKwargs],
  324. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  325. """Input shape: Batch x Time x Channel"""
  326. input_shape = hidden_states.shape[:-1]
  327. hidden_shape = (*input_shape, -1, self.head_dim)
  328. queries = self.q_proj(hidden_states)
  329. keys = self.k_proj(hidden_states)
  330. values = self.v_proj(hidden_states)
  331. queries = queries.view(hidden_shape).transpose(1, 2)
  332. keys = keys.view(hidden_shape).transpose(1, 2)
  333. values = values.view(hidden_shape).transpose(1, 2)
  334. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  335. self.config._attn_implementation, eager_attention_forward
  336. )
  337. attn_output, attn_weights = attention_interface(
  338. self,
  339. queries,
  340. keys,
  341. values,
  342. attention_mask,
  343. is_causal=self.is_causal,
  344. scaling=self.scale,
  345. dropout=0.0 if not self.training else self.dropout,
  346. **kwargs,
  347. )
  348. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  349. attn_output = self.out_proj(attn_output)
  350. return attn_output, attn_weights
  351. # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Kosmos2Vision
  352. class Kosmos2VisionMLP(nn.Module):
  353. def __init__(self, config):
  354. super().__init__()
  355. self.config = config
  356. self.activation_fn = ACT2FN[config.hidden_act]
  357. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  358. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  359. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  360. hidden_states = self.fc1(hidden_states)
  361. hidden_states = self.activation_fn(hidden_states)
  362. hidden_states = self.fc2(hidden_states)
  363. return hidden_states
  364. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->Kosmos2Vision
  365. class Kosmos2VisionEncoderLayer(GradientCheckpointingLayer):
  366. def __init__(self, config: Kosmos2VisionConfig):
  367. super().__init__()
  368. self.embed_dim = config.hidden_size
  369. self.self_attn = Kosmos2VisionAttention(config)
  370. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  371. self.mlp = Kosmos2VisionMLP(config)
  372. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  373. def forward(
  374. self,
  375. hidden_states: torch.Tensor,
  376. attention_mask: torch.Tensor,
  377. **kwargs: Unpack[TransformersKwargs],
  378. ) -> tuple[torch.FloatTensor, torch.Tensor | None]:
  379. residual = hidden_states
  380. hidden_states = self.layer_norm1(hidden_states)
  381. hidden_states, _ = self.self_attn(
  382. hidden_states=hidden_states,
  383. attention_mask=attention_mask,
  384. **kwargs,
  385. )
  386. hidden_states = residual + hidden_states
  387. residual = hidden_states
  388. hidden_states = self.layer_norm2(hidden_states)
  389. hidden_states = self.mlp(hidden_states)
  390. hidden_states = residual + hidden_states
  391. return hidden_states
  392. class Kosmos2VisionEncoder(nn.Module):
  393. """
  394. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  395. [`Kosmos2VisionEncoderLayer`].
  396. Args:
  397. config: Kosmos2VisionConfig
  398. """
  399. def __init__(self, config: Kosmos2VisionConfig):
  400. super().__init__()
  401. self.config = config
  402. self.layers = nn.ModuleList([Kosmos2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  403. self.gradient_checkpointing = False
  404. def forward(
  405. self,
  406. inputs_embeds,
  407. attention_mask: torch.Tensor | None = None,
  408. **kwargs: Unpack[TransformersKwargs],
  409. ) -> tuple | BaseModelOutput:
  410. r"""
  411. Args:
  412. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  413. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  414. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  415. than the model's internal embedding lookup matrix.
  416. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  417. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  418. - 1 for tokens that are **not masked**,
  419. - 0 for tokens that are **masked**.
  420. [What are attention masks?](../glossary#attention-mask)
  421. """
  422. hidden_states = inputs_embeds
  423. for encoder_layer in self.layers:
  424. hidden_states = encoder_layer(
  425. hidden_states,
  426. attention_mask,
  427. **kwargs,
  428. )
  429. return BaseModelOutputWithProjectionAttentions(
  430. last_hidden_state=hidden_states,
  431. )
  432. # Similar to `transformers.models.clip.modeling_clip.CLIPVisionTransformer` but without docstring for `forward`
  433. class Kosmos2VisionTransformer(Kosmos2PreTrainedModel):
  434. _can_record_outputs = {
  435. "hidden_states": Kosmos2VisionEncoderLayer,
  436. "attentions": Kosmos2VisionAttention,
  437. }
  438. def __init__(self, config: Kosmos2VisionConfig):
  439. super().__init__(config)
  440. embed_dim = config.hidden_size
  441. self.embeddings = Kosmos2VisionEmbeddings(config)
  442. self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  443. self.encoder = Kosmos2VisionEncoder(config)
  444. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  445. self.post_init()
  446. @merge_with_config_defaults
  447. @capture_outputs(tie_last_hidden_states=False)
  448. @auto_docstring
  449. def forward(
  450. self,
  451. pixel_values: torch.FloatTensor | None = None,
  452. interpolate_pos_encoding: bool = False,
  453. **kwargs: Unpack[TransformersKwargs],
  454. ) -> BaseModelOutputWithPooling:
  455. if pixel_values is None:
  456. raise ValueError("You have to specify pixel_values")
  457. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  458. hidden_states = self.pre_layrnorm(hidden_states)
  459. encoder_outputs = self.encoder(
  460. inputs_embeds=hidden_states,
  461. **kwargs,
  462. )
  463. last_hidden_state = encoder_outputs[0]
  464. pooled_output = last_hidden_state[:, 0, :]
  465. pooled_output = self.post_layernorm(pooled_output)
  466. return BaseModelOutputWithPooling(
  467. last_hidden_state=last_hidden_state,
  468. pooler_output=pooled_output,
  469. )
  470. # Similar to `transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding` but allowing to pass `position_ids`
  471. class Kosmos2TextSinusoidalPositionalEmbedding(nn.Module):
  472. """This module produces sinusoidal positional embeddings of any length."""
  473. # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.__init__
  474. def __init__(self, num_positions: int, embedding_dim: int, padding_idx: int | None = None):
  475. super().__init__()
  476. self.offset = 2
  477. self.num_positions = num_positions
  478. self.embedding_dim = embedding_dim
  479. self.padding_idx = padding_idx
  480. self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
  481. # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.make_weights
  482. def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: int | None = None):
  483. emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
  484. if hasattr(self, "weights"):
  485. # in forward put the weights on the correct dtype and device of the param
  486. emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
  487. self.register_buffer("weights", emb_weights, persistent=False)
  488. @staticmethod
  489. # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.get_embedding
  490. def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: int | None = None):
  491. """
  492. Build sinusoidal embeddings.
  493. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
  494. "Attention Is All You Need".
  495. """
  496. half_dim = embedding_dim // 2
  497. emb = math.log(10000) / (half_dim - 1)
  498. emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
  499. emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
  500. emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
  501. if embedding_dim % 2 == 1:
  502. # zero pad
  503. emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
  504. if padding_idx is not None:
  505. emb[padding_idx, :] = 0
  506. return emb.to(torch.get_default_dtype())
  507. @torch.no_grad()
  508. def forward(
  509. self,
  510. input_ids: torch.Tensor | None = None,
  511. inputs_embeds: torch.Tensor | None = None,
  512. past_key_values_length: int = 0,
  513. position_ids: torch.Tensor | None = None,
  514. ):
  515. if input_ids is not None:
  516. bsz, seq_len = input_ids.size()
  517. if position_ids is None:
  518. # Create the position ids from the input token ids. Any padded tokens remain padded.
  519. position_ids = self.create_position_ids_from_input_ids(
  520. input_ids, self.padding_idx, past_key_values_length
  521. ).to(input_ids.device)
  522. else:
  523. bsz, seq_len = inputs_embeds.size()[:-1]
  524. if position_ids is None:
  525. position_ids = self.create_position_ids_from_inputs_embeds(
  526. inputs_embeds, past_key_values_length, self.padding_idx
  527. )
  528. # expand embeddings if needed
  529. max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
  530. if max_pos > self.weights.size(0):
  531. self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
  532. return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach()
  533. @staticmethod
  534. # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.create_position_ids_from_inputs_embeds
  535. def create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length, padding_idx):
  536. """
  537. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  538. Args:
  539. inputs_embeds: torch.Tensor
  540. Returns: torch.Tensor
  541. """
  542. input_shape = inputs_embeds.size()[:-1]
  543. sequence_length = input_shape[1]
  544. position_ids = torch.arange(
  545. padding_idx + 1, sequence_length + padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  546. )
  547. return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length
  548. @staticmethod
  549. # Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings.create_position_ids_from_input_ids
  550. def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
  551. """
  552. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
  553. are ignored. This is modified from fairseq's `utils.make_positions`.
  554. Args:
  555. x: torch.Tensor x:
  556. Returns: torch.Tensor
  557. """
  558. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  559. mask = input_ids.ne(padding_idx).int()
  560. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
  561. return incremental_indices.long() + padding_idx
  562. class KosmosTextAttention(nn.Module):
  563. """Multi-headed attention from 'Attention Is All You Need' paper"""
  564. # Similar to transformers.models.bart.modeling_bart.BartAttention.__init__ except an additional `inner_attn_ln`.
  565. def __init__(
  566. self,
  567. config,
  568. embed_dim: int,
  569. num_heads: int,
  570. dropout: float = 0.0,
  571. is_decoder: bool | None = False,
  572. add_inner_attn_layernorm: bool | None = False,
  573. bias: bool | None = True,
  574. layer_idx: bool | None = None,
  575. ):
  576. super().__init__()
  577. self.config = config
  578. self.embed_dim = embed_dim
  579. self.num_heads = num_heads
  580. self.dropout = dropout
  581. self.head_dim = embed_dim // num_heads
  582. self.is_causal = True
  583. if (self.head_dim * num_heads) != self.embed_dim:
  584. raise ValueError(
  585. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  586. f" and `num_heads`: {num_heads})."
  587. )
  588. self.scaling = self.head_dim**-0.5
  589. self.is_decoder = is_decoder
  590. self.layer_idx = layer_idx
  591. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  592. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  593. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  594. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  595. # End opy
  596. self.inner_attn_ln = None
  597. if add_inner_attn_layernorm:
  598. self.inner_attn_ln = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  599. def forward(
  600. self,
  601. hidden_states: torch.Tensor,
  602. encoder_hidden_states: torch.Tensor | None = None,
  603. past_key_values: Cache | None = None,
  604. attention_mask: torch.Tensor | None = None,
  605. **kwargs,
  606. ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]:
  607. """Input shape: Batch x Time x Channel"""
  608. # if key_value_states are provided this layer is used as a cross-attention layer
  609. # for the decoder
  610. is_cross_attention = encoder_hidden_states is not None
  611. input_shape = hidden_states.shape[:-1]
  612. hidden_shape = (*input_shape, -1, self.head_dim)
  613. query_states = self.q_proj(hidden_states)
  614. query_states = query_states.view(hidden_shape).transpose(1, 2)
  615. is_updated = False
  616. if past_key_values is not None:
  617. if isinstance(past_key_values, EncoderDecoderCache):
  618. is_updated = past_key_values.is_updated.get(self.layer_idx)
  619. if is_cross_attention:
  620. # after the first generated id, we can subsequently re-use all key/value_states from cache
  621. curr_past_key_values = past_key_values.cross_attention_cache
  622. else:
  623. curr_past_key_values = past_key_values.self_attention_cache
  624. else:
  625. curr_past_key_values = past_key_values
  626. current_states = encoder_hidden_states if is_cross_attention else hidden_states
  627. if is_cross_attention and past_key_values is not None and is_updated:
  628. # reuse k,v, cross_attentions
  629. key_states = curr_past_key_values.layers[self.layer_idx].keys
  630. value_states = curr_past_key_values.layers[self.layer_idx].values
  631. else:
  632. kv_shape = (*current_states.shape[:-1], -1, self.head_dim)
  633. key_states = self.k_proj(current_states).view(kv_shape).transpose(1, 2)
  634. value_states = self.v_proj(current_states).view(kv_shape).transpose(1, 2)
  635. if past_key_values is not None:
  636. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  637. key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
  638. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  639. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  640. past_key_values.is_updated[self.layer_idx] = True
  641. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  642. self.config._attn_implementation, eager_attention_forward
  643. )
  644. attn_output, attn_weights = attention_interface(
  645. self,
  646. query_states,
  647. key_states,
  648. value_states,
  649. attention_mask,
  650. dropout=0.0 if not self.training else self.dropout,
  651. scaling=self.scaling,
  652. **kwargs,
  653. )
  654. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  655. if self.inner_attn_ln is not None:
  656. attn_output = self.inner_attn_ln(attn_output)
  657. attn_output = self.out_proj(attn_output)
  658. return attn_output, attn_weights
  659. class Kosmos2TextFFN(nn.Module):
  660. def __init__(self, config: Kosmos2TextConfig):
  661. super().__init__()
  662. self.dropout = config.dropout
  663. self.activation_fn = ACT2FN[config.activation_function]
  664. self.activation_dropout = config.activation_dropout
  665. self.fc1 = nn.Linear(config.embed_dim, config.ffn_dim)
  666. self.fc2 = nn.Linear(config.ffn_dim, config.embed_dim)
  667. self.ffn_layernorm = nn.LayerNorm(config.ffn_dim, eps=config.layer_norm_eps)
  668. def forward(self, hidden_states):
  669. hidden_states = self.activation_fn(self.fc1(hidden_states))
  670. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  671. hidden_states = self.ffn_layernorm(hidden_states)
  672. hidden_states = self.fc2(hidden_states)
  673. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  674. return hidden_states
  675. class Kosmos2TextBlock(GradientCheckpointingLayer):
  676. def __init__(self, config: Kosmos2TextConfig, layer_idx=None):
  677. super().__init__()
  678. self.embed_dim = config.embed_dim
  679. self.self_attn = KosmosTextAttention(
  680. config,
  681. embed_dim=self.embed_dim,
  682. num_heads=config.attention_heads,
  683. dropout=config.attention_dropout,
  684. is_decoder=True,
  685. add_inner_attn_layernorm=True,
  686. layer_idx=layer_idx,
  687. )
  688. self.dropout = config.dropout
  689. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  690. if config.add_cross_attention:
  691. self.encoder_attn = KosmosTextAttention(
  692. config,
  693. embed_dim=self.embed_dim,
  694. num_heads=config.attention_heads,
  695. dropout=config.attention_dropout,
  696. is_decoder=True,
  697. add_inner_attn_layernorm=False,
  698. layer_idx=layer_idx,
  699. )
  700. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  701. self.ffn = Kosmos2TextFFN(config)
  702. self.final_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  703. def forward(
  704. self,
  705. hidden_states: torch.Tensor,
  706. attention_mask: torch.Tensor | None = None,
  707. encoder_hidden_states: torch.Tensor | None = None,
  708. encoder_attention_mask: torch.Tensor | None = None,
  709. past_key_values: Cache | None = None,
  710. output_attentions: bool | None = False,
  711. **kwargs,
  712. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  713. residual = hidden_states
  714. hidden_states = self.self_attn_layer_norm(hidden_states)
  715. hidden_states, _ = self.self_attn(
  716. hidden_states=hidden_states,
  717. past_key_values=past_key_values,
  718. attention_mask=attention_mask,
  719. output_attentions=output_attentions,
  720. **kwargs,
  721. )
  722. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  723. hidden_states = residual + hidden_states
  724. # Cross-Attention Block
  725. if encoder_hidden_states is not None:
  726. if not hasattr(self, "encoder_attn"):
  727. raise ValueError(
  728. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
  729. " by setting `config.add_cross_attention=True`"
  730. )
  731. residual = hidden_states
  732. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  733. hidden_states, _ = self.encoder_attn(
  734. hidden_states=hidden_states,
  735. encoder_hidden_states=encoder_hidden_states,
  736. attention_mask=encoder_attention_mask,
  737. past_key_values=past_key_values,
  738. output_attentions=output_attentions,
  739. **kwargs,
  740. )
  741. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  742. hidden_states = residual + hidden_states
  743. # Fully Connected
  744. residual = hidden_states
  745. hidden_states = self.final_layer_norm(hidden_states)
  746. # FFN
  747. hidden_states = self.ffn(hidden_states)
  748. hidden_states = residual + hidden_states
  749. return hidden_states
  750. class Kosmos2TextTransformer(Kosmos2PreTrainedModel):
  751. config: Kosmos2TextConfig
  752. input_modalities = ("text",)
  753. _can_record_outputs = {
  754. "hidden_states": Kosmos2TextBlock,
  755. "attentions": OutputRecorder(KosmosTextAttention, index=1, layer_name="self_attn"),
  756. "cross_attentions": OutputRecorder(KosmosTextAttention, index=1, layer_name="encoder_attn"),
  757. }
  758. def __init__(self, config: Kosmos2TextConfig):
  759. super().__init__(config)
  760. self.dropout = config.dropout
  761. self.layerdrop = config.layerdrop
  762. self.embed_scale = math.sqrt(config.embed_dim) if config.scale_embedding else 1.0
  763. self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim, padding_idx=config.pad_token_id)
  764. self.embed_positions = Kosmos2TextSinusoidalPositionalEmbedding(
  765. num_positions=config.max_position_embeddings,
  766. embedding_dim=config.embed_dim,
  767. padding_idx=config.pad_token_id,
  768. )
  769. self.layers = nn.ModuleList([Kosmos2TextBlock(config, layer_idx=i) for i in range(config.layers)])
  770. self.layer_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps)
  771. self.gradient_checkpointing = False
  772. self.post_init()
  773. def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
  774. # create causal mask
  775. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  776. combined_attention_mask = None
  777. if input_shape[-1] > 1:
  778. combined_attention_mask = _make_causal_mask(
  779. input_shape,
  780. inputs_embeds.dtype,
  781. device=inputs_embeds.device,
  782. past_key_values_length=past_key_values_length,
  783. )
  784. if attention_mask is not None:
  785. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  786. expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
  787. inputs_embeds.device
  788. )
  789. combined_attention_mask = (
  790. expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
  791. )
  792. return combined_attention_mask
  793. def forward_embedding(
  794. self,
  795. input_ids,
  796. inputs_embeds: torch.Tensor | None = None,
  797. image_embeds: torch.Tensor | None = None,
  798. img_input_mask: torch.Tensor | None = None,
  799. past_key_values_length: int = 0,
  800. position_ids: torch.Tensor | None = None,
  801. ):
  802. # The argument `inputs_embeds` should be the one without being multiplied by `self.embed_scale`.
  803. if inputs_embeds is None:
  804. inputs_embeds = self.embed_tokens(input_ids)
  805. if image_embeds is not None:
  806. inputs_embeds[img_input_mask.to(dtype=torch.bool)] = image_embeds.to(inputs_embeds.device).view(
  807. -1, image_embeds.size(-1)
  808. )
  809. inputs_embeds = inputs_embeds * self.embed_scale
  810. # embed positions
  811. positions = self.embed_positions(
  812. input_ids=input_ids,
  813. inputs_embeds=inputs_embeds,
  814. past_key_values_length=past_key_values_length,
  815. position_ids=position_ids,
  816. )
  817. positions = positions.to(inputs_embeds.device)
  818. hidden_states = inputs_embeds + positions
  819. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  820. return hidden_states
  821. @merge_with_config_defaults
  822. @capture_outputs
  823. @auto_docstring
  824. def forward(
  825. self,
  826. input_ids: torch.Tensor | None = None,
  827. attention_mask: torch.Tensor | None = None,
  828. image_embeds: torch.Tensor | None = None,
  829. image_embeds_position_mask: torch.Tensor | None = None,
  830. encoder_hidden_states: torch.Tensor | None = None,
  831. encoder_attention_mask: torch.Tensor | None = None,
  832. past_key_values: Cache | None = None,
  833. inputs_embeds: torch.Tensor | None = None,
  834. position_ids: torch.Tensor | None = None,
  835. use_cache: bool | None = None,
  836. output_attentions: bool | None = None,
  837. output_hidden_states: bool | None = None,
  838. return_dict: bool | None = None,
  839. **kwargs: Unpack[FlashAttentionKwargs],
  840. ) -> tuple | BaseModelOutputWithPastAndCrossAttentions:
  841. r"""
  842. image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*):
  843. Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`.
  844. image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  845. Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0,
  846. 1]`:
  847. - 1 for places where to put the image features,
  848. - 0 for places that are not for image features (i.e. for text tokens).
  849. """
  850. if input_ids is not None and inputs_embeds is not None:
  851. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  852. elif input_ids is not None:
  853. input_shape = input_ids.shape
  854. input_ids = input_ids.view(-1, input_shape[-1])
  855. elif inputs_embeds is not None:
  856. input_shape = inputs_embeds.size()[:-1]
  857. else:
  858. raise ValueError("You have to specify either input_ids or inputs_embeds")
  859. if use_cache and past_key_values is None:
  860. past_key_values = (
  861. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  862. if encoder_hidden_states is not None or self.config.is_encoder_decoder
  863. else DynamicCache(config=self.config)
  864. )
  865. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  866. # We don't need img info. when `past_key_values_length` > 0
  867. if past_key_values_length > 0:
  868. image_embeds = None
  869. image_embeds_position_mask = None
  870. hidden_states = self.forward_embedding(
  871. input_ids=input_ids,
  872. inputs_embeds=inputs_embeds,
  873. image_embeds=image_embeds,
  874. img_input_mask=image_embeds_position_mask,
  875. past_key_values_length=past_key_values_length,
  876. position_ids=position_ids,
  877. )
  878. attention_mask = self._prepare_decoder_attention_mask(
  879. attention_mask, input_shape, hidden_states, past_key_values_length
  880. )
  881. # expand encoder attention mask
  882. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  883. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  884. encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
  885. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  886. for decoder_layer in self.layers:
  887. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  888. if self.training:
  889. dropout_probability = torch.rand([])
  890. if dropout_probability < self.layerdrop:
  891. continue
  892. hidden_states = decoder_layer(
  893. hidden_states,
  894. attention_mask,
  895. encoder_hidden_states,
  896. encoder_attention_mask=encoder_attention_mask,
  897. past_key_values=past_key_values,
  898. output_attentions=output_attentions,
  899. use_cache=use_cache,
  900. **kwargs,
  901. )
  902. # add final layer norm
  903. hidden_states = self.layer_norm(hidden_states)
  904. return BaseModelOutputWithPastAndCrossAttentions(
  905. last_hidden_state=hidden_states,
  906. past_key_values=past_key_values,
  907. )
  908. class Kosmos2VisionModel(Kosmos2PreTrainedModel):
  909. config: Kosmos2VisionConfig
  910. main_input_name = "pixel_values"
  911. input_modalities = ("image",)
  912. # Copied from transformers.models.clip.modeling_clip.CLIPVisionModel.__init__ with CLIP_VISION->KOSMOS2_VISION,CLIP->Kosmos2,self.vision_model->self.model
  913. def __init__(self, config: Kosmos2VisionConfig):
  914. super().__init__(config)
  915. self.model = Kosmos2VisionTransformer(config)
  916. # Initialize weights and apply final processing
  917. self.post_init()
  918. # Copied from transformers.models.clip.modeling_clip.CLIPVisionModel.get_input_embeddings with CLIP_VISION->KOSMOS2_VISION,CLIP->Kosmos2,self.vision_model->self.model
  919. def get_input_embeddings(self) -> nn.Module:
  920. return self.model.embeddings.patch_embedding
  921. @can_return_tuple
  922. @auto_docstring
  923. def forward(
  924. self,
  925. pixel_values: torch.FloatTensor | None = None,
  926. interpolate_pos_encoding: bool = False,
  927. **kwargs: Unpack[TransformersKwargs],
  928. ) -> tuple | BaseModelOutputWithProjectionAttentions:
  929. return self.model(
  930. pixel_values=pixel_values,
  931. interpolate_pos_encoding=interpolate_pos_encoding,
  932. **kwargs,
  933. )
  934. class Kosmos2TextModel(Kosmos2PreTrainedModel):
  935. config: Kosmos2TextConfig
  936. input_modalities = ("text",)
  937. def __init__(self, config: Kosmos2TextConfig):
  938. super().__init__(config)
  939. self.model = Kosmos2TextTransformer(config)
  940. # Initialize weights and apply final processing
  941. self.post_init()
  942. def get_input_embeddings(self) -> nn.Module:
  943. return self.model.embed_tokens
  944. @can_return_tuple
  945. @auto_docstring
  946. def forward(
  947. self,
  948. input_ids: torch.Tensor | None = None,
  949. attention_mask: torch.Tensor | None = None,
  950. image_embeds: torch.Tensor | None = None,
  951. image_embeds_position_mask: torch.Tensor | None = None,
  952. encoder_hidden_states: torch.Tensor | None = None,
  953. encoder_attention_mask: torch.Tensor | None = None,
  954. past_key_values: Cache | None = None,
  955. inputs_embeds: torch.Tensor | None = None,
  956. position_ids: torch.Tensor | None = None,
  957. use_cache: bool | None = None,
  958. **kwargs: Unpack[TransformersKwargs],
  959. ) -> tuple | BaseModelOutputWithPastAndCrossAttentions:
  960. r"""
  961. image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*):
  962. Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`.
  963. image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  964. Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0,
  965. 1]`:
  966. - 1 for places where to put the image features,
  967. - 0 for places that are not for image features (i.e. for text tokens).
  968. """
  969. return self.model(
  970. input_ids=input_ids,
  971. attention_mask=attention_mask,
  972. image_embeds=image_embeds,
  973. image_embeds_position_mask=image_embeds_position_mask,
  974. encoder_hidden_states=encoder_hidden_states,
  975. encoder_attention_mask=encoder_attention_mask,
  976. past_key_values=past_key_values,
  977. inputs_embeds=inputs_embeds,
  978. position_ids=position_ids,
  979. use_cache=use_cache,
  980. **kwargs,
  981. )
  982. @auto_docstring(
  983. custom_intro="""
  984. The text model from KOSMOS-2 with a language modeling head on top (linear layer with weights tied to the input
  985. embeddings).
  986. """
  987. )
  988. class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin):
  989. config: Kosmos2TextConfig
  990. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  991. def __init__(self, config: Kosmos2TextConfig):
  992. super().__init__(config)
  993. self.model = Kosmos2TextTransformer(config)
  994. self.lm_head = nn.Linear(in_features=config.embed_dim, out_features=config.vocab_size, bias=False)
  995. # Initialize weights and apply final processing
  996. self.post_init()
  997. def get_input_embeddings(self) -> nn.Module:
  998. return self.model.embed_tokens
  999. def get_output_embeddings(self) -> nn.Module:
  1000. return self.lm_head
  1001. @can_return_tuple
  1002. @auto_docstring
  1003. def forward(
  1004. self,
  1005. input_ids: torch.Tensor | None = None,
  1006. attention_mask: torch.Tensor | None = None,
  1007. image_embeds: torch.Tensor | None = None,
  1008. image_embeds_position_mask: torch.Tensor | None = None,
  1009. encoder_hidden_states: torch.Tensor | None = None,
  1010. encoder_attention_mask: torch.Tensor | None = None,
  1011. past_key_values: Cache | None = None,
  1012. inputs_embeds: torch.Tensor | None = None,
  1013. position_ids: torch.Tensor | None = None,
  1014. labels: torch.LongTensor | None = None,
  1015. use_cache: bool | None = None,
  1016. logits_to_keep: int | torch.Tensor = 0,
  1017. **kwargs: Unpack[TransformersKwargs],
  1018. ) -> tuple | CausalLMOutputWithCrossAttentions:
  1019. r"""
  1020. image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*):
  1021. Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`.
  1022. image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1023. Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0,
  1024. 1]`:
  1025. - 1 for places where to put the image features,
  1026. - 0 for places that are not for image features (i.e. for text tokens).
  1027. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1028. Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
  1029. `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
  1030. ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  1031. """
  1032. if labels is not None:
  1033. if use_cache:
  1034. logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
  1035. use_cache = False
  1036. outputs: BaseModelOutputWithPastAndCrossAttentions = self.model(
  1037. input_ids=input_ids,
  1038. attention_mask=attention_mask,
  1039. image_embeds=image_embeds,
  1040. image_embeds_position_mask=image_embeds_position_mask,
  1041. encoder_hidden_states=encoder_hidden_states,
  1042. encoder_attention_mask=encoder_attention_mask,
  1043. past_key_values=past_key_values,
  1044. inputs_embeds=inputs_embeds,
  1045. position_ids=position_ids,
  1046. use_cache=use_cache,
  1047. **kwargs,
  1048. )
  1049. hidden_states = outputs.last_hidden_state
  1050. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  1051. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  1052. logits = self.lm_head(hidden_states[:, slice_indices, :])
  1053. loss = None
  1054. if labels is not None:
  1055. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  1056. return CausalLMOutputWithCrossAttentions(
  1057. loss=loss,
  1058. logits=logits,
  1059. past_key_values=outputs.past_key_values,
  1060. hidden_states=outputs.hidden_states,
  1061. attentions=outputs.attentions,
  1062. cross_attentions=outputs.cross_attentions,
  1063. )
  1064. def prepare_inputs_for_generation(
  1065. self,
  1066. input_ids,
  1067. image_embeds=None,
  1068. image_embeds_position_mask=None,
  1069. past_key_values=None,
  1070. attention_mask=None,
  1071. inputs_embeds=None,
  1072. use_cache=None,
  1073. is_first_iteration=False,
  1074. **model_kwargs,
  1075. ):
  1076. # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
  1077. # Pixel values are used only in the first iteration if available
  1078. # In subsequent iterations, they are already merged with text and cached
  1079. # NOTE: first iteration doesn't have to be prefill, it can be the first
  1080. # iteration with a question and cached system prompt (continue generate from cache)
  1081. if not is_first_iteration and use_cache:
  1082. image_embeds = None
  1083. image_embeds_position_mask = None
  1084. # appending `False` to `image_embeds_position_mask` (because sequence grows during generation)
  1085. elif image_embeds_position_mask is not None:
  1086. batch_size, seq_len = inputs_embeds.size()[:-1] if inputs_embeds is not None else attention_mask.size()
  1087. mask_len = image_embeds_position_mask.size()[-1]
  1088. image_embeds_position_mask = torch.cat(
  1089. (
  1090. image_embeds_position_mask,
  1091. torch.zeros(size=(batch_size, seq_len - mask_len), dtype=torch.bool, device=input_ids.device),
  1092. ),
  1093. dim=1,
  1094. )
  1095. model_inputs = super().prepare_inputs_for_generation(
  1096. input_ids,
  1097. past_key_values=past_key_values,
  1098. attention_mask=attention_mask,
  1099. image_embeds=image_embeds,
  1100. image_embeds_position_mask=image_embeds_position_mask,
  1101. inputs_embeds=inputs_embeds,
  1102. use_cache=use_cache,
  1103. is_first_iteration=is_first_iteration,
  1104. **model_kwargs,
  1105. )
  1106. # Kosmos2 has offset for position ids, so we need to create them correctly in PositionEmbedding layer
  1107. model_inputs.pop("position_ids", None)
  1108. return model_inputs
  1109. class Kosmos2ImageToTextProjection(nn.Module):
  1110. """The layer that transforms the image model's output to part of the text model's input (namely, image features)"""
  1111. def __init__(self, config: Kosmos2Config):
  1112. super().__init__()
  1113. self.dense = nn.Linear(config.vision_config.hidden_size, config.text_config.embed_dim)
  1114. self.latent_query = nn.Parameter(torch.randn(config.latent_query_num, config.text_config.embed_dim))
  1115. self.x_attn = KosmosTextAttention(
  1116. config.text_config,
  1117. config.text_config.embed_dim,
  1118. config.text_config.attention_heads,
  1119. dropout=config.text_config.attention_dropout,
  1120. is_decoder=False,
  1121. add_inner_attn_layernorm=False,
  1122. )
  1123. def forward(self, features):
  1124. hidden_states = self.dense(features)
  1125. # shape = [batch, latent_query_num, h_dim]
  1126. latent_query = self.latent_query.unsqueeze(0).expand(hidden_states.size(0), -1, -1)
  1127. key_value_states = torch.cat([hidden_states, latent_query], dim=1)
  1128. hidden_states, attn_weights = self.x_attn(
  1129. hidden_states=latent_query,
  1130. encoder_hidden_states=key_value_states,
  1131. past_key_values=None,
  1132. attention_mask=None,
  1133. output_attentions=None,
  1134. )
  1135. return hidden_states, attn_weights
  1136. @auto_docstring(
  1137. custom_intro="""
  1138. KOSMOS-2 Model for generating text and image features. The model consists of a vision encoder and a language model.
  1139. """
  1140. )
  1141. class Kosmos2Model(Kosmos2PreTrainedModel):
  1142. config: Kosmos2Config
  1143. main_input_name = "pixel_values"
  1144. def __init__(self, config: Kosmos2Config):
  1145. super().__init__(config)
  1146. self.text_model = Kosmos2TextModel(config.text_config)
  1147. self.vision_model = Kosmos2VisionModel(config.vision_config)
  1148. self.image_to_text_projection = Kosmos2ImageToTextProjection(config)
  1149. # Initialize weights and apply final processing
  1150. self.post_init()
  1151. def get_input_embeddings(self) -> nn.Module:
  1152. return self.text_model.model.embed_tokens
  1153. def set_input_embeddings(self, value):
  1154. self.text_model.model.embed_tokens = value
  1155. @can_return_tuple
  1156. @auto_docstring
  1157. def get_image_features(
  1158. self,
  1159. pixel_values: torch.FloatTensor,
  1160. interpolate_pos_encoding: bool | None = False,
  1161. **kwargs: Unpack[TransformersKwargs],
  1162. ) -> tuple | BaseModelOutputWithProjectionAttentions:
  1163. if "return_attentions" in kwargs:
  1164. warnings.warn(
  1165. "`return_attentions` is deprecated and will be removed in a future version. Please use `return_dict`"
  1166. " and access `projection_attentions` from the returned `ModelOutput` instead.",
  1167. FutureWarning,
  1168. )
  1169. kwargs.pop("return_attentions", None)
  1170. vision_output: BaseModelOutputWithProjectionAttentions = self.vision_model(
  1171. pixel_values=pixel_values,
  1172. interpolate_pos_encoding=interpolate_pos_encoding,
  1173. return_dict=True,
  1174. **kwargs,
  1175. )
  1176. # The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`.
  1177. image_embeds = self.vision_model.model.post_layernorm(vision_output[0])
  1178. # normalized features
  1179. image_embeds = nn.functional.normalize(image_embeds, dim=-1)
  1180. image_embeds, projection_attentions = self.image_to_text_projection(image_embeds)
  1181. vision_output.pooler_output = image_embeds
  1182. vision_output.projection_attentions = projection_attentions
  1183. return vision_output
  1184. @can_return_tuple
  1185. @auto_docstring
  1186. def forward(
  1187. self,
  1188. pixel_values: torch.Tensor | None = None,
  1189. input_ids: torch.Tensor | None = None,
  1190. image_embeds_position_mask: torch.Tensor | None = None,
  1191. attention_mask: torch.Tensor | None = None,
  1192. past_key_values: Cache | None = None,
  1193. image_embeds: torch.Tensor | None = None,
  1194. inputs_embeds: torch.Tensor | None = None,
  1195. position_ids: torch.Tensor | None = None,
  1196. use_cache: bool | None = None,
  1197. interpolate_pos_encoding: bool = False,
  1198. **kwargs: Unpack[TransformersKwargs],
  1199. ) -> tuple | Kosmos2ModelOutput:
  1200. r"""
  1201. image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1202. Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0,
  1203. 1]`:
  1204. - 1 for places where to put the image features,
  1205. - 0 for places that are not for image features (i.e. for text tokens).
  1206. image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*):
  1207. Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`.
  1208. Examples:
  1209. ```python
  1210. >>> from PIL import Image
  1211. >>> import httpx
  1212. >>> from io import BytesIO
  1213. >>> from transformers import AutoProcessor, Kosmos2Model
  1214. >>> model = Kosmos2Model.from_pretrained("microsoft/kosmos-2-patch14-224")
  1215. >>> processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
  1216. >>> url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
  1217. >>> with httpx.stream("GET", url) as response:
  1218. ... image = Image.open(BytesIO(response.read()))
  1219. >>> text = (
  1220. ... "<grounding> An image of<phrase> a snowman</phrase><object><patch_index_0044><patch_index_0863>"
  1221. ... "</object> warming himself by<phrase> a fire</phrase><object><patch_index_0005><patch_index_0911>"
  1222. ... "</object>"
  1223. ... )
  1224. >>> inputs = processor(text=text, images=image, return_tensors="pt", add_eos_token=True)
  1225. >>> last_hidden_state = model(
  1226. ... pixel_values=inputs["pixel_values"],
  1227. ... input_ids=inputs["input_ids"],
  1228. ... attention_mask=inputs["attention_mask"],
  1229. ... image_embeds_position_mask=inputs["image_embeds_position_mask"],
  1230. ... ).last_hidden_state
  1231. >>> list(last_hidden_state.shape)
  1232. [1, 91, 2048]
  1233. ```"""
  1234. vision_model_output = None
  1235. projection_attentions = None
  1236. if image_embeds is None:
  1237. if pixel_values is None:
  1238. raise ValueError("You have to specify either `pixel_values` or `image_embeds`.")
  1239. image_features = self.get_image_features(
  1240. pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=True, **kwargs
  1241. )
  1242. image_embeds = image_features.pooler_output
  1243. projection_attentions = image_features.projection_attentions
  1244. outputs = self.text_model(
  1245. input_ids=input_ids,
  1246. attention_mask=attention_mask,
  1247. image_embeds=image_embeds,
  1248. image_embeds_position_mask=image_embeds_position_mask,
  1249. past_key_values=past_key_values,
  1250. inputs_embeds=inputs_embeds,
  1251. position_ids=position_ids,
  1252. use_cache=use_cache,
  1253. return_dict=True,
  1254. **kwargs,
  1255. )
  1256. return Kosmos2ModelOutput(
  1257. last_hidden_state=outputs.last_hidden_state,
  1258. past_key_values=outputs.past_key_values,
  1259. hidden_states=outputs.hidden_states,
  1260. attentions=outputs.attentions,
  1261. image_embeds=image_embeds,
  1262. projection_attentions=projection_attentions,
  1263. vision_model_output=vision_model_output,
  1264. )
  1265. @auto_docstring(
  1266. custom_intro="""
  1267. KOSMOS-2 Model for generating text and bounding boxes given an image. The model consists of a vision encoder and a
  1268. language model.
  1269. """
  1270. )
  1271. class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin):
  1272. config: Kosmos2Config
  1273. main_input_name = "pixel_values"
  1274. _tied_weights_keys = {"text_model.lm_head.weight": "text_model.model.embed_tokens.weight"}
  1275. def __init__(self, config: Kosmos2Config):
  1276. super().__init__(config)
  1277. self.text_model = Kosmos2TextForCausalLM(config.text_config)
  1278. self.vision_model = Kosmos2VisionModel(config.vision_config)
  1279. self.image_to_text_projection = Kosmos2ImageToTextProjection(config)
  1280. # Initialize weights and apply final processing
  1281. self.post_init()
  1282. def get_input_embeddings(self) -> nn.Module:
  1283. return self.text_model.model.embed_tokens
  1284. def set_input_embeddings(self, value):
  1285. self.text_model.model.embed_tokens = value
  1286. def get_output_embeddings(self) -> nn.Module:
  1287. return self.text_model.get_output_embeddings()
  1288. def set_output_embeddings(self, new_embeddings):
  1289. self.text_model.set_output_embeddings(new_embeddings)
  1290. @can_return_tuple
  1291. @auto_docstring
  1292. def forward(
  1293. self,
  1294. pixel_values: torch.Tensor | None = None,
  1295. input_ids: torch.Tensor | None = None,
  1296. image_embeds_position_mask: torch.Tensor | None = None,
  1297. attention_mask: torch.Tensor | None = None,
  1298. past_key_values: Cache | None = None,
  1299. image_embeds: torch.Tensor | None = None,
  1300. inputs_embeds: torch.Tensor | None = None,
  1301. position_ids: torch.Tensor | None = None,
  1302. labels: torch.LongTensor | None = None,
  1303. use_cache: bool | None = None,
  1304. logits_to_keep: int | torch.Tensor = 0,
  1305. **kwargs: Unpack[TransformersKwargs],
  1306. ) -> tuple | Kosmos2ForConditionalGenerationModelOutput:
  1307. r"""
  1308. image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1309. Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0,
  1310. 1]`:
  1311. - 1 for places where to put the image features,
  1312. - 0 for places that are not for image features (i.e. for text tokens).
  1313. image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*):
  1314. Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`.
  1315. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1316. Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
  1317. `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
  1318. ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  1319. Examples:
  1320. ```python
  1321. >>> from PIL import Image
  1322. >>> import httpx
  1323. >>> from io import BytesIO
  1324. >>> from transformers import AutoProcessor, Kosmos2ForConditionalGeneration
  1325. >>> model = Kosmos2ForConditionalGeneration.from_pretrained("microsoft/kosmos-2-patch14-224")
  1326. >>> processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
  1327. >>> url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
  1328. >>> with httpx.stream("GET", url) as response:
  1329. ... image = Image.open(BytesIO(response.read()))
  1330. >>> prompt = "<grounding> An image of"
  1331. >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
  1332. >>> generated_ids = model.generate(
  1333. ... pixel_values=inputs["pixel_values"],
  1334. ... input_ids=inputs["input_ids"],
  1335. ... attention_mask=inputs["attention_mask"],
  1336. ... image_embeds=None,
  1337. ... image_embeds_position_mask=inputs["image_embeds_position_mask"],
  1338. ... use_cache=True,
  1339. ... max_new_tokens=64,
  1340. ... )
  1341. >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  1342. >>> processed_text = processor.post_process_generation(generated_text, cleanup_and_extract=False)
  1343. >>> processed_text
  1344. '<grounding> An image of<phrase> a snowman</phrase><object><patch_index_0044><patch_index_0863></object> warming himself by<phrase> a fire</phrase><object><patch_index_0005><patch_index_0911></object>.'
  1345. >>> caption, entities = processor.post_process_generation(generated_text)
  1346. >>> caption
  1347. 'An image of a snowman warming himself by a fire.'
  1348. >>> entities
  1349. [('a snowman', (12, 21), [(0.390625, 0.046875, 0.984375, 0.828125)]), ('a fire', (41, 47), [(0.171875, 0.015625, 0.484375, 0.890625)])]
  1350. ```"""
  1351. vision_model_output = None
  1352. projection_attentions = None
  1353. if image_embeds is None:
  1354. if pixel_values is None:
  1355. raise ValueError("You have to specify either `pixel_values` or `image_embeds`.")
  1356. vision_model_output = self.vision_model(
  1357. pixel_values=pixel_values,
  1358. )
  1359. # The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`.
  1360. image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0])
  1361. # normalized features
  1362. image_embeds = nn.functional.normalize(image_embeds, dim=-1)
  1363. image_embeds, projection_attentions = self.image_to_text_projection(image_embeds)
  1364. lm_outputs: CausalLMOutputWithCrossAttentions = self.text_model(
  1365. input_ids=input_ids,
  1366. attention_mask=attention_mask,
  1367. image_embeds=image_embeds,
  1368. image_embeds_position_mask=image_embeds_position_mask,
  1369. past_key_values=past_key_values,
  1370. inputs_embeds=inputs_embeds,
  1371. position_ids=position_ids,
  1372. labels=labels,
  1373. use_cache=use_cache,
  1374. logits_to_keep=logits_to_keep,
  1375. **kwargs,
  1376. )
  1377. return Kosmos2ForConditionalGenerationModelOutput(
  1378. loss=lm_outputs.loss,
  1379. logits=lm_outputs.logits,
  1380. past_key_values=lm_outputs.past_key_values,
  1381. hidden_states=lm_outputs.hidden_states,
  1382. attentions=lm_outputs.attentions,
  1383. image_embeds=image_embeds,
  1384. projection_attentions=projection_attentions,
  1385. vision_model_output=vision_model_output,
  1386. )
  1387. @torch.no_grad()
  1388. def generate(
  1389. self,
  1390. pixel_values: torch.Tensor | None = None,
  1391. image_embeds_position_mask: torch.Tensor | None = None,
  1392. input_ids: torch.Tensor | None = None,
  1393. attention_mask: torch.Tensor | None = None,
  1394. image_embeds: torch.Tensor | None = None,
  1395. inputs_embeds: torch.Tensor | None = None,
  1396. **kwargs,
  1397. ):
  1398. # in order to allow `inputs` argument (as in `GenerationMixin`)
  1399. inputs = kwargs.pop("inputs", None)
  1400. if pixel_values is not None and inputs is not None:
  1401. raise ValueError(
  1402. f"`inputs`: {inputs} were passed alongside `pixel_values` which is not allowed."
  1403. f"Make sure to either pass `inputs` or pixel_values=..."
  1404. )
  1405. if pixel_values is None and inputs is not None:
  1406. pixel_values = inputs
  1407. if image_embeds is None:
  1408. vision_model_output = self.vision_model(pixel_values)
  1409. # The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`.
  1410. image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0])
  1411. # normalized features
  1412. image_embeds = nn.functional.normalize(image_embeds, dim=-1)
  1413. image_embeds, projection_attentions = self.image_to_text_projection(image_embeds)
  1414. output = self.text_model.generate(
  1415. input_ids=input_ids,
  1416. attention_mask=attention_mask,
  1417. image_embeds=image_embeds,
  1418. image_embeds_position_mask=image_embeds_position_mask,
  1419. inputs_embeds=inputs_embeds,
  1420. **kwargs,
  1421. )
  1422. return output
  1423. __all__ = ["Kosmos2ForConditionalGeneration", "Kosmos2Model", "Kosmos2PreTrainedModel"]