modeling_vipllava.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/vipllava/modular_vipllava.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_vipllava.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2023 the HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from dataclasses import dataclass
  21. import torch
  22. from torch import nn
  23. from ...activations import ACT2FN
  24. from ...cache_utils import Cache
  25. from ...generation import GenerationMixin
  26. from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput
  27. from ...modeling_utils import PreTrainedModel
  28. from ...processing_utils import Unpack
  29. from ...utils import TransformersKwargs, auto_docstring, torch_compilable_check
  30. from ...utils.generic import can_return_tuple
  31. from ..auto import AutoModel
  32. from .configuration_vipllava import VipLlavaConfig
  33. @dataclass
  34. @auto_docstring(
  35. custom_intro="""
  36. Base class for VipLlava outputs, with hidden states and attentions.
  37. """
  38. )
  39. class VipLlavaModelOutputWithPast(BaseModelOutputWithPast):
  40. r"""
  41. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  42. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  43. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  44. `past_key_values` input) to speed up sequential decoding.
  45. image_hidden_states (`torch.FloatTensor`, *optional*):
  46. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  47. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  48. """
  49. image_hidden_states: torch.FloatTensor | None = None
  50. @dataclass
  51. @auto_docstring(
  52. custom_intro="""
  53. Base class for VipLlava causal language model (or autoregressive) outputs.
  54. """
  55. )
  56. class VipLlavaCausalLMOutputWithPast(ModelOutput):
  57. r"""
  58. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  59. Language modeling loss (for next-token prediction).
  60. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  61. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  62. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  63. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  64. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  65. `past_key_values` input) to speed up sequential decoding.
  66. image_hidden_states (`torch.FloatTensor`, *optional*):
  67. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  68. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  69. """
  70. loss: torch.FloatTensor | None = None
  71. logits: torch.FloatTensor | None = None
  72. past_key_values: Cache | None = None
  73. hidden_states: tuple[torch.FloatTensor] | None = None
  74. attentions: tuple[torch.FloatTensor] | None = None
  75. image_hidden_states: torch.FloatTensor | None = None
  76. class VipLlavaMultiModalProjector(nn.Module):
  77. def __init__(self, config: VipLlavaConfig):
  78. super().__init__()
  79. num_feature_layers = 1 if isinstance(config.vision_feature_layers, int) else len(config.vision_feature_layers)
  80. self.projector_layernorm = nn.LayerNorm(
  81. num_feature_layers * config.vision_config.hidden_size, eps=config.projector_layernorm_eps
  82. )
  83. self.linear_1 = nn.Linear(
  84. num_feature_layers * config.vision_config.hidden_size,
  85. config.text_config.hidden_size,
  86. bias=True,
  87. )
  88. self.act = ACT2FN[config.projector_hidden_act]
  89. self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
  90. def forward(self, hidden_states):
  91. hidden_states = self.projector_layernorm(hidden_states)
  92. hidden_states = self.linear_1(hidden_states)
  93. hidden_states = self.act(hidden_states)
  94. hidden_states = self.linear_2(hidden_states)
  95. return hidden_states
  96. @auto_docstring
  97. class VipLlavaPreTrainedModel(PreTrainedModel):
  98. config: VipLlavaConfig
  99. base_model_prefix = "model"
  100. input_modalities = ("image", "text")
  101. supports_gradient_checkpointing = True
  102. _skip_keys_device_placement = "past_key_values"
  103. _supports_flash_attn = True
  104. _supports_sdpa = True
  105. _can_compile_fullgraph = True
  106. _supports_flex_attn = True
  107. _supports_attention_backend = True
  108. @auto_docstring(
  109. custom_intro="""
  110. The VipLlava model which consists of a vision backbone and a language model, without a language modeling head.
  111. """
  112. )
  113. class VipLlavaModel(VipLlavaPreTrainedModel):
  114. def __init__(self, config: VipLlavaConfig):
  115. super().__init__(config)
  116. self.vision_tower = AutoModel.from_config(config.vision_config)
  117. self.multi_modal_projector = VipLlavaMultiModalProjector(config)
  118. self.language_model = AutoModel.from_config(config.text_config)
  119. self.post_init()
  120. def get_input_embeddings(self):
  121. return self.language_model.get_input_embeddings()
  122. def set_input_embeddings(self, value):
  123. self.language_model.set_input_embeddings(value)
  124. @can_return_tuple
  125. @auto_docstring(
  126. custom_intro="Obtains image last hidden states from the vision tower and apply multimodal projection."
  127. )
  128. def get_image_features(
  129. self,
  130. pixel_values: torch.FloatTensor,
  131. vision_feature_layers: int | list[int] | None = None,
  132. **kwargs: Unpack[TransformersKwargs],
  133. ) -> tuple | BaseModelOutputWithPooling:
  134. r"""
  135. pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
  136. The tensors corresponding to the input images.
  137. vision_feature_layers (`Union[int, list[int]]`, *optional*):
  138. The vision feature layer, or the list of indexes of the layers to select
  139. the vision feature.
  140. """
  141. vision_feature_layers = (
  142. vision_feature_layers if vision_feature_layers is not None else self.config.vision_feature_layers
  143. )
  144. # We need hidden states to select intermediate vision features by layer index below.
  145. kwargs["output_hidden_states"] = True
  146. image_outputs = self.vision_tower(
  147. pixel_values,
  148. **kwargs,
  149. )
  150. # If multiple feature layers are provided (which is usually the case)
  151. # then the image features are concatenated after the CLS is removed.
  152. if isinstance(vision_feature_layers, int):
  153. image_features = image_outputs.hidden_states[vision_feature_layers][:, 1:]
  154. else:
  155. # Usually, we select the features from index 1: the layers -2, -5, -8, -11 and 6
  156. image_features = [image_outputs.hidden_states[index][:, 1:] for index in vision_feature_layers]
  157. image_features = torch.cat(image_features, dim=-1)
  158. image_features = self.multi_modal_projector(image_features)
  159. image_outputs.pooler_output = image_features
  160. return image_outputs
  161. def get_placeholder_mask(
  162. self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
  163. ):
  164. """
  165. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  166. equal to the length of multimodal features. If the lengths are different, an error is raised.
  167. """
  168. if input_ids is None:
  169. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  170. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  171. )
  172. special_image_mask = special_image_mask.all(-1)
  173. else:
  174. special_image_mask = input_ids == self.config.image_token_id
  175. n_image_tokens = special_image_mask.sum()
  176. n_image_features = image_features.shape[0] * image_features.shape[1]
  177. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  178. torch_compilable_check(
  179. inputs_embeds[special_image_mask].numel() == image_features.numel(),
  180. f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}",
  181. )
  182. return special_image_mask
  183. @can_return_tuple
  184. @auto_docstring
  185. def forward(
  186. self,
  187. input_ids: torch.LongTensor | None = None,
  188. pixel_values: torch.FloatTensor | None = None,
  189. attention_mask: torch.Tensor | None = None,
  190. position_ids: torch.LongTensor | None = None,
  191. past_key_values: Cache | None = None,
  192. inputs_embeds: torch.FloatTensor | None = None,
  193. vision_feature_layers: int | list[int] | None = None,
  194. use_cache: bool | None = None,
  195. **lm_kwargs: Unpack[TransformersKwargs],
  196. ) -> tuple | VipLlavaModelOutputWithPast:
  197. r"""
  198. vision_feature_layers (`Union[int, list[int]]`, *optional*):
  199. The vision feature layer, or the list of indexes of the layers to select
  200. the vision feature.
  201. """
  202. vision_feature_layers = (
  203. vision_feature_layers if vision_feature_layers is not None else self.config.vision_feature_layers
  204. )
  205. if (input_ids is None) ^ (inputs_embeds is not None):
  206. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  207. if inputs_embeds is None:
  208. inputs_embeds = self.get_input_embeddings()(input_ids)
  209. if pixel_values is not None:
  210. image_features = self.get_image_features(
  211. pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
  212. ).pooler_output
  213. image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
  214. special_image_mask = self.get_placeholder_mask(
  215. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  216. )
  217. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  218. outputs: BaseModelOutputWithPast = self.language_model(
  219. attention_mask=attention_mask,
  220. position_ids=position_ids,
  221. past_key_values=past_key_values,
  222. inputs_embeds=inputs_embeds,
  223. use_cache=use_cache,
  224. **lm_kwargs,
  225. )
  226. output = VipLlavaModelOutputWithPast(
  227. last_hidden_state=outputs.last_hidden_state,
  228. past_key_values=outputs.past_key_values,
  229. hidden_states=outputs.hidden_states,
  230. attentions=outputs.attentions,
  231. image_hidden_states=image_features if pixel_values is not None else None,
  232. )
  233. return output
  234. @auto_docstring(
  235. custom_intro="""
  236. The VIPLLAVA model which consists of a vision backbone and a language model.
  237. """
  238. )
  239. class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin):
  240. _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
  241. def __init__(self, config: VipLlavaConfig):
  242. super().__init__(config)
  243. self.model = VipLlavaModel(config)
  244. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  245. self.post_init()
  246. def get_input_embeddings(self):
  247. return self.model.get_input_embeddings()
  248. def set_input_embeddings(self, value):
  249. self.model.set_input_embeddings(value)
  250. def get_output_embeddings(self) -> nn.Module:
  251. return self.lm_head
  252. @auto_docstring
  253. def get_image_features(
  254. self,
  255. pixel_values: torch.FloatTensor,
  256. vision_feature_layers: int | list[int] | None = None,
  257. **kwargs: Unpack[TransformersKwargs],
  258. ) -> tuple | BaseModelOutputWithPooling:
  259. r"""
  260. pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
  261. The tensors corresponding to the input images.
  262. vision_feature_layers (`Union[int, list[int]]`, *optional*):
  263. The vision feature layer, or the list of indexes of the layers to select
  264. the vision feature.
  265. """
  266. return self.model.get_image_features(
  267. pixel_values=pixel_values, vision_feature_layers=vision_feature_layers, **kwargs
  268. )
  269. @can_return_tuple
  270. @auto_docstring
  271. def forward(
  272. self,
  273. input_ids: torch.LongTensor | None = None,
  274. pixel_values: torch.FloatTensor | None = None,
  275. attention_mask: torch.Tensor | None = None,
  276. position_ids: torch.LongTensor | None = None,
  277. past_key_values: Cache | None = None,
  278. inputs_embeds: torch.FloatTensor | None = None,
  279. vision_feature_layers: int | list[int] | None = None,
  280. labels: torch.LongTensor | None = None,
  281. use_cache: bool | None = None,
  282. logits_to_keep: int | torch.Tensor = 0,
  283. **lm_kwargs: Unpack[TransformersKwargs],
  284. ) -> tuple | VipLlavaCausalLMOutputWithPast:
  285. r"""
  286. vision_feature_layers (`Union[int, list[int]]`, *optional*):
  287. The vision feature layer, or the list of indexes of the layers to select
  288. the vision feature.
  289. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  290. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  291. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  292. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  293. Example:
  294. ```python
  295. >>> import torch
  296. >>> from PIL import Image
  297. >>> import httpx
  298. >>> from io import BytesIO
  299. >>> from transformers import AutoProcessor, VipLlavaForConditionalGeneration
  300. >>> model = VipLlavaForConditionalGeneration.from_pretrained("llava-hf/vip-llava-7b-hf", device_map="auto", dtype=torch.float16)
  301. >>> processor = AutoProcessor.from_pretrained("llava-hf/vip-llava-7b-hf")
  302. >>> prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.###Human: <image>\n{}###Assistant:"
  303. >>> question = "Can you please describe this image?"
  304. >>> prompt = prompt.format(question)
  305. >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/compel-neg.png"
  306. >>> with httpx.stream("GET", url) as response:
  307. ... image = Image.open(BytesIO(response.read()))
  308. >>> inputs = processor(text=text, images=image, return_tensors="pt").to(0, torch.float16)
  309. >>> # Generate
  310. >>> generate_ids = model.generate(**inputs, max_new_tokens=20)
  311. >>> processor.decode(generate_ids[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
  312. The image features a brown and white cat sitting on a green surface, with a red ball in its
  313. ```"""
  314. vision_feature_layers = (
  315. vision_feature_layers if vision_feature_layers is not None else self.config.vision_feature_layers
  316. )
  317. outputs: VipLlavaModelOutputWithPast = self.model(
  318. input_ids=input_ids,
  319. pixel_values=pixel_values,
  320. attention_mask=attention_mask,
  321. position_ids=position_ids,
  322. past_key_values=past_key_values,
  323. inputs_embeds=inputs_embeds,
  324. use_cache=use_cache,
  325. vision_feature_layers=vision_feature_layers,
  326. **lm_kwargs,
  327. )
  328. hidden_states = outputs.last_hidden_state
  329. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  330. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  331. logits = self.lm_head(hidden_states[:, slice_indices, :])
  332. loss = None
  333. if labels is not None:
  334. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
  335. return VipLlavaCausalLMOutputWithPast(
  336. loss=loss,
  337. logits=logits,
  338. past_key_values=outputs.past_key_values,
  339. hidden_states=outputs.hidden_states,
  340. attentions=outputs.attentions,
  341. image_hidden_states=outputs.image_hidden_states,
  342. )
  343. def prepare_inputs_for_generation(
  344. self,
  345. input_ids,
  346. past_key_values=None,
  347. inputs_embeds=None,
  348. pixel_values=None,
  349. attention_mask=None,
  350. logits_to_keep=None,
  351. is_first_iteration=False,
  352. **kwargs,
  353. ):
  354. # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
  355. model_inputs = super().prepare_inputs_for_generation(
  356. input_ids,
  357. past_key_values=past_key_values,
  358. inputs_embeds=inputs_embeds,
  359. attention_mask=attention_mask,
  360. logits_to_keep=logits_to_keep,
  361. is_first_iteration=is_first_iteration,
  362. **kwargs,
  363. )
  364. if is_first_iteration or not kwargs.get("use_cache", True):
  365. # Pixel values are used only in the first iteration if available
  366. # In subsequent iterations, they are already merged with text and cached
  367. # NOTE: first iteration doesn't have to be prefill, it can be the first
  368. # iteration with a question and cached system prompt (continue generate from cache)
  369. model_inputs["pixel_values"] = pixel_values
  370. return model_inputs
  371. __all__ = ["VipLlavaModel", "VipLlavaForConditionalGeneration", "VipLlavaPreTrainedModel"]