modular_ovis2.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from dataclasses import dataclass
  16. import torch
  17. from torch import nn
  18. from ... import initialization as init
  19. from ...cache_utils import Cache
  20. from ...generation import GenerationMixin
  21. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  22. from ...modeling_utils import PreTrainedModel
  23. from ...processing_utils import Unpack
  24. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  25. from ...utils.generic import merge_with_config_defaults
  26. from ...utils.output_capturing import capture_outputs
  27. from ..aimv2.modeling_aimv2 import Aimv2Attention, Aimv2EncoderLayer
  28. from ..auto import AutoModel
  29. from ..llama.modeling_llama import LlamaMLP, LlamaRMSNorm
  30. from ..llava.modeling_llava import LlavaForConditionalGeneration, LlavaModel
  31. from ..llava_next.modeling_llava_next import LlavaNextCausalLMOutputWithPast, LlavaNextModelOutputWithPast
  32. from ..siglip.modeling_siglip import SiglipEncoder, SiglipVisionEmbeddings
  33. from .configuration_ovis2 import Ovis2Config, Ovis2VisionConfig
  34. def hard_softmax(logits: torch.Tensor, dim: int):
  35. y_soft = logits.softmax(dim)
  36. # Straight through.
  37. index = y_soft.max(dim, keepdim=True)[1]
  38. y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
  39. ret = y_hard - y_soft.detach() + y_soft
  40. return ret
  41. @dataclass
  42. @auto_docstring
  43. class BaseModelOutputWithVisualIndicatorFeatures(BaseModelOutputWithPooling):
  44. r"""
  45. visual_indicator_features (`torch.FloatTensor` of shape `(batch_size, visual_indicator_size)`):
  46. Visual indicator features extracted from the model, which can be used for auxiliary tasks or further processing.
  47. """
  48. visual_indicator_features: torch.FloatTensor | None = None
  49. class Ovis2ModelOutputWithPast(LlavaNextModelOutputWithPast):
  50. pass
  51. class Ovis2CausalLMOutputWithPast(LlavaNextCausalLMOutputWithPast):
  52. pass
  53. class Ovis2RMSNorm(LlamaRMSNorm):
  54. pass
  55. class Ovis2VisionMLP(LlamaMLP):
  56. pass
  57. class Ovis2VisionEmbeddings(SiglipVisionEmbeddings):
  58. def __init__(self, config: Ovis2VisionConfig):
  59. super().__init__(config)
  60. self.rms_norm = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
  61. def interpolate_pos_encoding(self):
  62. raise NotImplementedError("Not needed for Ovis2")
  63. def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
  64. target_dtype = self.patch_embedding.weight.dtype
  65. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
  66. embeddings = patch_embeds.flatten(2).transpose(1, 2)
  67. embeddings = self.rms_norm(embeddings)
  68. embeddings = embeddings + self.position_embedding(self.position_ids)
  69. return embeddings
  70. class Ovis2VisionAttention(Aimv2Attention):
  71. pass
  72. class Ovis2VisionEncoderLayer(Aimv2EncoderLayer):
  73. def __init__(self, config: Ovis2VisionConfig):
  74. super().__init__()
  75. self.attention = Ovis2VisionAttention(config)
  76. class Ovis2VisionEncoder(SiglipEncoder):
  77. def __init__(self, config: Ovis2VisionConfig):
  78. super().__init__(config)
  79. self.layers = nn.ModuleList([Ovis2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  80. @can_return_tuple
  81. @auto_docstring
  82. def forward(
  83. self,
  84. inputs_embeds,
  85. attention_mask: torch.Tensor | None = None,
  86. **kwargs: Unpack[TransformersKwargs],
  87. ) -> BaseModelOutput:
  88. hidden_states = inputs_embeds
  89. for encoder_layer in self.layers:
  90. hidden_states = encoder_layer(hidden_states, attention_mask, **kwargs)
  91. return BaseModelOutput(last_hidden_state=hidden_states)
  92. class Ovis2VisionTransformer(nn.Module):
  93. def __init__(self, config: Ovis2VisionConfig):
  94. super().__init__()
  95. self.config = config
  96. self.embeddings = Ovis2VisionEmbeddings(config)
  97. self.encoder = Ovis2VisionEncoder(config)
  98. self.rms_norm = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
  99. self.gradient_checkpointing = False
  100. @can_return_tuple
  101. def forward(
  102. self,
  103. pixel_values,
  104. attention_mask: torch.Tensor | None = None,
  105. **kwargs,
  106. ):
  107. hidden_states = self.embeddings(pixel_values)
  108. encoder_outputs: BaseModelOutput = self.encoder(
  109. inputs_embeds=hidden_states,
  110. attention_mask=attention_mask,
  111. **kwargs,
  112. )
  113. last_hidden_state = encoder_outputs.last_hidden_state
  114. last_hidden_state = self.rms_norm(last_hidden_state)
  115. return BaseModelOutput(last_hidden_state=last_hidden_state)
  116. class Ovis2VisualEmbeddingTable(nn.Embedding):
  117. def forward(self, visual_tokens: torch.Tensor) -> torch.Tensor:
  118. if visual_tokens.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.long]:
  119. return super().forward(visual_tokens)
  120. return torch.matmul(visual_tokens, self.weight)
  121. class Ovis2PreTrainedModel(PreTrainedModel):
  122. config: Ovis2Config
  123. base_model_prefix = "model"
  124. input_modalities = ("image", "text")
  125. supports_gradient_checkpointing = True
  126. _no_split_modules = ["Ovis2VisionAttention"]
  127. _skip_keys_device_placement = "past_key_values"
  128. _supports_cache_class = True
  129. _supports_flash_attn = True
  130. _supports_flex_attn = True
  131. _supports_sdpa = True
  132. _can_compile_fullgraph = True
  133. _supports_attention_backend = True
  134. def _init_weights(self, module):
  135. super()._init_weights(module)
  136. if isinstance(module, Ovis2VisionEmbeddings):
  137. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  138. class Ovis2VisionModel(Ovis2PreTrainedModel):
  139. config: Ovis2VisionConfig
  140. _can_record_outputs = {
  141. "hidden_states": Ovis2VisionEncoderLayer,
  142. "attentions": Ovis2VisionAttention,
  143. }
  144. def __init__(self, config: Ovis2VisionConfig):
  145. super().__init__(config)
  146. self.config = config
  147. self.transformer = Ovis2VisionTransformer(config)
  148. self.num_visual_indicator_tokens = config.num_visual_indicator_tokens
  149. self.vocab_size = config.vocab_size
  150. self.head_linear = nn.Linear(
  151. config.hidden_size * config.hidden_stride * config.hidden_stride,
  152. self.vocab_size - self.num_visual_indicator_tokens,
  153. bias=False,
  154. )
  155. self.head_norm = nn.LayerNorm(self.vocab_size - self.num_visual_indicator_tokens)
  156. self.post_init()
  157. @merge_with_config_defaults
  158. @capture_outputs
  159. def forward(
  160. self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
  161. ) -> tuple | BaseModelOutputWithVisualIndicatorFeatures:
  162. outputs = self.transformer(pixel_values, **kwargs)
  163. last_hidden_state = outputs[0]
  164. if self.config.hidden_stride > 1:
  165. num_images, seq_len, hidden_dim = last_hidden_state.shape
  166. hidden_stride = self.config.hidden_stride
  167. sqrt_l = int(math.sqrt(seq_len))
  168. if sqrt_l * sqrt_l != seq_len:
  169. raise ValueError("Token sequence length must be a perfect square")
  170. pad_size = (hidden_stride - (sqrt_l % hidden_stride)) % hidden_stride
  171. last_hidden_state = nn.functional.pad(last_hidden_state, (0, 0, 0, pad_size, 0, pad_size), "constant", 0)
  172. sqrt_l += pad_size
  173. last_hidden_state = last_hidden_state.reshape(
  174. num_images, sqrt_l // hidden_stride, hidden_stride, sqrt_l // hidden_stride, hidden_stride, hidden_dim
  175. )
  176. last_hidden_state = last_hidden_state.permute(0, 1, 3, 2, 4, 5)
  177. last_hidden_state = last_hidden_state.reshape(
  178. num_images, -1, hidden_stride * hidden_stride * hidden_dim
  179. ) # (n, (sqrt_l//hs)^2, hs^2*d)
  180. logits = self.head_linear(last_hidden_state)
  181. logits = self.head_norm(logits)
  182. if self.config.tokenize_function == "gumbel_argmax":
  183. prob_token = nn.functional.gumbel_softmax(logits, dim=-1, hard=True)
  184. elif self.config.tokenize_function == "st_argmax":
  185. prob_token = hard_softmax(logits, dim=-1)
  186. elif self.config.tokenize_function == "softmax":
  187. prob_token = nn.functional.softmax(logits, dim=-1)
  188. return BaseModelOutputWithVisualIndicatorFeatures(
  189. last_hidden_state=last_hidden_state,
  190. pooler_output=prob_token,
  191. )
  192. class Ovis2Model(LlavaModel):
  193. def __init__(self, config: Ovis2Config):
  194. super().__init__(config)
  195. self.vision_tower = Ovis2VisionModel(config.vision_config)
  196. self.visual_embeddings_table = Ovis2VisualEmbeddingTable(config.vision_config.vocab_size, config.hidden_size)
  197. self.visual_vocab_size = config.vision_config.vocab_size
  198. self.vocab_size = config.vocab_size
  199. self.visual_indicator_token_ids = config.visual_indicator_token_ids
  200. self.language_model = AutoModel.from_config(config.text_config)
  201. del self.multi_modal_projector
  202. @can_return_tuple
  203. @auto_docstring(
  204. custom_intro="Obtains image last hidden states from the vision tower and apply multimodal projection."
  205. )
  206. def get_image_features(
  207. self,
  208. pixel_values: torch.FloatTensor,
  209. **kwargs: Unpack[TransformersKwargs],
  210. ) -> tuple | BaseModelOutputWithVisualIndicatorFeatures:
  211. image_outputs = self.vision_tower(pixel_values, return_dict=True, **kwargs)
  212. image_features = image_outputs.pooler_output
  213. batch_size, img_seq_len, _ = image_features.shape
  214. padding_tensor = torch.zeros(
  215. (batch_size, img_seq_len, self.vision_tower.num_visual_indicator_tokens),
  216. dtype=image_features.dtype,
  217. device=image_features.device,
  218. requires_grad=False,
  219. layout=image_features.layout,
  220. )
  221. image_features = torch.cat([image_features, padding_tensor], dim=2)
  222. image_features = self.visual_embeddings_table(image_features)
  223. visual_indicator = torch.arange(
  224. self.visual_vocab_size - self.vision_tower.num_visual_indicator_tokens,
  225. self.visual_vocab_size,
  226. dtype=torch.long,
  227. ).to(image_features.device)
  228. image_outputs.pooler_output = image_features
  229. image_outputs.visual_indicator_features = self.visual_embeddings_table(visual_indicator)
  230. return image_outputs
  231. @can_return_tuple
  232. @auto_docstring
  233. def forward(
  234. self,
  235. input_ids: torch.LongTensor | None = None,
  236. pixel_values: torch.FloatTensor | None = None,
  237. attention_mask: torch.Tensor | None = None,
  238. position_ids: torch.LongTensor | None = None,
  239. past_key_values: Cache | None = None,
  240. inputs_embeds: torch.FloatTensor | None = None,
  241. labels: torch.LongTensor | None = None,
  242. use_cache: bool | None = None,
  243. logits_to_keep: int | torch.Tensor = 0,
  244. **kwargs,
  245. ) -> tuple | Ovis2ModelOutputWithPast:
  246. if (input_ids is None) ^ (inputs_embeds is not None):
  247. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  248. if inputs_embeds is None:
  249. inputs_embeds = self.get_input_embeddings()(input_ids)
  250. if pixel_values is not None:
  251. image_outputs = self.get_image_features(pixel_values=pixel_values, return_dict=True)
  252. image_features = image_outputs.pooler_output
  253. visual_indicator_features = image_outputs.visual_indicator_features
  254. special_image_mask = self.get_placeholder_mask(
  255. input_ids,
  256. inputs_embeds=inputs_embeds,
  257. image_features=image_features,
  258. )
  259. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  260. for i, visual_indicator_id in enumerate(self.visual_indicator_token_ids):
  261. if input_ids is None:
  262. mask = inputs_embeds == self.get_input_embeddings()(
  263. torch.tensor(visual_indicator_id, dtype=torch.long, device=inputs_embeds.device)
  264. )
  265. mask = mask.all(-1)
  266. else:
  267. mask = (input_ids == visual_indicator_id).to(inputs_embeds.device)
  268. if mask.any():
  269. inputs_embeds[mask] = (
  270. visual_indicator_features[i]
  271. .expand_as(inputs_embeds[mask])
  272. .to(inputs_embeds.device, inputs_embeds.dtype)
  273. )
  274. outputs = self.language_model(
  275. attention_mask=attention_mask,
  276. position_ids=position_ids,
  277. past_key_values=past_key_values,
  278. inputs_embeds=inputs_embeds,
  279. use_cache=use_cache,
  280. logits_to_keep=logits_to_keep,
  281. **kwargs,
  282. )
  283. return Ovis2ModelOutputWithPast(
  284. last_hidden_state=outputs.last_hidden_state,
  285. past_key_values=outputs.past_key_values,
  286. hidden_states=outputs.hidden_states,
  287. attentions=outputs.attentions,
  288. image_hidden_states=image_features if pixel_values is not None else None,
  289. )
  290. @auto_docstring
  291. class Ovis2ForConditionalGeneration(LlavaForConditionalGeneration, GenerationMixin):
  292. def __init__(self, config: Ovis2Config):
  293. super().__init__(config)
  294. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  295. @auto_docstring
  296. def get_image_features(
  297. self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
  298. ) -> tuple | BaseModelOutputWithVisualIndicatorFeatures:
  299. return self.model.get_image_features(pixel_values=pixel_values, **kwargs)
  300. @can_return_tuple
  301. @auto_docstring
  302. def forward(
  303. self,
  304. input_ids: torch.LongTensor | None = None,
  305. pixel_values: torch.FloatTensor | None = None,
  306. attention_mask: torch.Tensor | None = None,
  307. position_ids: torch.LongTensor | None = None,
  308. past_key_values: Cache | None = None,
  309. inputs_embeds: torch.FloatTensor | None = None,
  310. labels: torch.LongTensor | None = None,
  311. use_cache: bool | None = None,
  312. logits_to_keep: int | torch.Tensor = 0,
  313. **kwargs,
  314. ) -> tuple | Ovis2CausalLMOutputWithPast:
  315. r"""
  316. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  317. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  318. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  319. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  320. Example:
  321. ```python
  322. >>> from PIL import Image
  323. >>> import httpx
  324. >>> from io import BytesIO
  325. >>> from transformers import AutoProcessor, Ovis2ForConditionalGeneration
  326. >>> model = Ovis2ForConditionalGeneration.from_pretrained("thisisiron/Ovis2-2B-hf")
  327. >>> processor = AutoProcessor.from_pretrained("thisisiron/Ovis2-2B-hf")
  328. >>> prompt = "<|im_start|>user\n<image>\nDescribe the image.<|im_end|>\n<|im_start|>assistant\n"
  329. >>> url = "http://images.cocodataset.org/val2014/COCO_val2014_000000537955.jpg"
  330. >>> with httpx.stream("GET", url) as response:
  331. ... image = Image.open(BytesIO(response.read()))
  332. >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
  333. >>> # Generate
  334. >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
  335. >>> processor.batch_decode(generate_ids, skip_special_tokens=True)[0]
  336. "user\n\nDescribe the image.\nassistant\nThe image features a brown dog standing on a wooden floor, looking up with"
  337. ```"""
  338. outputs = self.model(
  339. input_ids=input_ids,
  340. pixel_values=pixel_values,
  341. attention_mask=attention_mask,
  342. position_ids=position_ids,
  343. past_key_values=past_key_values,
  344. inputs_embeds=inputs_embeds,
  345. use_cache=use_cache,
  346. **kwargs,
  347. )
  348. hidden_states = outputs[0]
  349. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  350. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  351. logits = self.lm_head(hidden_states[:, slice_indices, :])
  352. loss = None
  353. if labels is not None:
  354. loss = self.loss_function(
  355. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  356. )
  357. return Ovis2CausalLMOutputWithPast(
  358. loss=loss,
  359. logits=logits,
  360. past_key_values=outputs.past_key_values,
  361. hidden_states=outputs.hidden_states,
  362. attentions=outputs.attentions,
  363. image_hidden_states=outputs.image_hidden_states,
  364. )
  365. __all__ = ["Ovis2PreTrainedModel", "Ovis2Model", "Ovis2ForConditionalGeneration"]