any_to_any.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import enum
  15. import re
  16. from typing import Any, Union, overload
  17. import numpy as np
  18. from ..audio_utils import AudioInput
  19. from ..generation import GenerationConfig
  20. from ..image_utils import ImageInput
  21. from ..processing_utils import ProcessingKwargs, Unpack
  22. from ..utils import (
  23. add_end_docstrings,
  24. is_torch_available,
  25. is_vision_available,
  26. logging,
  27. requires_backends,
  28. )
  29. from ..video_utils import VideoInput
  30. from .base import Pipeline, build_pipeline_init_args
  31. if is_torch_available():
  32. import torch
  33. from ..models.auto.modeling_auto import MODEL_FOR_MULTIMODAL_LM_MAPPING_NAMES
  34. from .pt_utils import KeyDataset
  35. if is_vision_available():
  36. from PIL import Image
  37. logger = logging.get_logger(__name__)
  38. class ReturnType(enum.Enum):
  39. TENSORS = 0
  40. NEW_TEXT = 1
  41. FULL_TEXT = 2
  42. class Chat:
  43. """This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats
  44. to this format because the rest of the pipeline code tends to assume that lists of messages are
  45. actually a batch of samples rather than messages in the same conversation."""
  46. def __init__(self, messages: list[dict]):
  47. for message in messages:
  48. if not ("role" in message and "content" in message):
  49. raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.")
  50. self.messages = messages
  51. @add_end_docstrings(build_pipeline_init_args(has_processor=True))
  52. class AnyToAnyPipeline(Pipeline):
  53. """
  54. Multimodal Generation pipeline using an `AutoModelForMultimodalLM`. This pipeline generates text given any
  55. combination of multimodal data and text.When the underlying model is a conversational model, it can also
  56. accept one or more chats, in which case the pipeline will operate in chat mode and will continue the
  57. chat(s) by adding its response(s). Each chat takes the form of a list of dicts, where each dict contains
  58. "role" and "content" keys.
  59. Unless the model you're using explicitly sets these generation parameters in its configuration files
  60. (`generation_config.json`), the following default values will be used:
  61. - max_new_tokens: 256
  62. Example:
  63. ```python
  64. >>> from transformers import pipeline
  65. >>> pipe = pipeline(task="any-to-any", model="google/gemma-3n-E4B-it")
  66. >>> pipe("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", text="A photo of")
  67. [{'generated_text': 'a photo of two birds'}]
  68. ```
  69. ```python
  70. >>> from transformers import pipeline
  71. >>> pipe = pipeline("any-to-any", model="google/gemma-3n-E4B-it")
  72. >>> messages = [
  73. >>> {
  74. >>> "role": "user",
  75. >>> "content": [
  76. >>> {
  77. >>> "type": "image",
  78. >>> "url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
  79. >>> },
  80. >>> {"type": "text", "text": "Describe this image."},
  81. >>> ],
  82. >>> },
  83. >>> {
  84. >>> "role": "assistant",
  85. >>> "content": [
  86. >>> {"type": "text", "text": "There is a dog and"},
  87. >>> ],
  88. >>> },
  89. >>> ]
  90. >>> pipe(text=messages, max_new_tokens=20, return_full_text=False)
  91. [{'input_text': [{'role': 'user',
  92. 'content': [{'type': 'image',
  93. 'url': 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg'},
  94. {'type': 'text', 'text': 'Describe this image.'}]},
  95. {'role': 'assistant',
  96. 'content': [{'type': 'text', 'text': 'There is a dog and'}]}],
  97. 'generated_text': ' a person in the image. The dog is sitting on the sand, and the person is sitting on'}]
  98. ```
  99. Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
  100. This multimodal pipeline can currently be loaded from pipeline() using the following task identifier:
  101. "any-to-any".
  102. See the list of available models on
  103. [huggingface.co/models](https://huggingface.co/models?pipeline_tag=any-to-any).
  104. """
  105. _load_processor = True
  106. _load_image_processor = False
  107. _load_feature_extractor = False
  108. _load_tokenizer = False
  109. _pipeline_calls_generate = True
  110. # Make sure the docstring is updated when the default generation config is changed
  111. _default_generation_config = GenerationConfig(
  112. max_new_tokens=256,
  113. )
  114. def __init__(self, *args, **kwargs):
  115. super().__init__(*args, **kwargs)
  116. if "image" in self.model.input_modalities or "video" in self.model.input_modalities:
  117. requires_backends(self, "vision")
  118. requires_backends(self, "torchvision")
  119. if "audio" in self.model.input_modalities:
  120. requires_backends(self, "librosa")
  121. self.check_model_type(MODEL_FOR_MULTIMODAL_LM_MAPPING_NAMES)
  122. def _sanitize_parameters(
  123. self,
  124. max_new_tokens=None,
  125. generate_kwargs=None,
  126. timeout=None,
  127. return_full_text=None,
  128. return_tensors=None,
  129. return_type=None,
  130. clean_up_tokenization_spaces=None,
  131. stop_sequence=None,
  132. continue_final_message=None,
  133. skip_special_tokens=None,
  134. generation_mode=None,
  135. processor_kwargs=None,
  136. **kwargs: Unpack[ProcessingKwargs],
  137. ):
  138. forward_kwargs = {}
  139. preprocess_params = {}
  140. postprocess_params = {}
  141. # Preprocess params
  142. preprocess_params.update(kwargs)
  143. if timeout is not None:
  144. preprocess_params["timeout"] = timeout
  145. if continue_final_message is not None:
  146. preprocess_params["continue_final_message"] = continue_final_message
  147. if processor_kwargs is not None:
  148. preprocess_params["processor_kwargs"] = processor_kwargs
  149. # Forward kwargs
  150. forward_kwargs["generate_kwargs"] = generate_kwargs or {}
  151. if generation_mode is not None and generation_mode != "text":
  152. forward_kwargs["generate_kwargs"]["generation_mode"] = generation_mode
  153. # Qwen-Omni models need to know the origin of audio, to align mm position ids
  154. if kwargs.get("load_audio_from_video") and re.search(r"qwen\domni", self.model.__class__.__name__.lower()):
  155. forward_kwargs["generate_kwargs"]["use_audio_in_video"] = True
  156. if stop_sequence is not None:
  157. if isinstance(stop_sequence, str):
  158. stop_sequence = [stop_sequence]
  159. forward_kwargs["generate_kwargs"]["stop_strings"] = stop_sequence
  160. forward_kwargs["generate_kwargs"]["tokenizer"] = self.processor.tokenizer
  161. if max_new_tokens is not None:
  162. if generate_kwargs is not None and "max_new_tokens" in generate_kwargs:
  163. raise ValueError(
  164. "'max_new_tokens' is defined twice, once in 'generate_kwargs' and "
  165. "once as a direct argument. Please use only one."
  166. )
  167. forward_kwargs["generate_kwargs"]["max_new_tokens"] = max_new_tokens
  168. if return_full_text is not None and return_type is None:
  169. if return_tensors is not None:
  170. raise ValueError("`return_full_text` is mutually exclusive with `return_tensors`")
  171. return_type = ReturnType.FULL_TEXT if return_full_text else ReturnType.NEW_TEXT
  172. elif return_tensors is not None and return_type is None:
  173. return_type = ReturnType.TENSORS
  174. # We don't want to set the global default to FULLTEXT at init time. That is why
  175. # `_postprocess_params` is checked before setting the default value
  176. elif return_type is None and generation_mode in [None, "text"] and hasattr(self, "_postprocess_params"):
  177. return_type = ReturnType.FULL_TEXT
  178. # Postprocess params
  179. if generation_mode not in [None, "text"] and return_type is not None:
  180. raise ValueError(
  181. f"`return_type` cannot be set to {return_type} when generation_mode={generation_mode}. "
  182. "Set `return_type=None` or generation_mode='text'"
  183. )
  184. if generation_mode not in [None, "text", "image", "audio"]:
  185. raise ValueError(
  186. f"`generation_mode` can be only one of the `text`, `audio`, `image` but got generation_mode[={generation_mode}]"
  187. )
  188. elif generation_mode is not None and generation_mode not in self.model.output_modalities:
  189. raise ValueError(
  190. f"`generation_mode={generation_mode}` is not supported for {self.model.__class__.__name__}. "
  191. f"The model can only output the following modalities: {self.model.output_modalities}"
  192. )
  193. if return_type is not None:
  194. postprocess_params["return_type"] = return_type
  195. if continue_final_message is not None:
  196. postprocess_params["continue_final_message"] = continue_final_message
  197. if clean_up_tokenization_spaces is not None:
  198. postprocess_params["clean_up_tokenization_spaces"] = clean_up_tokenization_spaces
  199. if skip_special_tokens is not None:
  200. postprocess_params["skip_special_tokens"] = skip_special_tokens
  201. postprocess_params["generation_mode"] = generation_mode
  202. return preprocess_params, forward_kwargs, postprocess_params
  203. @overload
  204. def __call__(
  205. self,
  206. text: str | None = None,
  207. images: Union[str, "Image.Image"] | None = None,
  208. videos: Union[str, "np.ndarray", "torch.Tensor"] | None = None,
  209. audio: Union[str, "np.ndarray"] | None = None,
  210. **kwargs: Any,
  211. ) -> list[dict[str, Any]]: ...
  212. @overload
  213. def __call__(
  214. self,
  215. text: list[str] | None = None,
  216. images: list[str] | list["Image.Image"] | None = None,
  217. videos: list[str] | list["np.ndarray"] | list["torch.Tensor"] | None = None,
  218. audio: list[str] | list["np.ndarray"] | None = None,
  219. **kwargs: Any,
  220. ) -> list[list[dict[str, Any]]]: ...
  221. def __call__(
  222. self,
  223. text: str | list[str] | list[dict],
  224. images: str | list[str] | list[list[str]] | ImageInput | None = None,
  225. videos: str | list[str] | VideoInput | None = None,
  226. audio: str | list[str] | AudioInput | None = None,
  227. **kwargs,
  228. ) -> list[dict[str, Any]] | list[list[dict[str, Any]]]:
  229. """
  230. Generate a text given text and optionally multimodal data passed as inputs.
  231. Args:
  232. text (`str`, `list[str]`, `list[dict]`):
  233. The text to be used for generation. If a list of strings is passed, the length of the list should be
  234. the same as the number of images. Text can also follow the chat format: a list of dictionaries where
  235. each dictionary represents a message in a conversation. Each dictionary should have two keys: 'role'
  236. and 'content'. 'role' should be one of 'user', 'system' or 'assistant'. 'content' should be a list of
  237. dictionary containing the text of the message and the type of the message.
  238. images (`str`, `list[str]`, `ImageInput`):
  239. The pipeline handles three types of images:
  240. - A string containing a HTTP(s) link pointing to an image
  241. - A string containing a local path to an image
  242. - An image loaded in PIL directly
  243. The pipeline accepts either a single image or a batch of images. Finally, this pipeline also supports
  244. the chat format (see `text`) containing images and text in this argument.
  245. videos (`str`, `list[str]`, `VideoInput`):
  246. The pipeline handles three types of videos:
  247. - A string containing a HTTP(s) link pointing to a video
  248. - A string containing a local path to a video
  249. - A video loaded and decoded to array format
  250. The pipeline accepts either a single video or a batch of videos. Finally, this pipeline also supports
  251. the chat format (see `text`) containing videos and text in this argument.
  252. audio (`str`, `list[str]`, `AudioInput`):
  253. The pipeline handles three types of audios:
  254. - A string containing a HTTP(s) link pointing to an audio
  255. - A string containing a local path to an audio
  256. - An audio loaded in PIL directly
  257. The pipeline accepts either a single audios or a batch of audios. Finally, this pipeline also supports
  258. the chat format (see `text`) containing audios and text in this argument.
  259. return_tensors (`bool`, *optional*, defaults to `False`):
  260. Returns the tensors of predictions (as token indices) in the outputs. If set to
  261. `True`, the decoded text is not returned.
  262. return_text (`bool`, *optional*):
  263. Returns the decoded texts in the outputs.
  264. return_full_text (`bool`, *optional*, defaults to `True`):
  265. If set to `False` only added text is returned, otherwise the full text is returned. Cannot be
  266. specified at the same time as `return_text`.
  267. clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
  268. Whether or not to clean up the potential extra spaces in the text output.
  269. continue_final_message( `bool`, *optional*): This indicates that you want the model to continue the
  270. last message in the input chat rather than starting a new one, allowing you to "prefill" its response.
  271. By default this is `True` when the final message in the input chat has the `assistant` role and
  272. `False` otherwise, but you can manually override that behaviour by setting this flag.
  273. Return:
  274. A list or a list of list of `dict`: Each result comes as a dictionary with the following key (cannot
  275. return a combination of both `generated_text` and `generated_token_ids`):
  276. - **generated_text** (`str`, present when `return_text=True` and `generation_mode="text"`) -- The generated text.
  277. - **generated_audio** (`np.ndarray`, present when `generation_mode="audio"`) -- The generated audio.
  278. - **generated_image** (`PIL.Image.Image`, present when `generation_mode="image"`) -- The generated image.
  279. - **generated_token_ids** (`torch.Tensor`, present when `return_tensors=True` and `generation_mode="text"`) -- The token
  280. ids of the generated text.
  281. - **input_text** (`str`) -- The input text.
  282. """
  283. if images is None and text is None:
  284. raise ValueError("You must at least provide either text or images.")
  285. if isinstance(text, (list, tuple, KeyDataset)) and isinstance(text[0], (list, tuple, dict)):
  286. # We have one or more prompts in list-of-dicts format, so this is chat mode
  287. if isinstance(text[0], dict) and "role" in text[0]:
  288. return super().__call__(Chat(text), **kwargs)
  289. elif isinstance(text[0], (list, tuple)) and isinstance(text[0][0], dict) and "role" in text[0][0]:
  290. chats = [Chat(chat) for chat in text] # 🐈 🐈 🐈
  291. return super().__call__(chats, **kwargs)
  292. if text is not None and not (isinstance(text, str) or (isinstance(text, list) and isinstance(text[0], str))):
  293. """
  294. Supports the following format
  295. - {"text": text, "image": image, "video": video, "audio": audio}
  296. - [{"text": text, "image": image, "video": video, "audio": audio}]
  297. - Generator and datasets
  298. This is a common pattern in other multimodal pipelines, so we support it here as well.
  299. """
  300. return super().__call__(text, **kwargs)
  301. # encourage the user to use the chat format if supported
  302. if getattr(self.processor, "chat_template", None) is not None:
  303. logger.warning_once(
  304. "The input data was not formatted as a chat with dicts containing 'role' and 'content' keys, even "
  305. "though this model supports chat. Consider using the chat format for better results. For more "
  306. "information, see https://huggingface.co/docs/transformers/en/chat_templating"
  307. )
  308. return super().__call__({"text": text, "images": images, "video": videos, "audio": audio}, **kwargs)
  309. def preprocess(self, inputs=None, timeout=None, continue_final_message=None, **processing_kwargs):
  310. if isinstance(inputs, Chat):
  311. # If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
  312. # because very few models support multiple separate, consecutive assistant messages
  313. if continue_final_message is None:
  314. continue_final_message = inputs.messages[-1]["role"] == "assistant"
  315. # Processor kwargs are passed separately from jinja kwargs to chat template
  316. # but it was added only in https://github.com/huggingface/transformers/pull/44881
  317. processor_kwargs = processing_kwargs.pop("processor_kwargs", None) or {}
  318. chat_template_kwargs = {
  319. "continue_final_message": continue_final_message,
  320. "return_tensors": "pt",
  321. "tokenize": True,
  322. "return_dict": True,
  323. "add_generation_prompt": not continue_final_message,
  324. "processor_kwargs": processor_kwargs,
  325. **processing_kwargs,
  326. }
  327. # Handle Mistral tokenizer which does not accept processing kwargs
  328. if self.processor.tokenizer.__class__.__name__ == "MistralCommonBackend":
  329. chat_template_kwargs = {
  330. k: v for k, v in chat_template_kwargs.items() if k in ["padding", "truncation", "max_length"]
  331. }
  332. model_inputs = self.processor.apply_chat_template(
  333. inputs.messages,
  334. **chat_template_kwargs,
  335. ).to(dtype=self.dtype)
  336. model_inputs["text"] = inputs
  337. return model_inputs
  338. # In case we only have text inputs
  339. if isinstance(inputs, (list, tuple, str)):
  340. text = inputs
  341. inputs = {}
  342. else:
  343. inputs = inputs.copy() # avoid in-place changes if users passed dict
  344. text = inputs.pop("text")
  345. # Feature extractor do not load audio files and expect a decoded array
  346. if inputs.get("audio", None) is not None and hasattr(self.processor, "feature_extractor"):
  347. inputs["audio"] = self.processor.feature_extractor.fetch_audio(inputs["audio"])
  348. # If batched text inputs, we set padding to True unless specified otherwise
  349. processor_kwargs = processing_kwargs.pop("processor_kwargs", None) or processing_kwargs
  350. if isinstance(text, (list, tuple)) and len(text) > 1:
  351. processor_kwargs.setdefault("padding", True)
  352. model_inputs = self.processor(text=text, **inputs, return_tensors="pt", **processor_kwargs).to(
  353. dtype=self.dtype
  354. )
  355. model_inputs["text"] = text
  356. return model_inputs
  357. def _forward(self, model_inputs, generate_kwargs=None):
  358. generate_kwargs = {} if generate_kwargs is None else generate_kwargs
  359. prompt_text = model_inputs.pop("text")
  360. input_ids = model_inputs.get("input_ids", model_inputs.get("decoder_input_ids"))
  361. # User-defined `generation_config` passed to the pipeline call take precedence
  362. if "generation_config" not in generate_kwargs:
  363. generate_kwargs["generation_config"] = self.generation_config
  364. generated_sequence = self.model.generate(**model_inputs, **generate_kwargs)
  365. return {"generated_sequence": generated_sequence, "prompt_text": prompt_text, "input_ids": input_ids}
  366. def postprocess(
  367. self,
  368. model_outputs,
  369. return_type=None,
  370. continue_final_message=None,
  371. skip_special_tokens=None,
  372. **postprocess_kwargs,
  373. ):
  374. input_texts = model_outputs["prompt_text"]
  375. input_texts = [input_texts] if isinstance(input_texts, (str, Chat)) else input_texts
  376. generated_sequence = model_outputs["generated_sequence"]
  377. input_ids = model_outputs["input_ids"]
  378. if return_type == ReturnType.TENSORS:
  379. return [
  380. {"input_text": input_texts[i], "generated_token_ids": generated_sequence[i]}
  381. for i in range(len(input_texts))
  382. ]
  383. # Decode inputs and outputs the same way to remove input text from generated text if present
  384. skip_special_tokens = skip_special_tokens if skip_special_tokens is not None else True
  385. if getattr(self.tokenizer, "response_schema", False):
  386. skip_special_tokens = False
  387. generation_mode = postprocess_kwargs["generation_mode"] or "text"
  388. if generation_mode == "image" and hasattr(self.model, "decode_image_tokens"):
  389. generated_sequence = self.model.decode_image_tokens(generated_sequence.to(self.model.device))
  390. generated_outputs = self.processor.post_process_multimodal_output(
  391. generated_sequence, skip_special_tokens=skip_special_tokens, **postprocess_kwargs
  392. )
  393. # Force consistent behavior for including the input text in the output
  394. if return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
  395. # Remove the input text from the generated text if the generated text starts with the input text
  396. # (accounting for the possibility of a space between the input and generated text)
  397. new_generated_texts = []
  398. postprocess_kwargs["generation_mode"] = "text"
  399. decoded_inputs = self.processor.post_process_multimodal_output(
  400. input_ids, skip_special_tokens=skip_special_tokens, **postprocess_kwargs
  401. )
  402. for text_generated, decoded_input in zip(generated_outputs, decoded_inputs):
  403. # There can be added characters before the input text, so we need to find the beginning of the input text in the generated text
  404. index_input_text = text_generated.find(decoded_input)
  405. # Limit the search to 2 residual characters, like spaces or new lines, to avoid removing a large part of the answer
  406. if 0 <= index_input_text <= 2:
  407. # If the input text is found, we remove it
  408. new_generated_texts.append(text_generated[index_input_text + len(decoded_input) :])
  409. else:
  410. new_generated_texts.append(text_generated)
  411. generated_outputs = new_generated_texts
  412. if return_type == ReturnType.FULL_TEXT:
  413. full_texts = []
  414. for prompt_text, generated_text in zip(input_texts, generated_outputs):
  415. if isinstance(prompt_text, str):
  416. generated_text = prompt_text + generated_text
  417. elif isinstance(prompt_text, Chat):
  418. if continue_final_message is None:
  419. # If the user passes a chat ending in an assistant message, we treat it as a prefill by
  420. # default because very few models support multiple separate, consecutive assistant messages
  421. continue_final_message = prompt_text.messages[-1]["role"] == "assistant"
  422. if continue_final_message:
  423. # With assistant prefill, concat onto the end of the last message
  424. new_text = dict(prompt_text.messages[-1]["content"][-1].items())
  425. new_text["text"] += generated_text
  426. generated_text = list(prompt_text.messages)[:-1] + [
  427. {
  428. "role": prompt_text.messages[-1]["role"],
  429. "content": prompt_text.messages[-1]["content"][:-1] + [new_text],
  430. }
  431. ]
  432. else:
  433. # When we're not starting from a prefill, the output is a new assistant message
  434. if getattr(self.tokenizer, "response_schema", False):
  435. assistant_message = self.tokenizer.parse_response(generated_text)
  436. else:
  437. assistant_message = {"role": "assistant", "content": generated_text}
  438. generated_text = list(prompt_text.messages) + [assistant_message]
  439. full_texts.append(generated_text)
  440. generated_outputs = full_texts
  441. records = [
  442. {
  443. "input_text": input_text.messages if isinstance(input_text, Chat) else input_text,
  444. f"generated_{generation_mode}": generated_output,
  445. }
  446. for input_text, generated_output in zip(input_texts, generated_outputs)
  447. ]
  448. return records