modeling_fuyu.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. # Copyright 2023 HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Fuyu model."""
  15. import torch
  16. from torch import nn
  17. from ...cache_utils import Cache
  18. from ...generation import GenerationMixin
  19. from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast
  20. from ...modeling_utils import PreTrainedModel
  21. from ...models.auto.modeling_auto import AutoModel
  22. from ...processing_utils import Unpack
  23. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check
  24. from .configuration_fuyu import FuyuConfig
  25. logger = logging.get_logger(__name__)
  26. @auto_docstring
  27. class FuyuPreTrainedModel(PreTrainedModel):
  28. config: FuyuConfig
  29. base_model_prefix = "model"
  30. input_modalities = ("image", "text")
  31. supports_gradient_checkpointing = True
  32. _supports_attention_backend = True
  33. _supports_flash_attn = True
  34. _supports_sdpa = True
  35. _supports_flex_attn = True
  36. _no_split_modules = []
  37. _skip_keys_device_placement = "past_key_values"
  38. @auto_docstring(
  39. custom_intro="""
  40. The Fuyu model which consists of a vision backbone and a language model, without a language modeling head.
  41. """
  42. )
  43. class FuyuModel(FuyuPreTrainedModel):
  44. def __init__(self, config: FuyuConfig):
  45. super().__init__(config)
  46. self.padding_idx = config.pad_token_id
  47. self.vocab_size = config.text_config.vocab_size
  48. self.language_model = AutoModel.from_config(config.text_config)
  49. self.vision_embed_tokens = nn.Linear(
  50. config.patch_size * config.patch_size * config.num_channels, config.hidden_size
  51. )
  52. self.gradient_checkpointing = False
  53. # Initialize weights and apply final processing
  54. self.post_init()
  55. def get_input_embeddings(self):
  56. return self.language_model.get_input_embeddings()
  57. def set_input_embeddings(self, value):
  58. self.language_model.set_input_embeddings(value)
  59. def gather_continuous_embeddings(
  60. self,
  61. word_embeddings: torch.Tensor,
  62. continuous_embeddings: list[torch.Tensor],
  63. image_patch_input_indices: torch.Tensor,
  64. ) -> torch.Tensor:
  65. """This function places the continuous_embeddings into the word_embeddings at the locations
  66. indicated by image_patch_input_indices. Different batch elements can have different numbers of continuous
  67. embeddings.
  68. Args:
  69. word_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  70. Tensor of word embeddings.
  71. continuous_embeddings (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`):
  72. Tensor of continuous embeddings. The length of the list is the batch size. Each entry is shape
  73. [num_image_embeddings, hidden], and num_image_embeddings needs to match the number of non-negative
  74. indices in image_patch_input_indices for that batch element.
  75. image_patch_input_indices (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  76. Tensor of indices of the image patches in the input_ids tensor.
  77. """
  78. if not (word_embeddings.shape[0] == len(continuous_embeddings)):
  79. raise ValueError(
  80. f"Batch sizes must match! Got {len(continuous_embeddings)=} and {word_embeddings.shape[0]=}"
  81. )
  82. output_embeddings = word_embeddings.clone()
  83. for batch_idx in range(word_embeddings.shape[0]):
  84. # First, find the positions of all the non-negative values in image_patch_input_indices, those are the
  85. # positions in word_embeddings that we want to replace with content from continuous_embeddings.
  86. dst_indices = torch.nonzero(image_patch_input_indices[batch_idx] >= 0, as_tuple=True)[0]
  87. # Next look up those indices in image_patch_input_indices to find the indices in continuous_embeddings that we
  88. # want to use to replace the values in word_embeddings.
  89. src_indices = image_patch_input_indices[batch_idx][dst_indices]
  90. # Check if we have more indices than embeddings. Note that we could have fewer indices if images got truncated.
  91. if src_indices.shape[0] > continuous_embeddings[batch_idx].shape[0]:
  92. raise ValueError(
  93. f"Number of continuous embeddings {continuous_embeddings[batch_idx].shape=} does not match "
  94. f"number of continuous token ids {src_indices.shape=} in batch element {batch_idx}."
  95. )
  96. output_embeddings[batch_idx, dst_indices] = continuous_embeddings[batch_idx][src_indices].to(
  97. output_embeddings.device
  98. )
  99. return output_embeddings
  100. @can_return_tuple
  101. @auto_docstring
  102. def get_image_features(
  103. self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
  104. ) -> tuple | BaseModelOutputWithPooling:
  105. r"""
  106. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  107. The tensors corresponding to the input images.
  108. """
  109. patch_embeddings = self.vision_embed_tokens(pixel_values)
  110. return BaseModelOutputWithPooling(last_hidden_state=patch_embeddings)
  111. def get_placeholder_mask(
  112. self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
  113. ):
  114. """
  115. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  116. equal to the length of multimodal features. If the lengths are different, an error is raised.
  117. """
  118. if input_ids is None:
  119. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  120. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  121. )
  122. special_image_mask = special_image_mask.all(-1)
  123. else:
  124. special_image_mask = input_ids == self.config.image_token_id
  125. n_image_tokens = special_image_mask.sum()
  126. n_image_features = image_features.shape[0] * image_features.shape[1]
  127. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  128. torch_compilable_check(
  129. inputs_embeds[special_image_mask].numel() == image_features.numel(),
  130. f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}",
  131. )
  132. return special_image_mask
  133. @can_return_tuple
  134. @auto_docstring
  135. def forward(
  136. self,
  137. input_ids: torch.LongTensor | None = None,
  138. # [batch_size, num_total_patches, patch_size_ x patch_size x num_channels ]
  139. image_patches: torch.Tensor | None = None,
  140. image_patches_indices: torch.Tensor | None = None,
  141. attention_mask: torch.Tensor | None = None,
  142. position_ids: torch.LongTensor | None = None,
  143. past_key_values: Cache | None = None,
  144. inputs_embeds: torch.FloatTensor | None = None,
  145. use_cache: bool | None = None,
  146. **kwargs: Unpack[TransformersKwargs],
  147. ) -> tuple | CausalLMOutputWithPast:
  148. r"""
  149. image_patches (`torch.FloatTensor` of shape `(batch_size, num_total_patches, patch_size_ x patch_size x num_channels)`, *optional*):
  150. Image patches to be used as continuous embeddings. The patches are flattened and then projected to the
  151. hidden size of the model.
  152. image_patches_indices (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  153. Tensor of indices of the image patches in the input_ids tensor.
  154. """
  155. if (input_ids is None) ^ (inputs_embeds is not None):
  156. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  157. if inputs_embeds is None:
  158. inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
  159. seq_len = inputs_embeds.shape[1]
  160. if position_ids is None:
  161. device = input_ids.device if input_ids is not None else inputs_embeds.device
  162. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  163. position_ids = torch.arange(
  164. past_key_values_length, seq_len + past_key_values_length, dtype=torch.long, device=device
  165. )
  166. position_ids = position_ids.unsqueeze(0)
  167. if image_patches is not None:
  168. patch_embeddings = self.get_image_features(image_patches, return_dict=True).last_hidden_state
  169. patch_embeddings = patch_embeddings.to(inputs_embeds.device, inputs_embeds.dtype)
  170. special_image_mask = self.get_placeholder_mask(
  171. input_ids, inputs_embeds=inputs_embeds, image_features=patch_embeddings
  172. )
  173. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, patch_embeddings)
  174. outputs = self.language_model(
  175. inputs_embeds=inputs_embeds,
  176. attention_mask=attention_mask,
  177. position_ids=position_ids,
  178. past_key_values=past_key_values,
  179. use_cache=use_cache,
  180. **kwargs,
  181. )
  182. return outputs
  183. @auto_docstring(
  184. custom_intro="""
  185. Fuyu Model with a language modeling head on top for causal language model conditioned on image patches and text.
  186. """
  187. )
  188. class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
  189. _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
  190. def __init__(self, config: FuyuConfig):
  191. super().__init__(config)
  192. self.model = FuyuModel(config)
  193. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  194. self.post_init()
  195. def get_input_embeddings(self):
  196. return self.model.get_input_embeddings()
  197. def set_input_embeddings(self, value):
  198. self.model.set_input_embeddings(value)
  199. @can_return_tuple
  200. @auto_docstring
  201. def forward(
  202. self,
  203. input_ids: torch.LongTensor | None = None,
  204. # [batch_size, num_total_patches, patch_size_ x patch_size x num_channels ]
  205. image_patches: torch.Tensor | None = None,
  206. image_patches_indices: torch.Tensor | None = None,
  207. attention_mask: torch.Tensor | None = None,
  208. position_ids: torch.LongTensor | None = None,
  209. past_key_values: Cache | None = None,
  210. inputs_embeds: torch.FloatTensor | None = None,
  211. use_cache: bool | None = None,
  212. labels: torch.Tensor | None = None,
  213. logits_to_keep: int | None = 0,
  214. **kwargs: Unpack[TransformersKwargs],
  215. ) -> tuple | CausalLMOutputWithPast:
  216. r"""
  217. image_patches (`torch.FloatTensor` of shape `(batch_size, num_total_patches, patch_size_ x patch_size x num_channels)`, *optional*):
  218. Image patches to be used as continuous embeddings. The patches are flattened and then projected to the
  219. hidden size of the model.
  220. image_patches_indices (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  221. Tensor of indices of the image patches in the input_ids tensor.
  222. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  223. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  224. config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  225. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
  226. Examples:
  227. ```python
  228. >>> from transformers import FuyuProcessor, FuyuForCausalLM
  229. >>> from PIL import Image
  230. >>> import httpx
  231. >>> from io import BytesIO
  232. >>> processor = FuyuProcessor.from_pretrained("adept/fuyu-8b")
  233. >>> model = FuyuForCausalLM.from_pretrained("adept/fuyu-8b")
  234. >>> url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png"
  235. >>> with httpx.stream("GET", url) as response:
  236. ... image = Image.open(BytesIO(response.read()))
  237. >>> prompt = "Generate a coco-style caption.\n"
  238. >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
  239. >>> outputs = model(**inputs)
  240. >>> generated_ids = model.generate(**inputs, max_new_tokens=7)
  241. >>> generation_text = processor.batch_decode(generated_ids[:, -7:], skip_special_tokens=True)
  242. >>> print(generation_text[0])
  243. A blue bus parked on the side of a road.
  244. ```"""
  245. outputs = self.model(
  246. input_ids=input_ids,
  247. image_patches=image_patches,
  248. image_patches_indices=image_patches_indices,
  249. inputs_embeds=inputs_embeds,
  250. attention_mask=attention_mask,
  251. position_ids=position_ids,
  252. past_key_values=past_key_values,
  253. use_cache=use_cache,
  254. **kwargs,
  255. )
  256. hidden_states = outputs[0]
  257. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  258. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  259. logits = self.lm_head(hidden_states[:, slice_indices, :])
  260. loss = None
  261. if labels is not None:
  262. loss = self.loss_function(
  263. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  264. )
  265. return CausalLMOutputWithPast(
  266. loss=loss,
  267. logits=logits,
  268. past_key_values=outputs.past_key_values,
  269. hidden_states=outputs.hidden_states,
  270. attentions=outputs.attentions,
  271. )
  272. def prepare_inputs_for_generation(
  273. self,
  274. input_ids,
  275. past_key_values=None,
  276. attention_mask=None,
  277. inputs_embeds=None,
  278. image_patches=None,
  279. image_patches_indices=None,
  280. is_first_iteration=False,
  281. **kwargs,
  282. ):
  283. # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
  284. model_inputs = super().prepare_inputs_for_generation(
  285. input_ids,
  286. past_key_values=past_key_values,
  287. attention_mask=attention_mask,
  288. inputs_embeds=inputs_embeds,
  289. image_patches=image_patches,
  290. image_patches_indices=image_patches_indices,
  291. is_first_iteration=is_first_iteration,
  292. **kwargs,
  293. )
  294. if not is_first_iteration and kwargs.get("use_cache", True):
  295. # set image_patches and image_patches_indices to `None` for decoding stage
  296. model_inputs["image_patches_indices"] = None
  297. model_inputs["image_patches"] = None
  298. return model_inputs
  299. __all__ = ["FuyuForCausalLM", "FuyuPreTrainedModel", "FuyuModel"]