modular_colqwen2.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. # Copyright 2025 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from dataclasses import dataclass
  15. from ...cache_utils import Cache
  16. from ...feature_extraction_utils import BatchFeature
  17. from ...image_utils import ImageInput, is_valid_image
  18. from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
  19. from ...tokenization_utils_base import PreTokenizedInput, TextInput
  20. from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_torch_available, logging
  21. from ..colpali.modeling_colpali import ColPaliForRetrieval, ColPaliPreTrainedModel
  22. from ..colpali.processing_colpali import ColPaliProcessor
  23. from .configuration_colqwen2 import ColQwen2Config
  24. if is_torch_available():
  25. import torch
  26. logger = logging.get_logger(__name__)
  27. class ColQwen2ProcessorKwargs(ProcessingKwargs, total=False):
  28. _defaults = {
  29. "text_kwargs": {
  30. "padding": "longest",
  31. },
  32. "images_kwargs": {
  33. "data_format": "channels_first",
  34. "do_convert_rgb": True,
  35. },
  36. "common_kwargs": {"return_tensors": "pt"},
  37. }
  38. class ColQwen2Processor(ColPaliProcessor):
  39. def __init__(
  40. self,
  41. image_processor=None,
  42. tokenizer=None,
  43. chat_template=None,
  44. visual_prompt_prefix: str | None = None,
  45. query_prefix: str | None = None,
  46. **kwargs,
  47. ):
  48. r"""
  49. visual_prompt_prefix (`str`, *optional*, defaults to `"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>"`):
  50. A string that gets tokenized and prepended to the image tokens.
  51. query_prefix (`str`, *optional*, defaults to `"Query: "`):
  52. A prefix to be used for the query.
  53. """
  54. ProcessorMixin.__init__(self, image_processor, tokenizer, chat_template=chat_template)
  55. self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
  56. self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
  57. self.visual_prompt_prefix = visual_prompt_prefix or (
  58. "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>"
  59. )
  60. self.query_prefix = query_prefix or "Query: "
  61. def __call__(
  62. self,
  63. images: ImageInput | None = None,
  64. text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
  65. **kwargs: Unpack[ColQwen2ProcessorKwargs],
  66. ) -> BatchFeature:
  67. r"""
  68. Returns:
  69. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  70. - **input_ids** -- List of token ids to be fed to a model.
  71. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  72. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  73. `None`).
  74. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
  75. """
  76. output_kwargs = self._merge_kwargs(
  77. ColQwen2ProcessorKwargs,
  78. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  79. **kwargs,
  80. )
  81. suffix = output_kwargs["text_kwargs"].pop("suffix", None)
  82. return_token_type_ids = suffix is not None
  83. if text is None and images is None:
  84. raise ValueError("Either text or images must be provided")
  85. if text is not None and images is not None:
  86. raise ValueError("Only one of text or images can be processed at a time")
  87. if images is not None:
  88. if is_valid_image(images):
  89. images = [images]
  90. elif isinstance(images, list) and is_valid_image(images[0]):
  91. pass
  92. elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])):
  93. raise ValueError("images must be an image, list of images or list of list of images")
  94. texts_doc = [self.visual_prompt_prefix] * len(images)
  95. image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
  96. image_grid_thw = image_inputs["image_grid_thw"]
  97. if image_grid_thw is not None:
  98. merge_length = self.image_processor.merge_size**2
  99. index = 0
  100. for i in range(len(texts_doc)):
  101. while self.image_token in texts_doc[i]:
  102. texts_doc[i] = texts_doc[i].replace(
  103. self.image_token, "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), 1
  104. )
  105. index += 1
  106. texts_doc[i] = texts_doc[i].replace("<|placeholder|>", self.image_token)
  107. text_inputs = self.tokenizer(
  108. texts_doc,
  109. return_token_type_ids=False,
  110. **output_kwargs["text_kwargs"],
  111. )
  112. return_data = BatchFeature(data={**text_inputs, **image_inputs})
  113. # NOTE: The following adjustment ensures correct behavior with DDP on multiple GPUs.
  114. offsets = return_data["image_grid_thw"][:, 1] * return_data["image_grid_thw"][:, 2] # (batch_size,)
  115. # Split the pixel_values tensor into a list of tensors, one per image
  116. pixel_values = list(
  117. torch.split(return_data["pixel_values"], offsets.tolist())
  118. ) # [(num_patches_image_0, pixel_values), ..., (num_patches_image_n, pixel_values)]
  119. # Pad the list of pixel_value tensors to the same length along the sequence dimension
  120. return_data["pixel_values"] = torch.nn.utils.rnn.pad_sequence(
  121. pixel_values, batch_first=True
  122. ) # (batch_size, max_num_patches, pixel_values)
  123. if return_token_type_ids:
  124. labels = return_data["input_ids"].masked_fill(return_data["token_type_ids"] == 0, -100)
  125. return_data.update({"labels": labels})
  126. return return_data
  127. elif text is not None:
  128. if isinstance(text, str):
  129. text = [text]
  130. elif not (isinstance(text, list) and isinstance(text[0], str)):
  131. raise ValueError("Text must be a string or a list of strings")
  132. if suffix is None:
  133. suffix = self.query_augmentation_token * 10
  134. texts_query: list[str] = []
  135. for query in text:
  136. augmented_query = self.query_prefix + query + suffix
  137. texts_query.append(augmented_query)
  138. batch_query = self.tokenizer(
  139. texts_query,
  140. return_token_type_ids=False,
  141. **output_kwargs["text_kwargs"],
  142. )
  143. return batch_query
  144. def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
  145. """
  146. Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
  147. Args:
  148. image_sizes (`list[list[int]]`, *optional*):
  149. The input sizes formatted as (height, width) per each image.
  150. Returns:
  151. `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
  152. input modalities, along with other useful data.
  153. """
  154. vision_data = {}
  155. if image_sizes is not None:
  156. images_kwargs = ColQwen2ProcessorKwargs._defaults.get("images_kwargs", {})
  157. images_kwargs.update(kwargs)
  158. merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size
  159. num_image_patches = [
  160. self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
  161. for image_size in image_sizes
  162. ]
  163. num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
  164. vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
  165. return MultiModalData(**vision_data)
  166. @property
  167. def model_input_names(self):
  168. tokenizer_input_names = self.tokenizer.model_input_names
  169. image_processor_input_names = self.image_processor.model_input_names
  170. # ColQwen doesn't process videos. Make a copy of list when removing
  171. # otherwise `self.feature_extractor.model_input_names` is also modified
  172. image_processor_input_names = [
  173. name for name in image_processor_input_names if name not in ["pixel_values_videos", "video_grid_thw"]
  174. ]
  175. return tokenizer_input_names + image_processor_input_names
  176. class ColQwen2PreTrainedModel(ColPaliPreTrainedModel):
  177. pass
  178. @dataclass
  179. @auto_docstring(
  180. custom_intro="""
  181. Base class for ColQwen2 embeddings output.
  182. """
  183. )
  184. class ColQwen2ForRetrievalOutput(ModelOutput):
  185. r"""
  186. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  187. Language modeling loss (for next-token prediction).
  188. embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  189. The embeddings of the model.
  190. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  191. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  192. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  193. `past_key_values` input) to speed up sequential decoding.
  194. """
  195. loss: torch.FloatTensor | None = None
  196. embeddings: torch.Tensor | None = None
  197. past_key_values: Cache | None = None
  198. hidden_states: tuple[torch.FloatTensor] | None = None
  199. attentions: tuple[torch.FloatTensor] | None = None
  200. @auto_docstring(
  201. custom_intro="""
  202. Following the ColPali approach, ColQwen2 leverages VLMs to construct efficient multi-vector embeddings directly
  203. from document images (“screenshots”) for document retrieval. The model is trained to maximize the similarity
  204. between these document embeddings and the corresponding query embeddings, using the late interaction method
  205. introduced in ColBERT.
  206. Using ColQwen2 removes the need for potentially complex and brittle layout recognition and OCR pipelines with
  207. a single model that can take into account both the textual and visual content (layout, charts, ...) of a document.
  208. ColQwen2 is part of the ColVision model family, which was introduced with ColPali in the following paper:
  209. [*ColPali: Efficient Document Retrieval with Vision Language Models*](https://huggingface.co/papers/2407.01449).
  210. """
  211. )
  212. class ColQwen2ForRetrieval(ColPaliForRetrieval):
  213. def __init__(self, config: ColQwen2Config):
  214. super().__init__(config)
  215. del self._tied_weights_keys
  216. @can_return_tuple
  217. @auto_docstring
  218. def forward(
  219. self,
  220. input_ids: torch.LongTensor | None = None,
  221. attention_mask: torch.Tensor | None = None,
  222. position_ids: torch.LongTensor | None = None,
  223. past_key_values: Cache | None = None,
  224. labels: torch.LongTensor | None = None,
  225. inputs_embeds: torch.FloatTensor | None = None,
  226. use_cache: bool | None = None,
  227. output_attentions: bool | None = None,
  228. output_hidden_states: bool | None = None,
  229. return_dict: bool | None = None,
  230. pixel_values: torch.Tensor | None = None,
  231. image_grid_thw: torch.LongTensor | None = None,
  232. **kwargs,
  233. ) -> ColQwen2ForRetrievalOutput:
  234. r"""
  235. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
  236. The temporal, height and width of feature shape of each image in LLM.
  237. """
  238. # Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding
  239. if pixel_values is not None and image_grid_thw is not None:
  240. # NOTE: image_grid_thw: (batch_size, 3) where image_grid_thw[i] = (num_patches_h, num_patches_w, temporal_patch_size)
  241. offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2] # (batch_size,)
  242. arange = torch.arange(pixel_values.shape[1], device=offsets.device) # (max_len,)
  243. mask = arange.unsqueeze(0) < offsets.unsqueeze(1) # (batch_size, max_len)
  244. pixel_values = pixel_values[mask] # (total_valid_patches, channels, height, width)
  245. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  246. output_hidden_states = (
  247. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  248. )
  249. return_dict = return_dict if return_dict is not None else self.config.return_dict
  250. # Custom data preparation to fix an issue with the gradient flow when training with multiple GPUs.
  251. if inputs_embeds is None:
  252. inputs_embeds = self.vlm.get_input_embeddings()(input_ids)
  253. if pixel_values is not None:
  254. image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True).pooler_output
  255. image_mask = (
  256. (input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
  257. )
  258. image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
  259. inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
  260. vlm_output = self.vlm(
  261. input_ids=None,
  262. position_ids=position_ids,
  263. attention_mask=attention_mask,
  264. past_key_values=past_key_values,
  265. inputs_embeds=inputs_embeds,
  266. use_cache=use_cache,
  267. output_attentions=output_attentions,
  268. output_hidden_states=output_hidden_states,
  269. return_dict=return_dict,
  270. )
  271. vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None
  272. last_hidden_states = vlm_output[0] # (batch_size, sequence_length, hidden_size)
  273. proj_dtype = self.embedding_proj_layer.weight.dtype
  274. embeddings = self.embedding_proj_layer(last_hidden_states.to(proj_dtype)) # (batch_size, sequence_length, dim)
  275. # L2 normalization
  276. embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
  277. if attention_mask is not None:
  278. embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
  279. return ColQwen2ForRetrievalOutput(
  280. embeddings=embeddings,
  281. past_key_values=vlm_output.past_key_values,
  282. hidden_states=vlm_hidden_states,
  283. attentions=vlm_output.attentions,
  284. )
  285. __all__ = [
  286. "ColQwen2ForRetrieval",
  287. "ColQwen2PreTrainedModel",
  288. "ColQwen2Processor",
  289. ]