processing_fuyu.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771
  1. # Copyright 2023 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Image/Text processor class for GIT
  16. """
  17. import re
  18. from typing import Union
  19. import numpy as np
  20. from ...image_utils import ImageInput
  21. from ...processing_utils import (
  22. MultiModalData,
  23. ProcessingKwargs,
  24. ProcessorMixin,
  25. Unpack,
  26. )
  27. from ...tokenization_utils_base import PreTokenizedInput, TextInput
  28. from ...utils import auto_docstring, is_torch_available, logging, requires_backends
  29. from ...utils.import_utils import requires
  30. if is_torch_available():
  31. from .image_processing_fuyu import FuyuBatchFeature
  32. logger = logging.get_logger(__name__)
  33. if is_torch_available():
  34. import torch
  35. TEXT_REPR_BBOX_OPEN = "<box>"
  36. TEXT_REPR_BBOX_CLOSE = "</box>"
  37. TEXT_REPR_POINT_OPEN = "<point>"
  38. TEXT_REPR_POINT_CLOSE = "</point>"
  39. TOKEN_BBOX_OPEN_STRING = "<0x00>" # <bbox>
  40. TOKEN_BBOX_CLOSE_STRING = "<0x01>" # </bbox>
  41. TOKEN_POINT_OPEN_STRING = "<0x02>" # <point>
  42. TOKEN_POINT_CLOSE_STRING = "<0x03>" # </point>
  43. BEGINNING_OF_ANSWER_STRING = "<0x04>" # <boa>
  44. class FuyuProcessorKwargs(ProcessingKwargs, total=False):
  45. _defaults = {
  46. "text_kwargs": {
  47. "add_special_tokens": True,
  48. "padding": False,
  49. "stride": 0,
  50. "return_attention_mask": True,
  51. "return_overflowing_tokens": False,
  52. "return_special_tokens_mask": False,
  53. "return_offsets_mapping": False,
  54. "return_token_type_ids": False,
  55. "return_length": False,
  56. "verbose": True,
  57. "return_mm_token_type_ids": False,
  58. },
  59. }
  60. def full_unpacked_stream_to_tensor(
  61. all_bi_tokens_to_place: list[int],
  62. full_unpacked_stream: list["torch.Tensor"],
  63. fill_value: int,
  64. batch_size: int,
  65. new_seq_len: int,
  66. offset: int,
  67. ) -> "torch.Tensor":
  68. """Takes an unpacked stream of tokens (i.e. a list of tensors, one for each item in the batch) and does
  69. the required padding to create a single tensor for the batch of shape batch_size x new_seq_len.
  70. """
  71. assert len(all_bi_tokens_to_place) == batch_size
  72. assert len(full_unpacked_stream) == batch_size
  73. # Create padded tensors for the full batch.
  74. new_padded_tensor = torch.full(
  75. [batch_size, new_seq_len],
  76. fill_value=fill_value,
  77. dtype=full_unpacked_stream[0].dtype,
  78. device=full_unpacked_stream[0].device,
  79. )
  80. # Place each batch entry into the batch tensor.
  81. for bi in range(batch_size):
  82. tokens_to_place = all_bi_tokens_to_place[bi]
  83. new_padded_tensor[bi, :tokens_to_place] = full_unpacked_stream[bi][offset : tokens_to_place + offset]
  84. return new_padded_tensor
  85. def construct_full_unpacked_stream(
  86. num_real_text_tokens: Union[list[list[int]], "torch.Tensor"],
  87. input_stream: "torch.Tensor",
  88. image_tokens: list[list["torch.Tensor"]],
  89. batch_size: int,
  90. num_sub_sequences: int,
  91. ) -> list["torch.Tensor"]:
  92. """Takes an input_stream tensor of shape B x S x ?. For each subsequence, adds any required
  93. padding to account for images and then unpacks the subsequences to create a single sequence per item in the batch.
  94. Returns a list of tensors, one for each item in the batch."""
  95. all_bi_stream = []
  96. for batch_index in range(batch_size):
  97. all_si_stream = []
  98. # First, construct full token stream (including image placeholder tokens) and loss mask for each subsequence
  99. # and append to lists. We use lists rather than tensors because each subsequence is variable-sized.
  100. # TODO Remove this logic in a subsequent release since subsequences are not supported.
  101. image_adjustment = image_tokens[batch_index][0]
  102. subsequence_stream = torch.cat([image_adjustment, input_stream[batch_index, 0]], dim=0)
  103. num_real_tokens = image_adjustment.shape[0] + num_real_text_tokens[batch_index][0]
  104. all_si_stream.append(subsequence_stream[:num_real_tokens])
  105. all_bi_stream.append(torch.cat(all_si_stream, dim=0))
  106. return all_bi_stream
  107. def _replace_string_repr_with_token_tags(prompt: str) -> str:
  108. prompt = prompt.replace(TEXT_REPR_POINT_OPEN, TOKEN_POINT_OPEN_STRING)
  109. prompt = prompt.replace(TEXT_REPR_POINT_CLOSE, TOKEN_POINT_CLOSE_STRING)
  110. prompt = prompt.replace(TEXT_REPR_BBOX_OPEN, TOKEN_BBOX_OPEN_STRING)
  111. prompt = prompt.replace(TEXT_REPR_BBOX_CLOSE, TOKEN_BBOX_CLOSE_STRING)
  112. return prompt
  113. def _segment_prompt_into_text_token_conversions(prompt: str) -> list:
  114. """
  115. Given a string prompt, converts the prompt into a list of TextTokenConversions.
  116. """
  117. # Wherever, we notice the [TOKEN_OPEN_STRING, TOKEN_CLOSE_STRING], we split the prompt
  118. prompt_text_list: list = []
  119. regex_pattern = re.compile(
  120. f"({TOKEN_BBOX_OPEN_STRING}|{TOKEN_BBOX_CLOSE_STRING}|{TOKEN_POINT_OPEN_STRING}|{TOKEN_POINT_CLOSE_STRING})"
  121. )
  122. # Split by the regex pattern
  123. prompt_split = regex_pattern.split(prompt)
  124. for i, elem in enumerate(prompt_split):
  125. if len(elem) == 0 or elem in [
  126. TOKEN_BBOX_OPEN_STRING,
  127. TOKEN_BBOX_CLOSE_STRING,
  128. TOKEN_POINT_OPEN_STRING,
  129. TOKEN_POINT_CLOSE_STRING,
  130. ]:
  131. continue
  132. prompt_text_list.append(
  133. (elem, i > 1 and prompt_split[i - 1] in [TOKEN_BBOX_OPEN_STRING, TOKEN_POINT_OPEN_STRING])
  134. )
  135. return prompt_text_list
  136. def _transform_coordinates_and_tokenize(prompt: str, scale_factor: float, tokenizer) -> list[int]:
  137. """
  138. This function transforms the prompt in the following fashion:
  139. - <box> <point> and </box> </point> to their respective token mappings
  140. - extract the coordinates from the tag
  141. - transform the coordinates into the transformed image space
  142. - return the prompt tokens with the transformed coordinates and new tags
  143. Bounding boxes and points MUST be in the following format: <box>y1, x1, y2, x2</box> <point>x, y</point> The spaces
  144. and punctuation added above are NOT optional.
  145. """
  146. # Make a namedtuple that stores "text" and "is_bbox"
  147. # We want to do the following: Tokenize the code normally -> when we see a point or box, tokenize using the tokenize_within_tag function
  148. # When point or box close tag, continue tokenizing normally
  149. # First, we replace the point and box tags with their respective tokens
  150. prompt = _replace_string_repr_with_token_tags(prompt)
  151. # Tokenize the prompt
  152. # Convert prompt into a list split
  153. prompt_text_list = _segment_prompt_into_text_token_conversions(prompt)
  154. transformed_prompt_tokens: list[int] = []
  155. for elem in prompt_text_list:
  156. if elem[1]:
  157. # This is a location, we need to tokenize it
  158. within_tag_tokenized = _transform_within_tags(elem[0], scale_factor, tokenizer)
  159. # Surround the text with the open and close tags
  160. transformed_prompt_tokens.extend(within_tag_tokenized)
  161. else:
  162. transformed_prompt_tokens.extend(tokenizer(elem[0], add_special_tokens=False).input_ids)
  163. return transformed_prompt_tokens
  164. def _transform_within_tags(text: str, scale_factor: float, tokenizer) -> list[int]:
  165. """
  166. Given a bounding box of the fashion <box>1, 2, 3, 4</box> | <point>1, 2</point> This function is responsible for
  167. converting 1, 2, 3, 4 into tokens of 1 2 3 4 without any commas.
  168. """
  169. # Convert the text into a list of strings.
  170. num_int_strs = text.split(",")
  171. if len(num_int_strs) == 2:
  172. # If there are any open or close tags, remove them.
  173. token_space_open_string = tokenizer.vocab[TOKEN_POINT_OPEN_STRING]
  174. token_space_close_string = tokenizer.vocab[TOKEN_POINT_CLOSE_STRING]
  175. else:
  176. token_space_open_string = tokenizer.vocab[TOKEN_BBOX_OPEN_STRING]
  177. token_space_close_string = tokenizer.vocab[TOKEN_BBOX_CLOSE_STRING]
  178. # Remove all spaces from num_ints
  179. num_ints = [float(num.strip()) for num in num_int_strs]
  180. # scale to transformed image size
  181. if len(num_ints) == 2:
  182. num_ints_translated = scale_point_to_transformed_image(x=num_ints[0], y=num_ints[1], scale_factor=scale_factor)
  183. elif len(num_ints) == 4:
  184. num_ints_translated = scale_bbox_to_transformed_image(
  185. top=num_ints[0],
  186. left=num_ints[1],
  187. bottom=num_ints[2],
  188. right=num_ints[3],
  189. scale_factor=scale_factor,
  190. )
  191. else:
  192. raise ValueError(f"Invalid number of ints: {len(num_ints)}")
  193. # Tokenize the text, skipping the
  194. tokens = [tokenizer.vocab[str(num)] for num in num_ints_translated]
  195. return [token_space_open_string] + tokens + [token_space_close_string]
  196. def _tokenize_prompts_with_image_and_batch(
  197. tokenizer,
  198. prompts: list[list[str]],
  199. scale_factors: list[list["torch.Tensor"]] | None,
  200. max_tokens_to_generate: int,
  201. max_position_embeddings: int,
  202. add_BOS: bool, # Same issue with types as above
  203. add_beginning_of_answer_token: bool,
  204. ) -> tuple["torch.Tensor", "torch.Tensor"]:
  205. """
  206. Given a set of prompts and number of tokens to generate:
  207. - tokenize prompts
  208. - set the sequence length to be the max of length of prompts plus the number of tokens we would like to generate
  209. - pad all the sequences to this length so we can convert them into a 3D tensor.
  210. """
  211. # If not tool use, transform the coordinates while tokenizing
  212. if scale_factors is not None:
  213. transformed_prompt_tokens = []
  214. for prompt_seq, scale_factor_seq in zip(prompts, scale_factors):
  215. transformed_prompt_tokens.append(
  216. [
  217. _transform_coordinates_and_tokenize(prompt, scale_factor.item(), tokenizer)
  218. for prompt, scale_factor in zip(prompt_seq, scale_factor_seq)
  219. ]
  220. )
  221. else:
  222. transformed_prompt_tokens = [[tokenizer.tokenize(prompt) for prompt in prompt_seq] for prompt_seq in prompts]
  223. prompts_tokens = transformed_prompt_tokens
  224. if add_BOS:
  225. bos_token = tokenizer.vocab["<s>"]
  226. else:
  227. bos_token = tokenizer.vocab["|ENDOFTEXT|"]
  228. prompts_tokens = [[[bos_token] + x for x in prompt_seq] for prompt_seq in prompts_tokens]
  229. if add_beginning_of_answer_token:
  230. beginning_of_answer = tokenizer.vocab[BEGINNING_OF_ANSWER_STRING]
  231. # Only add bbox open token to the last subsequence since that is what will be completed
  232. for token_seq in prompts_tokens:
  233. token_seq[-1].append(beginning_of_answer)
  234. # Now we have a list of list of tokens which each list has a different
  235. # size. We want to extend this list to:
  236. # - incorporate the tokens that need to be generated
  237. # - make all the sequences equal length.
  238. # Get the prompts length.
  239. prompts_length = [[len(x) for x in prompts_tokens_seq] for prompts_tokens_seq in prompts_tokens]
  240. # Get the max prompts length.
  241. max_prompt_len: int = np.max(prompts_length)
  242. # Number of tokens in the each sample of the batch.
  243. samples_length = min(max_prompt_len + max_tokens_to_generate, max_position_embeddings)
  244. if max_prompt_len + max_tokens_to_generate > max_position_embeddings:
  245. logger.warning(
  246. f"Max subsequence prompt length of {max_prompt_len} + max tokens to generate {max_tokens_to_generate}",
  247. f"exceeds context length of {max_position_embeddings}. Will generate as many tokens as possible.",
  248. )
  249. # Now update the list of list to be of the same size: samples_length.
  250. for prompt_tokens_seq, prompts_length_seq in zip(prompts_tokens, prompts_length):
  251. for prompt_tokens, prompt_length in zip(prompt_tokens_seq, prompts_length_seq):
  252. if len(prompt_tokens) > samples_length:
  253. raise ValueError("Length of subsequence prompt exceeds sequence length.")
  254. padding_size = samples_length - prompt_length
  255. prompt_tokens.extend([tokenizer.vocab["|ENDOFTEXT|"]] * padding_size)
  256. # Now we are in a structured format, we can convert to tensors.
  257. prompts_tokens_tensor = torch.tensor(prompts_tokens, dtype=torch.int64)
  258. prompts_length_tensor = torch.tensor(prompts_length, dtype=torch.int64)
  259. return prompts_tokens_tensor, prompts_length_tensor
  260. # Simplified assuming self.crop_top = self.padding_top = 0
  261. def original_to_transformed_h_coords(original_coords, scale_h):
  262. return np.round(original_coords * scale_h).astype(np.int32)
  263. # Simplified assuming self.crop_left = self.padding_left = 0
  264. def original_to_transformed_w_coords(original_coords, scale_w):
  265. return np.round(original_coords * scale_w).astype(np.int32)
  266. def scale_point_to_transformed_image(x: float, y: float, scale_factor: float) -> list[int]:
  267. x_scaled = original_to_transformed_w_coords(np.array([x / 2]), scale_factor)[0]
  268. y_scaled = original_to_transformed_h_coords(np.array([y / 2]), scale_factor)[0]
  269. return [x_scaled, y_scaled]
  270. def scale_bbox_to_transformed_image(
  271. top: float, left: float, bottom: float, right: float, scale_factor: float
  272. ) -> list[int]:
  273. top_scaled = original_to_transformed_w_coords(np.array([top / 2]), scale_factor)[0]
  274. left_scaled = original_to_transformed_h_coords(np.array([left / 2]), scale_factor)[0]
  275. bottom_scaled = original_to_transformed_w_coords(np.array([bottom / 2]), scale_factor)[0]
  276. right_scaled = original_to_transformed_h_coords(np.array([right / 2]), scale_factor)[0]
  277. return [top_scaled, left_scaled, bottom_scaled, right_scaled]
  278. @requires(backends=("vision",))
  279. @auto_docstring
  280. class FuyuProcessor(ProcessorMixin):
  281. @classmethod
  282. def _load_tokenizer_from_pretrained(
  283. cls, sub_processor_type, pretrained_model_name_or_path, subfolder="", **kwargs
  284. ):
  285. """
  286. Override for BC. Fuyu uses TokenizersBackend and requires token_type_ids to be removed from model_input_names
  287. because Fuyu uses mm_token_type_ids instead for multimodal token identification. `
  288. """
  289. from ...tokenization_utils_tokenizers import TokenizersBackend
  290. tokenizer = TokenizersBackend.from_pretrained(pretrained_model_name_or_path, **kwargs)
  291. # Remove token_type_ids as Fuyu uses mm_token_type_ids instead
  292. if "token_type_ids" in tokenizer.model_input_names:
  293. tokenizer.model_input_names.remove("token_type_ids")
  294. return tokenizer
  295. def __init__(self, image_processor, tokenizer, **kwargs):
  296. super().__init__(image_processor=image_processor, tokenizer=tokenizer)
  297. self.image_processor = image_processor
  298. self.tokenizer = tokenizer
  299. self.max_tokens_to_generate = 10
  300. self.max_position_embeddings = 16384 # TODO Can't derive this from model files: where to set it?
  301. self.pad_token_id = 0
  302. self.dummy_image_index = -1
  303. self.image_token_id = tokenizer.encode("|SPEAKER|", add_special_tokens=False)[1]
  304. self.image_newline_id = tokenizer.encode("|NEWLINE|", add_special_tokens=False)[1]
  305. self.image_ids = [self.image_newline_id, self.image_token_id]
  306. def _left_pad_inputs_with_attention_mask(self, model_inputs: list[dict], return_attention_mask: bool):
  307. max_length_input_ids = max(entry["input_ids"].shape[1] for entry in model_inputs)
  308. max_length_image_patch_indices = max(entry["image_patches_indices"].shape[1] for entry in model_inputs)
  309. batched_inputs = {"input_ids": [], "image_patches": [], "image_patches_indices": [], "attention_mask": []}
  310. for entry in model_inputs:
  311. for key, tensor in entry.items():
  312. if key == "input_ids":
  313. num_padding_tokens = max_length_input_ids - tensor.shape[1]
  314. padded_input_ids = torch.cat(
  315. [
  316. torch.full((tensor.shape[0], num_padding_tokens), self.pad_token_id, dtype=torch.long),
  317. tensor,
  318. ],
  319. dim=1,
  320. )
  321. batched_inputs[key].append(padded_input_ids)
  322. attention_mask = torch.cat(
  323. [torch.zeros(tensor.shape[0], num_padding_tokens, dtype=torch.long), torch.ones_like(tensor)],
  324. dim=1,
  325. )
  326. batched_inputs["attention_mask"].append(attention_mask)
  327. elif key == "image_patches":
  328. # For image_patches, we don't pad but just append them to the list.
  329. batched_inputs[key].append(tensor)
  330. else: # for image_patches_indices
  331. num_padding_indices = max_length_image_patch_indices - tensor.shape[1]
  332. padded_indices = torch.cat(
  333. [
  334. torch.full(
  335. (tensor.shape[0], num_padding_indices), self.dummy_image_index, dtype=torch.long
  336. ),
  337. tensor,
  338. ],
  339. dim=1,
  340. )
  341. batched_inputs[key].append(padded_indices)
  342. batched_keys = ["input_ids", "image_patches_indices"]
  343. if return_attention_mask:
  344. batched_keys.append("attention_mask")
  345. for key in batched_keys:
  346. batched_inputs[key] = torch.cat(batched_inputs[key], dim=0)
  347. # Cast images to tensor as well, if only one image passed and no padding needed
  348. # NOTE: vLLM expects all processor outputs to be a tensor
  349. if len(batched_inputs["image_patches"]) == 1:
  350. batched_inputs["image_patches"] = torch.cat(batched_inputs["image_patches"], dim=0)
  351. return batched_inputs
  352. def get_sample_encoding(
  353. self,
  354. prompts,
  355. scale_factors,
  356. image_unpadded_heights,
  357. image_unpadded_widths,
  358. image_placeholder_id,
  359. image_newline_id,
  360. tensor_batch_images,
  361. ):
  362. image_present = torch.ones(1, 1, 1)
  363. model_image_input = self.image_processor.preprocess_with_tokenizer_info(
  364. image_input=tensor_batch_images,
  365. image_present=image_present,
  366. image_unpadded_h=image_unpadded_heights,
  367. image_unpadded_w=image_unpadded_widths,
  368. image_placeholder_id=image_placeholder_id,
  369. image_newline_id=image_newline_id,
  370. variable_sized=True,
  371. )
  372. # FIXME max_tokens_to_generate is embedded into this processor's call.
  373. prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch(
  374. tokenizer=self.tokenizer,
  375. prompts=prompts,
  376. scale_factors=scale_factors,
  377. max_tokens_to_generate=self.max_tokens_to_generate,
  378. max_position_embeddings=self.max_position_embeddings,
  379. add_BOS=True,
  380. add_beginning_of_answer_token=True,
  381. )
  382. image_padded_unpacked_tokens = construct_full_unpacked_stream(
  383. num_real_text_tokens=prompts_length,
  384. input_stream=prompt_tokens,
  385. image_tokens=model_image_input["image_input_ids"],
  386. batch_size=1,
  387. num_sub_sequences=self.subsequence_length,
  388. )
  389. # Construct inputs for image patch indices.
  390. unpacked_image_patch_indices_per_batch = construct_full_unpacked_stream(
  391. num_real_text_tokens=prompts_length,
  392. input_stream=torch.full_like(prompt_tokens, -1),
  393. image_tokens=model_image_input["image_patch_indices_per_batch"],
  394. batch_size=1,
  395. num_sub_sequences=self.subsequence_length,
  396. )
  397. max_prompt_length = max(x.shape[-1] for x in image_padded_unpacked_tokens)
  398. max_seq_len_batch = min(max_prompt_length + self.max_tokens_to_generate, self.max_position_embeddings)
  399. tokens_to_place = min(max_seq_len_batch, max(0, image_padded_unpacked_tokens[0].shape[0]))
  400. # Use same packing logic for the image patch indices.
  401. image_patch_input_indices = full_unpacked_stream_to_tensor(
  402. all_bi_tokens_to_place=[tokens_to_place],
  403. full_unpacked_stream=unpacked_image_patch_indices_per_batch,
  404. fill_value=-1,
  405. batch_size=1,
  406. new_seq_len=max_seq_len_batch,
  407. offset=0,
  408. )
  409. image_patches_tensor = torch.stack([img[0] for img in model_image_input["image_patches"]])
  410. batch_encoding = {
  411. "input_ids": image_padded_unpacked_tokens[0].unsqueeze(0),
  412. "image_patches": image_patches_tensor,
  413. "image_patches_indices": image_patch_input_indices,
  414. }
  415. return batch_encoding
  416. @auto_docstring
  417. def __call__(
  418. self,
  419. images: ImageInput | None = None,
  420. text: str | list[str] | TextInput | PreTokenizedInput | None = None,
  421. **kwargs: Unpack[FuyuProcessorKwargs],
  422. ) -> "FuyuBatchFeature":
  423. r"""
  424. Returns:
  425. [`FuyuBatchEncoding`]: A [`FuyuBatchEncoding`] with the following fields:
  426. - **input_ids** -- Tensor of token ids to be fed to a model. Returned when `text` is not `None`.
  427. - **image_patches** -- List of Tensor of image patches. Returned when `images` is not `None`.
  428. - **image_patches_indices** -- Tensor of indices where patch embeddings have to be inserted by the model.
  429. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model when
  430. `return_attention_mask=True`.
  431. """
  432. requires_backends(self, ["torch"])
  433. # --- Check input validity ---
  434. if text is None and images is None:
  435. raise ValueError("You have to specify either text or images. Both cannot be None.")
  436. output_kwargs = self._merge_kwargs(
  437. FuyuProcessorKwargs,
  438. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  439. **kwargs,
  440. )
  441. return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
  442. if not output_kwargs["text_kwargs"].setdefault("return_attention_mask", True):
  443. raise ValueError("`return_attention_mask=False` is not supported for this model.")
  444. if text is not None and images is None:
  445. logger.warning("You are processing a text with no associated image. Make sure it is intended.")
  446. text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
  447. return text_encoding
  448. if text is None and images is not None:
  449. logger.warning("You are processing an image with no associated text. Make sure it is intended.")
  450. prompts = [[""]]
  451. if text is not None and images is not None:
  452. if isinstance(text, str):
  453. prompts = [[text]]
  454. elif isinstance(text, list):
  455. prompts = [[text_seq] for text_seq in text]
  456. # --- Preprocess images using self.image_processor ---
  457. # FIXME - We hard code "pt" here because the rest of the processing assumes torch tensors
  458. output_kwargs["images_kwargs"]["return_tensors"] = "pt"
  459. image_encoding = self.image_processor.preprocess(images, **output_kwargs["images_kwargs"])
  460. batch_images = image_encoding["images"]
  461. image_unpadded_heights = image_encoding["image_unpadded_heights"]
  462. image_unpadded_widths = image_encoding["image_unpadded_widths"]
  463. scale_factors = image_encoding["image_scale_factors"]
  464. self.subsequence_length = 1 # Each batch contains only one sequence.
  465. self.batch_size = len(batch_images)
  466. # --- Use self.tokenizer to get the ids of special tokens to insert into image ids ---
  467. tensor_batch_images = torch.stack([img[0] for img in batch_images if img]).unsqueeze(1)
  468. # --- Use self.image_processor again to obtain the full token ids and batch inputs ---
  469. all_encodings = []
  470. for prompt, scale_factor, image_unpadded_height, image_unpadded_width, tensor_batch_image in zip(
  471. prompts, scale_factors, image_unpadded_heights, image_unpadded_widths, tensor_batch_images
  472. ):
  473. sample_encoding = self.get_sample_encoding(
  474. prompts=[prompt],
  475. scale_factors=[scale_factor],
  476. image_unpadded_heights=torch.tensor([image_unpadded_height]),
  477. image_unpadded_widths=torch.tensor([image_unpadded_width]),
  478. image_placeholder_id=self.image_token_id,
  479. image_newline_id=self.image_newline_id,
  480. tensor_batch_images=tensor_batch_image.unsqueeze(0),
  481. )
  482. all_encodings.append(sample_encoding)
  483. batch_encoding = self._left_pad_inputs_with_attention_mask(
  484. model_inputs=all_encodings, return_attention_mask=True
  485. )
  486. if return_mm_token_type_ids:
  487. batch_encoding["mm_token_type_ids"] = self.create_mm_token_type_ids(batch_encoding["input_ids"])
  488. batch_encoding["mm_token_type_ids"] = torch.tensor(batch_encoding["mm_token_type_ids"])
  489. return FuyuBatchFeature(data=batch_encoding)
  490. def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
  491. """
  492. Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
  493. Args:
  494. image_sizes (`list[list[int]]`, *optional*):
  495. The input sizes formatted as (height, width) per each image.
  496. Returns:
  497. `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
  498. input modalities, along with other useful data.
  499. """
  500. vision_data = {}
  501. if image_sizes is not None:
  502. size = kwargs.get("size") or self.image_processor.size
  503. padded_height, padded_width = size["height"], size["width"]
  504. num_image_tokens = []
  505. num_image_patches = [1] * len(image_sizes)
  506. for image_size in image_sizes:
  507. height_scale_factor = padded_height / image_size[0]
  508. width_scale_factor = padded_width / image_size[1]
  509. optimal_scale_factor = min(height_scale_factor, width_scale_factor)
  510. image_unpadded_h = min(int(image_size[0] * optimal_scale_factor), image_size[0])
  511. image_unpadded_w = min(int(image_size[1] * optimal_scale_factor), image_size[1])
  512. # We can use torch here because Fuyu processor has hard dependency on torch. NOTE: Fuyu can't do multi-image
  513. # thus the below (1, 1, 1) is hardcoded. Same as when calling the processor
  514. model_image_input = self.image_processor.preprocess_with_tokenizer_info(
  515. image_input=torch.zeros(1, 1, 3, padded_height, padded_width),
  516. image_present=torch.ones(1, 1, 1),
  517. image_unpadded_h=torch.tensor([[image_unpadded_h]]),
  518. image_unpadded_w=torch.tensor([[image_unpadded_w]]),
  519. image_placeholder_id=0, # dummy ids, we can be sure `id=0` is never out-of-range
  520. image_newline_id=0,
  521. variable_sized=True,
  522. )
  523. num_image_tokens.append(model_image_input["image_input_ids"][0][0].shape[-1])
  524. vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
  525. return MultiModalData(**vision_data)
  526. def post_process_box_coordinates(self, outputs, target_sizes=None):
  527. """
  528. Transforms raw coordinates detected by [`FuyuForCausalLM`] to the original images' coordinate space.
  529. Coordinates will be returned in "box" format, with the following pattern:
  530. `<box>top, left, bottom, right</box>`
  531. Point coordinates are not supported yet.
  532. Args:
  533. outputs ([`GenerateOutput`]):
  534. Raw outputs from `generate`.
  535. target_sizes (`torch.Tensor`, *optional*):
  536. Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in
  537. the batch. If set, found coordinates in the output sequence are rescaled to the target sizes. If left
  538. to None, coordinates will not be rescaled.
  539. Returns:
  540. `GenerateOutput`: Same output type returned by `generate`, with output token ids replaced with
  541. boxed and possible rescaled coordinates.
  542. """
  543. def scale_factor_to_fit(original_size, target_size=None):
  544. height, width = original_size
  545. if target_size is None:
  546. max_height = self.image_processor.size["height"]
  547. max_width = self.image_processor.size["width"]
  548. else:
  549. max_height, max_width = target_size
  550. if width <= max_width and height <= max_height:
  551. return 1.0
  552. return min(max_height / height, max_width / width)
  553. def find_delimiters_pair(tokens, start_token, end_token):
  554. start_id = self.tokenizer.convert_tokens_to_ids(start_token)
  555. end_id = self.tokenizer.convert_tokens_to_ids(end_token)
  556. starting_positions = (tokens == start_id).nonzero(as_tuple=True)[0]
  557. ending_positions = (tokens == end_id).nonzero(as_tuple=True)[0]
  558. if torch.any(starting_positions) and torch.any(ending_positions):
  559. return (starting_positions[0], ending_positions[0])
  560. return (None, None)
  561. def tokens_to_boxes(tokens, original_size):
  562. while (pair := find_delimiters_pair(tokens, TOKEN_BBOX_OPEN_STRING, TOKEN_BBOX_CLOSE_STRING)) != (
  563. None,
  564. None,
  565. ):
  566. start, end = pair
  567. if end != start + 5:
  568. continue
  569. # Retrieve transformed coordinates from tokens
  570. coords = self.tokenizer.convert_ids_to_tokens(tokens[start + 1 : end])
  571. # Scale back to original image size and multiply by 2
  572. scale = scale_factor_to_fit(original_size)
  573. top, left, bottom, right = [2 * int(float(c) / scale) for c in coords]
  574. # Replace the IDs so they get detokenized right
  575. replacement = f" {TEXT_REPR_BBOX_OPEN}{top}, {left}, {bottom}, {right}{TEXT_REPR_BBOX_CLOSE}"
  576. replacement = self.tokenizer.tokenize(replacement)[1:]
  577. replacement = self.tokenizer.convert_tokens_to_ids(replacement)
  578. replacement = torch.tensor(replacement).to(tokens)
  579. tokens = torch.cat([tokens[:start], replacement, tokens[end + 1 :]], 0)
  580. return tokens
  581. def tokens_to_points(tokens, original_size):
  582. while (pair := find_delimiters_pair(tokens, TOKEN_POINT_OPEN_STRING, TOKEN_POINT_CLOSE_STRING)) != (
  583. None,
  584. None,
  585. ):
  586. start, end = pair
  587. if end != start + 3:
  588. continue
  589. # Retrieve transformed coordinates from tokens
  590. coords = self.tokenizer.convert_ids_to_tokens(tokens[start + 1 : end])
  591. # Scale back to original image size and multiply by 2
  592. scale = scale_factor_to_fit(original_size)
  593. x, y = [2 * int(float(c) / scale) for c in coords]
  594. # Replace the IDs so they get detokenized right
  595. replacement = f" {TEXT_REPR_POINT_OPEN}{x}, {y}{TEXT_REPR_POINT_CLOSE}"
  596. replacement = self.tokenizer.tokenize(replacement)[1:]
  597. replacement = self.tokenizer.convert_tokens_to_ids(replacement)
  598. replacement = torch.tensor(replacement).to(tokens)
  599. tokens = torch.cat([tokens[:start], replacement, tokens[end + 1 :]], 0)
  600. return tokens
  601. if target_sizes is None:
  602. target_sizes = ((self.image_processor.size["height"], self.image_processor.size["width"]),) * len(outputs)
  603. elif target_sizes.shape[1] != 2:
  604. raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
  605. if len(outputs) != len(target_sizes):
  606. raise ValueError("Make sure that you pass in as many target sizes as output sequences")
  607. results = []
  608. for seq, size in zip(outputs, target_sizes):
  609. seq = tokens_to_boxes(seq, size)
  610. seq = tokens_to_points(seq, size)
  611. results.append(seq)
  612. return results
  613. def post_process_image_text_to_text(self, generated_outputs, skip_special_tokens=True, **kwargs):
  614. """
  615. Post-processes the output of `FuyuForConditionalGeneration` to only return the text output.
  616. Args:
  617. generated_outputs (`torch.Tensor` or `np.ndarray`):
  618. The output of the model. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
  619. containing the token ids of the generated sequences.
  620. skip_special_tokens (`bool`, *optional*, defaults to `True`):
  621. Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
  622. **kwargs:
  623. Additional arguments to be passed to the tokenizer's `batch_decode method`.
  624. Returns:
  625. `list[str]`: The decoded text output.
  626. """
  627. beginning_of_answer = self.tokenizer.convert_tokens_to_ids(BEGINNING_OF_ANSWER_STRING)
  628. # get boa index for each outputted sequence tensor
  629. # start all generated sequences from the beginning of the answer token, pad to have consistent length
  630. unpadded_output_sequences = [
  631. seq[(seq == beginning_of_answer).nonzero(as_tuple=True)[0] + 1 :] for seq in generated_outputs
  632. ]
  633. max_len = max(len(seq) for seq in unpadded_output_sequences)
  634. # convert to torch and pad sequences
  635. padded_output_sequences = torch.full((len(unpadded_output_sequences), max_len), self.pad_token_id)
  636. for i, seq in enumerate(unpadded_output_sequences):
  637. padded_output_sequences[i, : len(seq)] = torch.tensor(seq)
  638. return self.batch_decode(padded_output_sequences, skip_special_tokens=skip_special_tokens, **kwargs)
  639. @property
  640. def model_input_names(self):
  641. tokenizer_input_names = self.tokenizer.model_input_names
  642. image_processor_input_names = self.image_processor.model_input_names
  643. # Make a copy of list when removing otherwise `self.image_processor.model_input_names` is also modified
  644. extra_image_inputs = [
  645. "image_input_ids",
  646. "image_patch_indices_per_subsequence",
  647. "images",
  648. "image_patch_indices_per_batch",
  649. ]
  650. image_processor_input_names = [name for name in image_processor_input_names if name not in extra_image_inputs]
  651. return list(tokenizer_input_names + image_processor_input_names + ["image_patches_indices"])
  652. __all__ = ["FuyuProcessor"]