modular_colpali.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. # Copyright 2024 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Optional, Union
  15. from transformers.models.paligemma.processing_paligemma import IMAGE_TOKEN, PaliGemmaProcessor, build_string_from_input
  16. from ...feature_extraction_utils import BatchFeature
  17. from ...image_utils import ImageInput, make_flat_list_of_images
  18. from ...processing_utils import ProcessingKwargs, Unpack
  19. from ...tokenization_utils_base import PreTokenizedInput, TextInput
  20. from ...utils import is_torch_available, logging
  21. if is_torch_available():
  22. import torch
  23. logger = logging.get_logger(__name__)
  24. class ColPaliProcessorKwargs(ProcessingKwargs, total=False):
  25. _defaults = {
  26. "text_kwargs": {
  27. "padding": "longest",
  28. },
  29. "images_kwargs": {
  30. "data_format": "channels_first",
  31. "do_convert_rgb": True,
  32. },
  33. "common_kwargs": {"return_tensors": "pt"},
  34. }
  35. class ColPaliProcessor(PaliGemmaProcessor):
  36. def __init__(
  37. self,
  38. image_processor=None,
  39. tokenizer=None,
  40. chat_template=None,
  41. visual_prompt_prefix: str = "Describe the image.",
  42. query_prefix: str = "Question: ",
  43. ):
  44. r"""
  45. visual_prompt_prefix (`str`, *optional*, defaults to `"Describe the image."`):
  46. A string that gets tokenized and prepended to the image tokens.
  47. query_prefix (`str`, *optional*, defaults to `"Question: "`):
  48. A prefix to be used for the query.
  49. """
  50. self.visual_prompt_prefix = visual_prompt_prefix
  51. self.query_prefix = query_prefix
  52. super().__init__(image_processor=image_processor, tokenizer=tokenizer, chat_template=chat_template)
  53. @property
  54. def query_augmentation_token(self) -> str:
  55. """
  56. Return the query augmentation token.
  57. Query augmentation buffers are used as reasoning buffers during inference.
  58. """
  59. return self.tokenizer.pad_token
  60. def __call__(
  61. self,
  62. images: ImageInput | None = None,
  63. text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
  64. **kwargs: Unpack[ColPaliProcessorKwargs],
  65. ) -> BatchFeature:
  66. r"""
  67. Returns:
  68. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  69. - **input_ids** -- List of token ids to be fed to a model.
  70. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  71. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  72. `None`).
  73. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
  74. """
  75. output_kwargs = self._merge_kwargs(
  76. ColPaliProcessorKwargs,
  77. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  78. **kwargs,
  79. )
  80. suffix = output_kwargs["text_kwargs"].pop("suffix", None)
  81. return_token_type_ids = True
  82. if text is None and images is None:
  83. raise ValueError("Either text or images must be provided")
  84. if text is not None and images is not None:
  85. raise ValueError("Only one of text or images can be processed at a time")
  86. if images is not None:
  87. images = self.image_processor.fetch_images(images)
  88. images = make_flat_list_of_images(images)
  89. texts_doc = [self.visual_prompt_prefix] * len(images)
  90. images = [image.convert("RGB") for image in images]
  91. input_strings = [
  92. build_string_from_input(
  93. prompt=prompt,
  94. bos_token=self.tokenizer.bos_token,
  95. image_seq_len=self.image_seq_length,
  96. image_token=IMAGE_TOKEN,
  97. num_images=len(image_list) if isinstance(image_list, list) else 1,
  98. )
  99. for prompt, image_list in zip(texts_doc, images)
  100. ]
  101. pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
  102. # max_length has to account for the image tokens
  103. if output_kwargs["text_kwargs"].get("max_length", None) is not None:
  104. output_kwargs["text_kwargs"]["max_length"] += self.image_seq_length
  105. inputs = self.tokenizer(
  106. input_strings,
  107. return_token_type_ids=return_token_type_ids,
  108. **output_kwargs["text_kwargs"],
  109. )
  110. return_data = {**inputs, "pixel_values": pixel_values}
  111. if return_token_type_ids:
  112. labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100)
  113. return_data.update({"labels": labels})
  114. return BatchFeature(data=return_data)
  115. elif text is not None:
  116. if isinstance(text, str):
  117. text = [text]
  118. elif not (isinstance(text, list) and isinstance(text[0], str)):
  119. raise ValueError("Text must be a string or a list of strings")
  120. if suffix is None:
  121. suffix = self.query_augmentation_token * 10
  122. texts_query: list[str] = []
  123. for query in text:
  124. query = self.tokenizer.bos_token + self.query_prefix + query + suffix + "\n"
  125. texts_query.append(query)
  126. output_kwargs["text_kwargs"]["max_length"] = output_kwargs["text_kwargs"].get("max_length", 50)
  127. batch_query = self.tokenizer(
  128. texts_query,
  129. return_token_type_ids=return_token_type_ids,
  130. **output_kwargs["text_kwargs"],
  131. )
  132. return batch_query
  133. def process_images(
  134. self,
  135. images: ImageInput | None = None,
  136. **kwargs: Unpack[ColPaliProcessorKwargs],
  137. ) -> BatchFeature:
  138. """
  139. Prepare for the model one or several image(s). This method is a wrapper around the `__call__` method of the ColPaliProcessor's
  140. [`ColPaliProcessor.__call__`].
  141. This method forwards the `images` and `kwargs` arguments to the image processor.
  142. Args:
  143. images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
  144. The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
  145. tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
  146. number of channels, H and W are image height and width.
  147. return_tensors (`str` or [`~utils.TensorType`], *optional*):
  148. If set, will return tensors of a particular framework. Acceptable values are:
  149. - `'pt'`: Return PyTorch `torch.Tensor` objects.
  150. - `'np'`: Return NumPy `np.ndarray` objects.
  151. Returns:
  152. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  153. - **input_ids** -- List of token ids to be fed to a model.
  154. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  155. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  156. `None`).
  157. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
  158. """
  159. return self.__call__(images=images, **kwargs)
  160. def process_queries(
  161. self,
  162. text: TextInput | list[TextInput],
  163. **kwargs: Unpack[ColPaliProcessorKwargs],
  164. ) -> BatchFeature:
  165. """
  166. Prepare for the model one or several texts. This method is a wrapper around the `__call__` method of the ColPaliProcessor's
  167. [`ColPaliProcessor.__call__`].
  168. This method forwards the `text` and `kwargs` arguments to the tokenizer.
  169. Args:
  170. text (`str`, `list[str]`, `list[list[str]]`):
  171. The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
  172. (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
  173. `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
  174. return_tensors (`str` or [`~utils.TensorType`], *optional*):
  175. If set, will return tensors of a particular framework. Acceptable values are:
  176. - `'pt'`: Return PyTorch `torch.Tensor` objects.
  177. - `'np'`: Return NumPy `np.ndarray` objects.
  178. Returns:
  179. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  180. - **input_ids** -- List of token ids to be fed to a model.
  181. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  182. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  183. `None`).
  184. """
  185. return self.__call__(text=text, **kwargs)
  186. def score_retrieval(
  187. self,
  188. query_embeddings: Union["torch.Tensor", list["torch.Tensor"]],
  189. passage_embeddings: Union["torch.Tensor", list["torch.Tensor"]],
  190. batch_size: int = 128,
  191. output_dtype: Optional["torch.dtype"] = None,
  192. output_device: Union["torch.device", str] = "cpu",
  193. ) -> "torch.Tensor":
  194. """
  195. Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector
  196. query embeddings (`qs`) and passage embeddings (`ps`). For ColPali, a passage is the
  197. image of a document page.
  198. Because the embedding tensors are multi-vector and can thus have different shapes, they
  199. should be fed as:
  200. (1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim)
  201. (2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually
  202. obtained by padding the list of tensors.
  203. Args:
  204. query_embeddings (`Union[torch.Tensor, list[torch.Tensor]`): Query embeddings.
  205. passage_embeddings (`Union[torch.Tensor, list[torch.Tensor]`): Passage embeddings.
  206. batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores.
  207. output_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The dtype of the output tensor.
  208. If `None`, the dtype of the input embeddings is used.
  209. output_device (`torch.device` or `str`, *optional*, defaults to "cpu"): The device of the output tensor.
  210. Returns:
  211. `torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score
  212. tensor is saved on the "cpu" device.
  213. """
  214. if len(query_embeddings) == 0:
  215. raise ValueError("No queries provided")
  216. if len(passage_embeddings) == 0:
  217. raise ValueError("No passages provided")
  218. if query_embeddings[0].device != passage_embeddings[0].device:
  219. raise ValueError("Queries and passages must be on the same device")
  220. if query_embeddings[0].dtype != passage_embeddings[0].dtype:
  221. raise ValueError("Queries and passages must have the same dtype")
  222. if output_dtype is None:
  223. output_dtype = query_embeddings[0].dtype
  224. scores: list[torch.Tensor] = []
  225. for i in range(0, len(query_embeddings), batch_size):
  226. batch_scores: list[torch.Tensor] = []
  227. batch_queries = torch.nn.utils.rnn.pad_sequence(
  228. query_embeddings[i : i + batch_size], batch_first=True, padding_value=0
  229. )
  230. for j in range(0, len(passage_embeddings), batch_size):
  231. batch_passages = torch.nn.utils.rnn.pad_sequence(
  232. passage_embeddings[j : j + batch_size], batch_first=True, padding_value=0
  233. )
  234. batch_scores.append(
  235. torch.einsum("bnd,csd->bcns", batch_queries, batch_passages).max(dim=3)[0].sum(dim=2)
  236. )
  237. scores.append(torch.cat(batch_scores, dim=1).to(output_dtype).to(output_device))
  238. return torch.cat(scores, dim=0)
  239. __all__ = [
  240. "ColPaliProcessor",
  241. ]