modular_vipllava.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. # Copyright 2023 the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. from torch import nn
  16. from transformers.models.llava.modeling_llava import (
  17. LlavaCausalLMOutputWithPast,
  18. LlavaForConditionalGeneration,
  19. LlavaModel,
  20. LlavaModelOutputWithPast,
  21. LlavaPreTrainedModel,
  22. )
  23. from ...activations import ACT2FN
  24. from ...cache_utils import Cache
  25. from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
  26. from ...processing_utils import Unpack
  27. from ...utils import TransformersKwargs, auto_docstring, logging
  28. from ...utils.generic import can_return_tuple
  29. from .configuration_vipllava import VipLlavaConfig
  30. logger = logging.get_logger(__name__)
  31. class VipLlavaModelOutputWithPast(LlavaModelOutputWithPast):
  32. pass
  33. class VipLlavaCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
  34. pass
  35. class VipLlavaMultiModalProjector(nn.Module):
  36. def __init__(self, config: VipLlavaConfig):
  37. super().__init__()
  38. num_feature_layers = 1 if isinstance(config.vision_feature_layers, int) else len(config.vision_feature_layers)
  39. self.projector_layernorm = nn.LayerNorm(
  40. num_feature_layers * config.vision_config.hidden_size, eps=config.projector_layernorm_eps
  41. )
  42. self.linear_1 = nn.Linear(
  43. num_feature_layers * config.vision_config.hidden_size,
  44. config.text_config.hidden_size,
  45. bias=True,
  46. )
  47. self.act = ACT2FN[config.projector_hidden_act]
  48. self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
  49. def forward(self, hidden_states):
  50. hidden_states = self.projector_layernorm(hidden_states)
  51. hidden_states = self.linear_1(hidden_states)
  52. hidden_states = self.act(hidden_states)
  53. hidden_states = self.linear_2(hidden_states)
  54. return hidden_states
  55. class VipLlavaPreTrainedModel(LlavaPreTrainedModel):
  56. pass
  57. class VipLlavaModel(LlavaModel):
  58. @can_return_tuple
  59. @auto_docstring(
  60. custom_intro="Obtains image last hidden states from the vision tower and apply multimodal projection."
  61. )
  62. def get_image_features(
  63. self,
  64. pixel_values: torch.FloatTensor,
  65. vision_feature_layers: int | list[int] | None = None,
  66. **kwargs: Unpack[TransformersKwargs],
  67. ) -> tuple | BaseModelOutputWithPooling:
  68. r"""
  69. pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
  70. The tensors corresponding to the input images.
  71. vision_feature_layers (`Union[int, list[int]]`, *optional*):
  72. The vision feature layer, or the list of indexes of the layers to select
  73. the vision feature.
  74. """
  75. vision_feature_layers = (
  76. vision_feature_layers if vision_feature_layers is not None else self.config.vision_feature_layers
  77. )
  78. # We need hidden states to select intermediate vision features by layer index below.
  79. kwargs["output_hidden_states"] = True
  80. image_outputs = self.vision_tower(
  81. pixel_values,
  82. **kwargs,
  83. )
  84. # If multiple feature layers are provided (which is usually the case)
  85. # then the image features are concatenated after the CLS is removed.
  86. if isinstance(vision_feature_layers, int):
  87. image_features = image_outputs.hidden_states[vision_feature_layers][:, 1:]
  88. else:
  89. # Usually, we select the features from index 1: the layers -2, -5, -8, -11 and 6
  90. image_features = [image_outputs.hidden_states[index][:, 1:] for index in vision_feature_layers]
  91. image_features = torch.cat(image_features, dim=-1)
  92. image_features = self.multi_modal_projector(image_features)
  93. image_outputs.pooler_output = image_features
  94. return image_outputs
  95. @can_return_tuple
  96. @auto_docstring
  97. def forward(
  98. self,
  99. input_ids: torch.LongTensor | None = None,
  100. pixel_values: torch.FloatTensor | None = None,
  101. attention_mask: torch.Tensor | None = None,
  102. position_ids: torch.LongTensor | None = None,
  103. past_key_values: Cache | None = None,
  104. inputs_embeds: torch.FloatTensor | None = None,
  105. vision_feature_layers: int | list[int] | None = None,
  106. use_cache: bool | None = None,
  107. **lm_kwargs: Unpack[TransformersKwargs],
  108. ) -> tuple | VipLlavaModelOutputWithPast:
  109. r"""
  110. vision_feature_layers (`Union[int, list[int]]`, *optional*):
  111. The vision feature layer, or the list of indexes of the layers to select
  112. the vision feature.
  113. """
  114. vision_feature_layers = (
  115. vision_feature_layers if vision_feature_layers is not None else self.config.vision_feature_layers
  116. )
  117. if (input_ids is None) ^ (inputs_embeds is not None):
  118. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  119. if inputs_embeds is None:
  120. inputs_embeds = self.get_input_embeddings()(input_ids)
  121. if pixel_values is not None:
  122. image_features = self.get_image_features(
  123. pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
  124. ).pooler_output
  125. image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
  126. special_image_mask = self.get_placeholder_mask(
  127. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  128. )
  129. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  130. outputs: BaseModelOutputWithPast = self.language_model(
  131. attention_mask=attention_mask,
  132. position_ids=position_ids,
  133. past_key_values=past_key_values,
  134. inputs_embeds=inputs_embeds,
  135. use_cache=use_cache,
  136. **lm_kwargs,
  137. )
  138. output = VipLlavaModelOutputWithPast(
  139. last_hidden_state=outputs.last_hidden_state,
  140. past_key_values=outputs.past_key_values,
  141. hidden_states=outputs.hidden_states,
  142. attentions=outputs.attentions,
  143. image_hidden_states=image_features if pixel_values is not None else None,
  144. )
  145. return output
  146. class VipLlavaForConditionalGeneration(LlavaForConditionalGeneration):
  147. @auto_docstring
  148. def get_image_features(
  149. self,
  150. pixel_values: torch.FloatTensor,
  151. vision_feature_layers: int | list[int] | None = None,
  152. **kwargs: Unpack[TransformersKwargs],
  153. ) -> tuple | BaseModelOutputWithPooling:
  154. r"""
  155. pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
  156. The tensors corresponding to the input images.
  157. vision_feature_layers (`Union[int, list[int]]`, *optional*):
  158. The vision feature layer, or the list of indexes of the layers to select
  159. the vision feature.
  160. """
  161. return self.model.get_image_features(
  162. pixel_values=pixel_values, vision_feature_layers=vision_feature_layers, **kwargs
  163. )
  164. @can_return_tuple
  165. @auto_docstring
  166. def forward(
  167. self,
  168. input_ids: torch.LongTensor | None = None,
  169. pixel_values: torch.FloatTensor | None = None,
  170. attention_mask: torch.Tensor | None = None,
  171. position_ids: torch.LongTensor | None = None,
  172. past_key_values: Cache | None = None,
  173. inputs_embeds: torch.FloatTensor | None = None,
  174. vision_feature_layers: int | list[int] | None = None,
  175. labels: torch.LongTensor | None = None,
  176. use_cache: bool | None = None,
  177. logits_to_keep: int | torch.Tensor = 0,
  178. **lm_kwargs: Unpack[TransformersKwargs],
  179. ) -> tuple | VipLlavaCausalLMOutputWithPast:
  180. r"""
  181. vision_feature_layers (`Union[int, list[int]]`, *optional*):
  182. The vision feature layer, or the list of indexes of the layers to select
  183. the vision feature.
  184. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  185. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  186. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  187. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  188. Example:
  189. ```python
  190. >>> import torch
  191. >>> from PIL import Image
  192. >>> import httpx
  193. >>> from io import BytesIO
  194. >>> from transformers import AutoProcessor, VipLlavaForConditionalGeneration
  195. >>> model = VipLlavaForConditionalGeneration.from_pretrained("llava-hf/vip-llava-7b-hf", device_map="auto", dtype=torch.float16)
  196. >>> processor = AutoProcessor.from_pretrained("llava-hf/vip-llava-7b-hf")
  197. >>> prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.###Human: <image>\n{}###Assistant:"
  198. >>> question = "Can you please describe this image?"
  199. >>> prompt = prompt.format(question)
  200. >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/compel-neg.png"
  201. >>> with httpx.stream("GET", url) as response:
  202. ... image = Image.open(BytesIO(response.read()))
  203. >>> inputs = processor(text=text, images=image, return_tensors="pt").to(0, torch.float16)
  204. >>> # Generate
  205. >>> generate_ids = model.generate(**inputs, max_new_tokens=20)
  206. >>> processor.decode(generate_ids[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
  207. The image features a brown and white cat sitting on a green surface, with a red ball in its
  208. ```"""
  209. vision_feature_layers = (
  210. vision_feature_layers if vision_feature_layers is not None else self.config.vision_feature_layers
  211. )
  212. outputs: VipLlavaModelOutputWithPast = self.model(
  213. input_ids=input_ids,
  214. pixel_values=pixel_values,
  215. attention_mask=attention_mask,
  216. position_ids=position_ids,
  217. past_key_values=past_key_values,
  218. inputs_embeds=inputs_embeds,
  219. use_cache=use_cache,
  220. vision_feature_layers=vision_feature_layers,
  221. **lm_kwargs,
  222. )
  223. hidden_states = outputs.last_hidden_state
  224. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  225. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  226. logits = self.lm_head(hidden_states[:, slice_indices, :])
  227. loss = None
  228. if labels is not None:
  229. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
  230. return VipLlavaCausalLMOutputWithPast(
  231. loss=loss,
  232. logits=logits,
  233. past_key_values=outputs.past_key_values,
  234. hidden_states=outputs.hidden_states,
  235. attentions=outputs.attentions,
  236. image_hidden_states=outputs.image_hidden_states,
  237. )
  238. __all__ = ["VipLlavaModel", "VipLlavaForConditionalGeneration", "VipLlavaPreTrainedModel"]