table_question_answering.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. import collections
  2. import types
  3. import numpy as np
  4. from ..generation import GenerationConfig
  5. from ..utils import (
  6. add_end_docstrings,
  7. is_torch_available,
  8. requires_backends,
  9. )
  10. from .base import ArgumentHandler, Dataset, Pipeline, PipelineException, build_pipeline_init_args
  11. if is_torch_available():
  12. import torch
  13. from ..models.auto.modeling_auto import (
  14. MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
  15. MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
  16. )
  17. class TableQuestionAnsweringArgumentHandler(ArgumentHandler):
  18. """
  19. Handles arguments for the TableQuestionAnsweringPipeline
  20. """
  21. def __call__(self, table=None, query=None, **kwargs):
  22. # Returns tqa_pipeline_inputs of shape:
  23. # [
  24. # {"table": pd.DataFrame, "query": list[str]},
  25. # ...,
  26. # {"table": pd.DataFrame, "query" : list[str]}
  27. # ]
  28. requires_backends(self, "pandas")
  29. import pandas as pd
  30. if table is None:
  31. raise ValueError("Keyword argument `table` cannot be None.")
  32. elif query is None:
  33. if isinstance(table, dict) and table.get("query") is not None and table.get("table") is not None:
  34. tqa_pipeline_inputs = [table]
  35. elif isinstance(table, list) and len(table) > 0:
  36. if not all(isinstance(d, dict) for d in table):
  37. raise ValueError(
  38. f"Keyword argument `table` should be a list of dict, but is {(type(d) for d in table)}"
  39. )
  40. if table[0].get("query") is not None and table[0].get("table") is not None:
  41. tqa_pipeline_inputs = table
  42. else:
  43. raise ValueError(
  44. "If keyword argument `table` is a list of dictionaries, each dictionary should have a `table`"
  45. f" and `query` key, but only dictionary has keys {table[0].keys()} `table` and `query` keys."
  46. )
  47. elif Dataset is not None and isinstance(table, Dataset) or isinstance(table, types.GeneratorType):
  48. return table
  49. else:
  50. raise ValueError(
  51. "Invalid input. Keyword argument `table` should be either of type `dict` or `list`, but "
  52. f"is {type(table)})"
  53. )
  54. else:
  55. tqa_pipeline_inputs = [{"table": table, "query": query}]
  56. for tqa_pipeline_input in tqa_pipeline_inputs:
  57. if not isinstance(tqa_pipeline_input["table"], pd.DataFrame):
  58. if tqa_pipeline_input["table"] is None:
  59. raise ValueError("Table cannot be None.")
  60. tqa_pipeline_input["table"] = pd.DataFrame(tqa_pipeline_input["table"])
  61. return tqa_pipeline_inputs
  62. @add_end_docstrings(build_pipeline_init_args(has_tokenizer=True))
  63. class TableQuestionAnsweringPipeline(Pipeline):
  64. """
  65. Table Question Answering pipeline using a `ModelForTableQuestionAnswering`. This pipeline is only available in
  66. PyTorch.
  67. Unless the model you're using explicitly sets these generation parameters in its configuration files
  68. (`generation_config.json`), the following default values will be used:
  69. - max_new_tokens: 256
  70. Example:
  71. ```python
  72. >>> from transformers import pipeline
  73. >>> oracle = pipeline(model="google/tapas-base-finetuned-wtq")
  74. >>> table = {
  75. ... "Repository": ["Transformers", "Datasets", "Tokenizers"],
  76. ... "Stars": ["36542", "4512", "3934"],
  77. ... "Contributors": ["651", "77", "34"],
  78. ... "Programming language": ["Python", "Python", "Rust, Python and NodeJS"],
  79. ... }
  80. >>> oracle(query="How many stars does the transformers repository have?", table=table)
  81. {'answer': 'AVERAGE > 36542', 'coordinates': [(0, 1)], 'cells': ['36542'], 'aggregator': 'AVERAGE'}
  82. ```
  83. Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
  84. This tabular question answering pipeline can currently be loaded from [`pipeline`] using the following task
  85. identifier: `"table-question-answering"`.
  86. The models that this pipeline can use are models that have been fine-tuned on a tabular question answering task.
  87. See the up-to-date list of available models on
  88. [huggingface.co/models](https://huggingface.co/models?filter=table-question-answering).
  89. """
  90. default_input_names = "table,query"
  91. _pipeline_calls_generate = True
  92. _load_processor = False
  93. _load_image_processor = False
  94. _load_feature_extractor = False
  95. _load_tokenizer = True
  96. # Make sure the docstring is updated when the default generation config is changed
  97. _default_generation_config = GenerationConfig(
  98. max_new_tokens=256,
  99. )
  100. def __init__(self, args_parser=TableQuestionAnsweringArgumentHandler(), **kwargs):
  101. super().__init__(**kwargs)
  102. self._args_parser = args_parser
  103. mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES.copy()
  104. mapping.update(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES)
  105. self.check_model_type(mapping)
  106. self.aggregate = getattr(self.model.config, "aggregation_labels", None) and getattr(
  107. self.model.config, "num_aggregation_labels", None
  108. )
  109. self.type = "tapas" if hasattr(self.model.config, "aggregation_labels") else None
  110. def batch_inference(self, **inputs):
  111. return self.model(**inputs)
  112. def sequential_inference(self, **inputs):
  113. """
  114. Inference used for models that need to process sequences in a sequential fashion, like the SQA models which
  115. handle conversational query related to a table.
  116. """
  117. all_logits = []
  118. all_aggregations = []
  119. prev_answers = None
  120. batch_size = inputs["input_ids"].shape[0]
  121. input_ids = inputs["input_ids"].to(self.device)
  122. attention_mask = inputs["attention_mask"].to(self.device)
  123. token_type_ids = inputs["token_type_ids"].to(self.device)
  124. token_type_ids_example = None
  125. for index in range(batch_size):
  126. # If sequences have already been processed, the token type IDs will be created according to the previous
  127. # answer.
  128. if prev_answers is not None:
  129. prev_labels_example = token_type_ids_example[:, 3] # shape (seq_len,)
  130. model_labels = np.zeros_like(prev_labels_example.cpu().numpy()) # shape (seq_len,)
  131. token_type_ids_example = token_type_ids[index] # shape (seq_len, 7)
  132. for i in range(model_labels.shape[0]):
  133. segment_id = token_type_ids_example[:, 0].tolist()[i]
  134. col_id = token_type_ids_example[:, 1].tolist()[i] - 1
  135. row_id = token_type_ids_example[:, 2].tolist()[i] - 1
  136. if row_id >= 0 and col_id >= 0 and segment_id == 1:
  137. model_labels[i] = int(prev_answers[(col_id, row_id)])
  138. token_type_ids_example[:, 3] = torch.from_numpy(model_labels).type(torch.long).to(self.device)
  139. input_ids_example = input_ids[index]
  140. attention_mask_example = attention_mask[index] # shape (seq_len,)
  141. token_type_ids_example = token_type_ids[index] # shape (seq_len, 7)
  142. outputs = self.model(
  143. input_ids=input_ids_example.unsqueeze(0),
  144. attention_mask=attention_mask_example.unsqueeze(0),
  145. token_type_ids=token_type_ids_example.unsqueeze(0),
  146. )
  147. logits = outputs.logits
  148. if self.aggregate:
  149. all_aggregations.append(outputs.logits_aggregation)
  150. all_logits.append(logits)
  151. dist_per_token = torch.distributions.Bernoulli(logits=logits)
  152. probabilities = dist_per_token.probs * attention_mask_example.type(torch.float32).to(
  153. dist_per_token.probs.device
  154. )
  155. coords_to_probs = collections.defaultdict(list)
  156. for i, p in enumerate(probabilities.squeeze().tolist()):
  157. segment_id = token_type_ids_example[:, 0].tolist()[i]
  158. col = token_type_ids_example[:, 1].tolist()[i] - 1
  159. row = token_type_ids_example[:, 2].tolist()[i] - 1
  160. if col >= 0 and row >= 0 and segment_id == 1:
  161. coords_to_probs[(col, row)].append(p)
  162. prev_answers = {key: np.array(coords_to_probs[key]).mean() > 0.5 for key in coords_to_probs}
  163. logits_batch = torch.cat(tuple(all_logits), 0)
  164. return (logits_batch,) if not self.aggregate else (logits_batch, torch.cat(tuple(all_aggregations), 0))
  165. def __call__(self, *args, **kwargs):
  166. r"""
  167. Answers queries according to a table. The pipeline accepts several types of inputs which are detailed below:
  168. - `pipeline(table, query)`
  169. - `pipeline(table, [query])`
  170. - `pipeline(table=table, query=query)`
  171. - `pipeline(table=table, query=[query])`
  172. - `pipeline({"table": table, "query": query})`
  173. - `pipeline({"table": table, "query": [query]})`
  174. - `pipeline([{"table": table, "query": query}, {"table": table, "query": query}])`
  175. The `table` argument should be a dict or a DataFrame built from that dict, containing the whole table:
  176. Example:
  177. ```python
  178. data = {
  179. "actors": ["brad pitt", "leonardo di caprio", "george clooney"],
  180. "age": ["56", "45", "59"],
  181. "number of movies": ["87", "53", "69"],
  182. "date of birth": ["7 february 1967", "10 june 1996", "28 november 1967"],
  183. }
  184. ```
  185. This dictionary can be passed in as such, or can be converted to a pandas DataFrame:
  186. Example:
  187. ```python
  188. import pandas as pd
  189. table = pd.DataFrame.from_dict(data)
  190. ```
  191. Args:
  192. table (`pd.DataFrame` or `Dict`):
  193. Pandas DataFrame or dictionary that will be converted to a DataFrame containing all the table values.
  194. See above for an example of dictionary.
  195. query (`str` or `list[str]`):
  196. Query or list of queries that will be sent to the model alongside the table.
  197. sequential (`bool`, *optional*, defaults to `False`):
  198. Whether to do inference sequentially or as a batch. Batching is faster, but models like SQA require the
  199. inference to be done sequentially to extract relations within sequences, given their conversational
  200. nature.
  201. padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
  202. Activates and controls padding. Accepts the following values:
  203. - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
  204. sequence if provided).
  205. - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
  206. acceptable input length for the model if that argument is not provided.
  207. - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
  208. lengths).
  209. truncation (`bool`, `str` or [`TapasTruncationStrategy`], *optional*, defaults to `False`):
  210. Activates and controls truncation. Accepts the following values:
  211. - `True` or `'drop_rows_to_fit'`: Truncate to a maximum length specified with the argument `max_length`
  212. or to the maximum acceptable input length for the model if that argument is not provided. This will
  213. truncate row by row, removing rows from the table.
  214. - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
  215. greater than the model maximum admissible input size).
  216. Return:
  217. A dictionary or a list of dictionaries containing results: Each result is a dictionary with the following
  218. keys:
  219. - **answer** (`str`) -- The answer of the query given the table. If there is an aggregator, the answer will
  220. be preceded by `AGGREGATOR >`.
  221. - **coordinates** (`list[tuple[int, int]]`) -- Coordinates of the cells of the answers.
  222. - **cells** (`list[str]`) -- List of strings made up of the answer cell values.
  223. - **aggregator** (`str`) -- If the model has an aggregator, this returns the aggregator.
  224. """
  225. pipeline_inputs = self._args_parser(*args, **kwargs)
  226. results = super().__call__(pipeline_inputs, **kwargs)
  227. if len(results) == 1:
  228. return results[0]
  229. return results
  230. def _sanitize_parameters(self, sequential=None, padding=None, truncation=None, **kwargs):
  231. preprocess_params = {}
  232. if padding is not None:
  233. preprocess_params["padding"] = padding
  234. if truncation is not None:
  235. preprocess_params["truncation"] = truncation
  236. forward_params = {}
  237. if sequential is not None:
  238. forward_params["sequential"] = sequential
  239. if getattr(self, "assistant_model", None) is not None:
  240. forward_params["assistant_model"] = self.assistant_model
  241. if getattr(self, "assistant_tokenizer", None) is not None:
  242. forward_params["tokenizer"] = self.tokenizer
  243. forward_params["assistant_tokenizer"] = self.assistant_tokenizer
  244. return preprocess_params, forward_params, {}
  245. def preprocess(self, pipeline_input, padding=True, truncation=None):
  246. if truncation is None:
  247. if self.type == "tapas":
  248. truncation = "drop_rows_to_fit"
  249. else:
  250. truncation = "do_not_truncate"
  251. table, query = pipeline_input["table"], pipeline_input["query"]
  252. if table.empty:
  253. raise ValueError("table is empty")
  254. if query is None or query == "":
  255. raise ValueError("query is empty")
  256. inputs = self.tokenizer(table, query, return_tensors="pt", truncation=truncation, padding=padding)
  257. inputs["table"] = table
  258. return inputs
  259. def _forward(self, model_inputs, sequential=False, **generate_kwargs):
  260. table = model_inputs.pop("table")
  261. if self.type == "tapas":
  262. if sequential:
  263. outputs = self.sequential_inference(**model_inputs)
  264. else:
  265. outputs = self.batch_inference(**model_inputs)
  266. else:
  267. # User-defined `generation_config` passed to the pipeline call take precedence
  268. if "generation_config" not in generate_kwargs:
  269. generate_kwargs["generation_config"] = self.generation_config
  270. outputs = self.model.generate(**model_inputs, **generate_kwargs)
  271. model_outputs = {"model_inputs": model_inputs, "table": table, "outputs": outputs}
  272. return model_outputs
  273. def postprocess(self, model_outputs):
  274. inputs = model_outputs["model_inputs"]
  275. table = model_outputs["table"]
  276. outputs = model_outputs["outputs"]
  277. if self.type == "tapas":
  278. if self.aggregate:
  279. logits, logits_agg = outputs[:2]
  280. predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits, logits_agg)
  281. answer_coordinates_batch, agg_predictions = predictions
  282. aggregators = {i: self.model.config.aggregation_labels[pred] for i, pred in enumerate(agg_predictions)}
  283. no_agg_label_index = self.model.config.no_aggregation_label_index
  284. aggregators_prefix = {
  285. i: aggregators[i] + " > " for i, pred in enumerate(agg_predictions) if pred != no_agg_label_index
  286. }
  287. else:
  288. logits = outputs[0]
  289. predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits)
  290. answer_coordinates_batch = predictions[0]
  291. aggregators = {}
  292. aggregators_prefix = {}
  293. answers = []
  294. for index, coordinates in enumerate(answer_coordinates_batch):
  295. cells = [table.iat[coordinate] for coordinate in coordinates]
  296. aggregator = aggregators.get(index, "")
  297. aggregator_prefix = aggregators_prefix.get(index, "")
  298. answer = {
  299. "answer": aggregator_prefix + ", ".join(cells),
  300. "coordinates": coordinates,
  301. "cells": [table.iat[coordinate] for coordinate in coordinates],
  302. }
  303. if aggregator:
  304. answer["aggregator"] = aggregator
  305. answers.append(answer)
  306. if len(answer) == 0:
  307. raise PipelineException("Table question answering", self.model.name_or_path, "Empty answer")
  308. else:
  309. answers = [{"answer": answer} for answer in self.tokenizer.batch_decode(outputs, skip_special_tokens=True)]
  310. return answers if len(answers) > 1 else answers[0]