modeling_llava.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. # Copyright 2023 the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Llava model."""
  15. from dataclasses import dataclass
  16. import torch
  17. from torch import nn
  18. from ...activations import ACT2FN
  19. from ...cache_utils import Cache
  20. from ...generation import GenerationMixin
  21. from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput
  22. from ...modeling_utils import PreTrainedModel
  23. from ...processing_utils import Unpack
  24. from ...utils import TransformersKwargs, auto_docstring, logging, torch_compilable_check
  25. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  26. from ..auto import AutoModel
  27. from .configuration_llava import LlavaConfig
  28. logger = logging.get_logger(__name__)
  29. @dataclass
  30. @auto_docstring(
  31. custom_intro="""
  32. Base class for Llava outputs, with hidden states and attentions.
  33. """
  34. )
  35. class LlavaModelOutputWithPast(BaseModelOutputWithPast):
  36. r"""
  37. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  38. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  39. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  40. `past_key_values` input) to speed up sequential decoding.
  41. image_hidden_states (`torch.FloatTensor`, *optional*):
  42. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  43. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  44. """
  45. image_hidden_states: torch.FloatTensor | None = None
  46. @dataclass
  47. @auto_docstring(
  48. custom_intro="""
  49. Base class for Llava causal language model (or autoregressive) outputs.
  50. """
  51. )
  52. class LlavaCausalLMOutputWithPast(ModelOutput):
  53. r"""
  54. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  55. Language modeling loss (for next-token prediction).
  56. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  57. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  58. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  59. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  60. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  61. `past_key_values` input) to speed up sequential decoding.
  62. image_hidden_states (`torch.FloatTensor`, *optional*):
  63. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  64. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  65. """
  66. loss: torch.FloatTensor | None = None
  67. logits: torch.FloatTensor | None = None
  68. past_key_values: Cache | None = None
  69. hidden_states: tuple[torch.FloatTensor] | None = None
  70. attentions: tuple[torch.FloatTensor] | None = None
  71. image_hidden_states: torch.FloatTensor | None = None
  72. class LlavaMultiModalProjector(nn.Module):
  73. def __init__(self, config: LlavaConfig):
  74. super().__init__()
  75. # We have hidden_size * the number of vision feature layers
  76. num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer)
  77. self.linear_1 = nn.Linear(
  78. config.vision_config.hidden_size * num_feature_layers,
  79. config.text_config.hidden_size,
  80. bias=config.multimodal_projector_bias,
  81. )
  82. self.act = ACT2FN[config.projector_hidden_act]
  83. self.linear_2 = nn.Linear(
  84. config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
  85. )
  86. def forward(self, image_features):
  87. hidden_states = self.linear_1(image_features)
  88. hidden_states = self.act(hidden_states)
  89. hidden_states = self.linear_2(hidden_states)
  90. return hidden_states
  91. @auto_docstring
  92. class LlavaPreTrainedModel(PreTrainedModel):
  93. config: LlavaConfig
  94. base_model_prefix = "model"
  95. input_modalities = ("image", "text")
  96. supports_gradient_checkpointing = True
  97. _skip_keys_device_placement = "past_key_values"
  98. _supports_flash_attn = True
  99. _supports_sdpa = True
  100. _can_compile_fullgraph = True
  101. _supports_flex_attn = True
  102. _supports_attention_backend = True
  103. @auto_docstring(
  104. custom_intro="""
  105. The Llava model which consists of a vision backbone and a language model, without a language modeling head.
  106. """
  107. )
  108. class LlavaModel(LlavaPreTrainedModel):
  109. def __init__(self, config: LlavaConfig):
  110. super().__init__(config)
  111. self.vision_tower = AutoModel.from_config(config.vision_config)
  112. self.multi_modal_projector = LlavaMultiModalProjector(config)
  113. self.language_model = AutoModel.from_config(config.text_config)
  114. self.post_init()
  115. def get_input_embeddings(self):
  116. return self.language_model.get_input_embeddings()
  117. def set_input_embeddings(self, value):
  118. self.language_model.set_input_embeddings(value)
  119. @merge_with_config_defaults
  120. @can_return_tuple
  121. @auto_docstring(
  122. custom_intro="Obtains image last hidden states from the vision tower and apply multimodal projection."
  123. )
  124. def get_image_features(
  125. self,
  126. pixel_values: torch.FloatTensor,
  127. vision_feature_layer: int | list[int] | list[int] | None = None,
  128. vision_feature_select_strategy: str | None = None,
  129. output_hidden_states: bool | None = None,
  130. **kwargs: Unpack[TransformersKwargs],
  131. ) -> tuple | BaseModelOutputWithPooling:
  132. kwargs = {k: v for k, v in kwargs.items() if v is not None}
  133. # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
  134. image_outputs = self.vision_tower(
  135. pixel_values,
  136. output_hidden_states=True, # Ignore arg on purpose
  137. return_dict=True,
  138. **kwargs,
  139. )
  140. # If we have one vision feature layer, return the corresponding hidden states,
  141. # otherwise, select the hidden states of each feature layer and concatenate them
  142. if isinstance(vision_feature_layer, int):
  143. selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
  144. if vision_feature_select_strategy == "default":
  145. selected_image_feature = selected_image_feature[:, 1:]
  146. else:
  147. hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
  148. # For default; crop CLS from each hidden state in the hidden state pool
  149. if vision_feature_select_strategy == "default":
  150. hs_pool = [hs[:, 1:] for hs in hs_pool]
  151. selected_image_feature = torch.cat(hs_pool, dim=-1)
  152. image_features = self.multi_modal_projector(selected_image_feature)
  153. # If image_sizes is provided, we need to split the image features accordingly,
  154. # but only if the image_sizes is not None (the default in this and related architectures)
  155. if kwargs.get("image_sizes") is not None:
  156. split_sizes = (
  157. (torch.as_tensor(kwargs["image_sizes"], device=image_features.device) // self.vision_tower.patch_size)
  158. .prod(dim=-1)
  159. .tolist()
  160. )
  161. image_features = torch.split(image_features.squeeze(0), split_sizes)
  162. else:
  163. image_features = list(image_features)
  164. image_outputs.pooler_output = image_features
  165. return image_outputs
  166. def get_placeholder_mask(
  167. self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
  168. ):
  169. """
  170. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  171. equal to the length of multimodal features. If the lengths are different, an error is raised.
  172. """
  173. if input_ids is None:
  174. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  175. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  176. )
  177. special_image_mask = special_image_mask.all(-1)
  178. else:
  179. special_image_mask = input_ids == self.config.image_token_id
  180. n_image_tokens = special_image_mask.sum()
  181. n_image_features = image_features.shape[0] * image_features.shape[1]
  182. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  183. torch_compilable_check(
  184. inputs_embeds[special_image_mask].numel() == image_features.numel(),
  185. f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}",
  186. )
  187. return special_image_mask
  188. @can_return_tuple
  189. @auto_docstring
  190. def forward(
  191. self,
  192. input_ids: torch.LongTensor | None = None,
  193. pixel_values: torch.FloatTensor | None = None,
  194. attention_mask: torch.Tensor | None = None,
  195. position_ids: torch.LongTensor | None = None,
  196. past_key_values: Cache | None = None,
  197. inputs_embeds: torch.FloatTensor | None = None,
  198. vision_feature_layer: int | list[int] | list[int] | None = None,
  199. vision_feature_select_strategy: str | None = None,
  200. image_sizes: torch.Tensor | None = None,
  201. **kwargs: Unpack[TransformersKwargs],
  202. ) -> tuple | LlavaModelOutputWithPast:
  203. if (input_ids is None) ^ (inputs_embeds is not None):
  204. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  205. if inputs_embeds is None:
  206. inputs_embeds = self.get_input_embeddings()(input_ids)
  207. if pixel_values is not None:
  208. image_features = self.get_image_features(
  209. pixel_values=pixel_values,
  210. vision_feature_layer=vision_feature_layer,
  211. vision_feature_select_strategy=vision_feature_select_strategy,
  212. image_sizes=image_sizes,
  213. return_dict=True,
  214. ).pooler_output
  215. image_features = torch.cat(image_features, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
  216. special_image_mask = self.get_placeholder_mask(
  217. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  218. )
  219. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  220. outputs = self.language_model(
  221. attention_mask=attention_mask,
  222. position_ids=position_ids,
  223. past_key_values=past_key_values,
  224. inputs_embeds=inputs_embeds,
  225. **kwargs,
  226. )
  227. return LlavaModelOutputWithPast(
  228. last_hidden_state=outputs.last_hidden_state,
  229. past_key_values=outputs.past_key_values,
  230. hidden_states=outputs.hidden_states,
  231. attentions=outputs.attentions,
  232. image_hidden_states=image_features if pixel_values is not None else None,
  233. )
  234. @auto_docstring(
  235. custom_intro="""
  236. The LLAVA model which consists of a vision backbone and a language model.
  237. """
  238. )
  239. class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
  240. _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
  241. def __init__(self, config: LlavaConfig):
  242. super().__init__(config)
  243. self.model = LlavaModel(config)
  244. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  245. self.post_init()
  246. def get_input_embeddings(self):
  247. return self.model.get_input_embeddings()
  248. def set_input_embeddings(self, value):
  249. self.model.set_input_embeddings(value)
  250. def get_output_embeddings(self) -> nn.Module:
  251. return self.lm_head
  252. @auto_docstring
  253. def get_image_features(
  254. self,
  255. pixel_values: torch.FloatTensor,
  256. vision_feature_layer: int | list[int] | list[int] | None = None,
  257. vision_feature_select_strategy: str | None = None,
  258. **kwargs: Unpack[TransformersKwargs],
  259. ) -> tuple | BaseModelOutputWithPooling:
  260. return self.model.get_image_features(
  261. pixel_values=pixel_values,
  262. vision_feature_layer=vision_feature_layer,
  263. vision_feature_select_strategy=vision_feature_select_strategy,
  264. **kwargs,
  265. )
  266. @can_return_tuple
  267. @auto_docstring
  268. def forward(
  269. self,
  270. input_ids: torch.LongTensor | None = None,
  271. pixel_values: torch.FloatTensor | None = None,
  272. attention_mask: torch.Tensor | None = None,
  273. position_ids: torch.LongTensor | None = None,
  274. past_key_values: Cache | None = None,
  275. inputs_embeds: torch.FloatTensor | None = None,
  276. vision_feature_layer: int | list[int] | list[int] | None = None,
  277. vision_feature_select_strategy: str | None = None,
  278. labels: torch.LongTensor | None = None,
  279. logits_to_keep: int | torch.Tensor = 0,
  280. image_sizes: torch.Tensor | None = None,
  281. **kwargs: Unpack[TransformersKwargs],
  282. ) -> tuple | LlavaCausalLMOutputWithPast:
  283. r"""
  284. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  285. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  286. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  287. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  288. Example:
  289. ```python
  290. >>> from PIL import Image
  291. >>> import httpx
  292. >>> from io import BytesIO
  293. >>> from transformers import AutoProcessor, LlavaForConditionalGeneration
  294. >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
  295. >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
  296. >>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
  297. >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
  298. >>> with httpx.stream("GET", url) as response:
  299. ... image = Image.open(BytesIO(response.read()))
  300. >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
  301. >>> # Generate
  302. >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
  303. >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  304. "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
  305. ```"""
  306. outputs = self.model(
  307. input_ids=input_ids,
  308. pixel_values=pixel_values,
  309. attention_mask=attention_mask,
  310. position_ids=position_ids,
  311. past_key_values=past_key_values,
  312. inputs_embeds=inputs_embeds,
  313. vision_feature_layer=vision_feature_layer,
  314. vision_feature_select_strategy=vision_feature_select_strategy,
  315. image_sizes=image_sizes,
  316. **kwargs,
  317. )
  318. hidden_states = outputs[0]
  319. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  320. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  321. logits = self.lm_head(hidden_states[:, slice_indices, :])
  322. loss = None
  323. if labels is not None:
  324. loss = self.loss_function(
  325. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  326. )
  327. return LlavaCausalLMOutputWithPast(
  328. loss=loss,
  329. logits=logits,
  330. past_key_values=outputs.past_key_values,
  331. hidden_states=outputs.hidden_states,
  332. attentions=outputs.attentions,
  333. image_hidden_states=outputs.image_hidden_states,
  334. )
  335. def prepare_inputs_for_generation(
  336. self,
  337. input_ids,
  338. past_key_values=None,
  339. inputs_embeds=None,
  340. pixel_values=None,
  341. attention_mask=None,
  342. logits_to_keep=None,
  343. is_first_iteration=False,
  344. **kwargs,
  345. ):
  346. # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
  347. model_inputs = super().prepare_inputs_for_generation(
  348. input_ids,
  349. past_key_values=past_key_values,
  350. inputs_embeds=inputs_embeds,
  351. attention_mask=attention_mask,
  352. logits_to_keep=logits_to_keep,
  353. is_first_iteration=is_first_iteration,
  354. **kwargs,
  355. )
  356. if is_first_iteration or not kwargs.get("use_cache", True):
  357. # Pixel values are used only in the first iteration if available
  358. # In subsequent iterations, they are already merged with text and cached
  359. # NOTE: first iteration doesn't have to be prefill, it can be the first
  360. # iteration with a question and cached system prompt (continue generate from cache)
  361. model_inputs["pixel_values"] = pixel_values
  362. return model_inputs
  363. __all__ = ["LlavaForConditionalGeneration", "LlavaPreTrainedModel", "LlavaModel"]