tokenization_tapas.py 116 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786
  1. # Copyright 2020 Google Research and The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Tokenization class for TAPAS model."""
  15. import collections
  16. import datetime
  17. import enum
  18. import itertools
  19. import math
  20. import os
  21. import re
  22. import unicodedata
  23. from collections.abc import Callable, Generator
  24. from dataclasses import dataclass
  25. from typing import Union
  26. import numpy as np
  27. from ...tokenization_python import PreTrainedTokenizer, Trie, _is_control, _is_punctuation, _is_whitespace
  28. from ...tokenization_utils_base import (
  29. ENCODE_KWARGS_DOCSTRING,
  30. VERY_LARGE_INTEGER,
  31. BatchEncoding,
  32. EncodedInput,
  33. PreTokenizedInput,
  34. TextInput,
  35. )
  36. from ...utils import ExplicitEnum, PaddingStrategy, TensorType, add_end_docstrings, is_pandas_available, logging
  37. if is_pandas_available():
  38. import pandas as pd
  39. logger = logging.get_logger(__name__)
  40. VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
  41. class TapasTruncationStrategy(ExplicitEnum):
  42. """
  43. Possible values for the `truncation` argument in [`~TapasTokenizer.__call__`]. Useful for tab-completion in an IDE.
  44. """
  45. DROP_ROWS_TO_FIT = "drop_rows_to_fit"
  46. DO_NOT_TRUNCATE = "do_not_truncate"
  47. TableValue = collections.namedtuple("TokenValue", ["token", "column_id", "row_id"])
  48. @dataclass(frozen=True)
  49. class TokenCoordinates:
  50. column_index: int
  51. row_index: int
  52. token_index: int
  53. @dataclass
  54. class TokenizedTable:
  55. rows: list[list[list[str]]]
  56. selected_tokens: list[TokenCoordinates]
  57. @dataclass(frozen=True)
  58. class SerializedExample:
  59. tokens: list[str]
  60. column_ids: list[int]
  61. row_ids: list[int]
  62. segment_ids: list[int]
  63. def _is_inner_wordpiece(token: str):
  64. return token.startswith("##")
  65. def load_vocab(vocab_file):
  66. """Loads a vocabulary file into a dictionary."""
  67. vocab = collections.OrderedDict()
  68. with open(vocab_file, "r", encoding="utf-8") as reader:
  69. tokens = reader.readlines()
  70. for index, token in enumerate(tokens):
  71. token = token.rstrip("\n")
  72. vocab[token] = index
  73. return vocab
  74. def whitespace_tokenize(text):
  75. """Runs basic whitespace cleaning and splitting on a piece of text."""
  76. text = text.strip()
  77. if not text:
  78. return []
  79. tokens = text.split()
  80. return tokens
  81. TAPAS_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r"""
  82. add_special_tokens (`bool`, *optional*, defaults to `True`):
  83. Whether or not to encode the sequences with the special tokens relative to their model.
  84. padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
  85. Activates and controls padding. Accepts the following values:
  86. - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
  87. sequence if provided).
  88. - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
  89. acceptable input length for the model if that argument is not provided.
  90. - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
  91. lengths).
  92. truncation (`bool`, `str` or [`TapasTruncationStrategy`], *optional*, defaults to `False`):
  93. Activates and controls truncation. Accepts the following values:
  94. - `True` or `'drop_rows_to_fit'`: Truncate to a maximum length specified with the argument `max_length`
  95. or to the maximum acceptable input length for the model if that argument is not provided. This will
  96. truncate row by row, removing rows from the table.
  97. - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
  98. greater than the model maximum admissible input size).
  99. max_length (`int`, *optional*):
  100. Controls the maximum length to use by one of the truncation/padding parameters.
  101. If left unset or set to `None`, this will use the predefined model maximum length if a maximum length
  102. is required by one of the truncation/padding parameters. If the model has no specific maximum input
  103. length (like XLNet) truncation/padding to a maximum length will be deactivated.
  104. is_split_into_words (`bool`, *optional*, defaults to `False`):
  105. Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the
  106. tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace)
  107. which it will tokenize. This is useful for NER or token classification.
  108. pad_to_multiple_of (`int`, *optional*):
  109. If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
  110. the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta).
  111. return_tensors (`str` or [`~utils.TensorType`], *optional*):
  112. If set, will return tensors instead of list of python integers. Acceptable values are:
  113. - `'pt'`: Return PyTorch `torch.Tensor` objects.
  114. - `'np'`: Return Numpy `np.ndarray` objects.
  115. """
  116. class TapasTokenizer(PreTrainedTokenizer):
  117. r"""
  118. Construct a TAPAS tokenizer. Based on WordPiece. Flattens a table and one or more related sentences to be used by
  119. TAPAS models.
  120. This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
  121. this superclass for more information regarding those methods. [`TapasTokenizer`] creates several token type ids to
  122. encode tabular structure. To be more precise, it adds 7 token type ids, in the following order: `segment_ids`,
  123. `column_ids`, `row_ids`, `prev_labels`, `column_ranks`, `inv_column_ranks` and `numeric_relations`:
  124. - segment_ids: indicate whether a token belongs to the question (0) or the table (1). 0 for special tokens and
  125. padding.
  126. - column_ids: indicate to which column of the table a token belongs (starting from 1). Is 0 for all question
  127. tokens, special tokens and padding.
  128. - row_ids: indicate to which row of the table a token belongs (starting from 1). Is 0 for all question tokens,
  129. special tokens and padding. Tokens of column headers are also 0.
  130. - prev_labels: indicate whether a token was (part of) an answer to the previous question (1) or not (0). Useful in
  131. a conversational setup (such as SQA).
  132. - column_ranks: indicate the rank of a table token relative to a column, if applicable. For example, if you have a
  133. column "number of movies" with values 87, 53 and 69, then the column ranks of these tokens are 3, 1 and 2
  134. respectively. 0 for all question tokens, special tokens and padding.
  135. - inv_column_ranks: indicate the inverse rank of a table token relative to a column, if applicable. For example, if
  136. you have a column "number of movies" with values 87, 53 and 69, then the inverse column ranks of these tokens are
  137. 1, 3 and 2 respectively. 0 for all question tokens, special tokens and padding.
  138. - numeric_relations: indicate numeric relations between the question and the tokens of the table. 0 for all
  139. question tokens, special tokens and padding.
  140. [`TapasTokenizer`] runs end-to-end tokenization on a table and associated sentences: punctuation splitting and
  141. wordpiece.
  142. Args:
  143. vocab_file (`str`):
  144. File containing the vocabulary.
  145. do_lower_case (`bool`, *optional*, defaults to `True`):
  146. Whether or not to lowercase the input when tokenizing.
  147. do_basic_tokenize (`bool`, *optional*, defaults to `True`):
  148. Whether or not to do basic tokenization before WordPiece.
  149. never_split (`Iterable`, *optional*):
  150. Collection of tokens which will never be split during tokenization. Only has an effect when
  151. `do_basic_tokenize=True`
  152. unk_token (`str`, *optional*, defaults to `"[UNK]"`):
  153. The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
  154. token instead.
  155. sep_token (`str`, *optional*, defaults to `"[SEP]"`):
  156. The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
  157. sequence classification or for a text and a question for question answering. It is also used as the last
  158. token of a sequence built with special tokens.
  159. pad_token (`str`, *optional*, defaults to `"[PAD]"`):
  160. The token used for padding, for example when batching sequences of different lengths.
  161. cls_token (`str`, *optional*, defaults to `"[CLS]"`):
  162. The classifier token which is used when doing sequence classification (classification of the whole sequence
  163. instead of per-token classification). It is the first token of the sequence when built with special tokens.
  164. mask_token (`str`, *optional*, defaults to `"[MASK]"`):
  165. The token used for masking values. This is the token used when training this model with masked language
  166. modeling. This is the token which the model will try to predict.
  167. empty_token (`str`, *optional*, defaults to `"[EMPTY]"`):
  168. The token used for empty cell values in a table. Empty cell values include "", "n/a", "nan" and "?".
  169. tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
  170. Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see this
  171. [issue](https://github.com/huggingface/transformers/issues/328)).
  172. strip_accents (`bool`, *optional*):
  173. Whether or not to strip all accents. If this option is not specified, then it will be determined by the
  174. value for `lowercase` (as in the original BERT).
  175. cell_trim_length (`int`, *optional*, defaults to -1):
  176. If > 0: Trim cells so that the length is <= this value. Also disables further cell trimming, should thus be
  177. used with `truncation` set to `True`.
  178. max_column_id (`int`, *optional*):
  179. Max column id to extract.
  180. max_row_id (`int`, *optional*):
  181. Max row id to extract.
  182. strip_column_names (`bool`, *optional*, defaults to `False`):
  183. Whether to add empty strings instead of column names.
  184. update_answer_coordinates (`bool`, *optional*, defaults to `False`):
  185. Whether to recompute the answer coordinates from the answer text.
  186. min_question_length (`int`, *optional*):
  187. Minimum length of each question in terms of tokens (will be skipped otherwise).
  188. max_question_length (`int`, *optional*):
  189. Maximum length of each question in terms of tokens (will be skipped otherwise).
  190. clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
  191. Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
  192. extra spaces.
  193. """
  194. model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
  195. vocab_files_names = VOCAB_FILES_NAMES
  196. def __init__(
  197. self,
  198. vocab_file,
  199. do_lower_case=True,
  200. do_basic_tokenize=True,
  201. never_split=None,
  202. unk_token="[UNK]",
  203. sep_token="[SEP]",
  204. pad_token="[PAD]",
  205. cls_token="[CLS]",
  206. mask_token="[MASK]",
  207. empty_token="[EMPTY]",
  208. tokenize_chinese_chars=True,
  209. strip_accents=None,
  210. cell_trim_length: int = -1,
  211. max_column_id: int | None = None,
  212. max_row_id: int | None = None,
  213. strip_column_names: bool = False,
  214. update_answer_coordinates: bool = False,
  215. min_question_length=None,
  216. max_question_length=None,
  217. model_max_length: int = 512,
  218. additional_special_tokens: list[str] | None = None,
  219. clean_up_tokenization_spaces=True,
  220. **kwargs,
  221. ):
  222. if not is_pandas_available():
  223. raise ImportError("Pandas is required for the TAPAS tokenizer.")
  224. if additional_special_tokens is not None:
  225. if empty_token not in additional_special_tokens:
  226. additional_special_tokens.append(empty_token)
  227. else:
  228. additional_special_tokens = [empty_token]
  229. if not os.path.isfile(vocab_file):
  230. raise ValueError(
  231. f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
  232. " model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
  233. )
  234. self.vocab = load_vocab(vocab_file)
  235. self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
  236. self.do_basic_tokenize = do_basic_tokenize
  237. if do_basic_tokenize:
  238. self.basic_tokenizer = BasicTokenizer(
  239. do_lower_case=do_lower_case,
  240. never_split=never_split,
  241. tokenize_chinese_chars=tokenize_chinese_chars,
  242. strip_accents=strip_accents,
  243. )
  244. self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
  245. # Additional properties
  246. self.cell_trim_length = cell_trim_length
  247. self.max_column_id = (
  248. max_column_id
  249. if max_column_id is not None
  250. else model_max_length
  251. if model_max_length is not None
  252. else VERY_LARGE_INTEGER
  253. )
  254. self.max_row_id = (
  255. max_row_id
  256. if max_row_id is not None
  257. else model_max_length
  258. if model_max_length is not None
  259. else VERY_LARGE_INTEGER
  260. )
  261. self.strip_column_names = strip_column_names
  262. self.update_answer_coordinates = update_answer_coordinates
  263. self.min_question_length = min_question_length
  264. self.max_question_length = max_question_length
  265. super().__init__(
  266. do_lower_case=do_lower_case,
  267. do_basic_tokenize=do_basic_tokenize,
  268. never_split=never_split,
  269. unk_token=unk_token,
  270. sep_token=sep_token,
  271. pad_token=pad_token,
  272. cls_token=cls_token,
  273. mask_token=mask_token,
  274. empty_token=empty_token,
  275. tokenize_chinese_chars=tokenize_chinese_chars,
  276. strip_accents=strip_accents,
  277. cell_trim_length=cell_trim_length,
  278. max_column_id=max_column_id,
  279. max_row_id=max_row_id,
  280. strip_column_names=strip_column_names,
  281. update_answer_coordinates=update_answer_coordinates,
  282. min_question_length=min_question_length,
  283. max_question_length=max_question_length,
  284. model_max_length=model_max_length,
  285. additional_special_tokens=additional_special_tokens,
  286. clean_up_tokenization_spaces=clean_up_tokenization_spaces,
  287. **kwargs,
  288. )
  289. # Tests override the vocab while reusing a tokenizer_config.json coming from a pretrained model.
  290. # This can register base vocab tokens (like [UNK]) as added tokens with mismatched ids (e.g. 100)
  291. # and breaks assumptions on token ordering. Drop any added-token entry that overlaps with the vocab
  292. # so these tokens rely on the vocab-provided ids.
  293. removed_overlap = False
  294. for token, added_id in list(self._added_tokens_encoder.items()):
  295. if token in self.vocab:
  296. self._added_tokens_encoder.pop(token, None)
  297. self._added_tokens_decoder.pop(added_id, None)
  298. removed_overlap = True
  299. if removed_overlap:
  300. self.tokens_trie = Trie()
  301. self._update_trie()
  302. @property
  303. def do_lower_case(self):
  304. return self.basic_tokenizer.do_lower_case
  305. @property
  306. def vocab_size(self):
  307. return len(self.vocab)
  308. def get_vocab(self):
  309. return dict(self.vocab, **self.added_tokens_encoder)
  310. def _tokenize(self, text):
  311. if format_text(text) == EMPTY_TEXT:
  312. return [self.extra_special_tokens[0]]
  313. split_tokens = []
  314. if self.do_basic_tokenize:
  315. for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
  316. # If the token is part of the never_split set
  317. if token in self.basic_tokenizer.never_split:
  318. split_tokens.append(token)
  319. else:
  320. split_tokens += self.wordpiece_tokenizer.tokenize(token)
  321. else:
  322. split_tokens = self.wordpiece_tokenizer.tokenize(text)
  323. return split_tokens
  324. def _convert_token_to_id(self, token):
  325. """Converts a token (str) in an id using the vocab."""
  326. return self.vocab.get(token, self.vocab.get(self.unk_token))
  327. def _convert_id_to_token(self, index):
  328. """Converts an index (integer) in a token (str) using the vocab."""
  329. return self.ids_to_tokens.get(index, self.unk_token)
  330. def convert_tokens_to_string(self, tokens):
  331. """Converts a sequence of tokens (string) in a single string."""
  332. out_string = " ".join(tokens).replace(" ##", "").strip()
  333. return out_string
  334. def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str]:
  335. index = 0
  336. if os.path.isdir(save_directory):
  337. vocab_file = os.path.join(
  338. save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
  339. )
  340. else:
  341. vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
  342. with open(vocab_file, "w", encoding="utf-8") as writer:
  343. for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
  344. if index != token_index:
  345. logger.warning(
  346. f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
  347. " Please check that the vocabulary is not corrupted!"
  348. )
  349. index = token_index
  350. writer.write(token + "\n")
  351. index += 1
  352. return (vocab_file,)
  353. def create_attention_mask_from_sequences(self, query_ids: list[int], table_values: list[TableValue]) -> list[int]:
  354. """
  355. Creates the attention mask according to the query token IDs and a list of table values.
  356. Args:
  357. query_ids (`list[int]`): list of token IDs corresponding to the ID.
  358. table_values (`list[TableValue]`): lift of table values, which are named tuples containing the
  359. token value, the column ID and the row ID of said token.
  360. Returns:
  361. `list[int]`: List of ints containing the attention mask values.
  362. """
  363. return [1] * (1 + len(query_ids) + 1 + len(table_values))
  364. def create_segment_token_type_ids_from_sequences(
  365. self, query_ids: list[int], table_values: list[TableValue]
  366. ) -> list[int]:
  367. """
  368. Creates the segment token type IDs according to the query token IDs and a list of table values.
  369. Args:
  370. query_ids (`list[int]`): list of token IDs corresponding to the ID.
  371. table_values (`list[TableValue]`): lift of table values, which are named tuples containing the
  372. token value, the column ID and the row ID of said token.
  373. Returns:
  374. `list[int]`: List of ints containing the segment token type IDs values.
  375. """
  376. table_ids = list(zip(*table_values))[0] if table_values else []
  377. return [0] * (1 + len(query_ids) + 1) + [1] * len(table_ids)
  378. def create_column_token_type_ids_from_sequences(
  379. self, query_ids: list[int], table_values: list[TableValue]
  380. ) -> list[int]:
  381. """
  382. Creates the column token type IDs according to the query token IDs and a list of table values.
  383. Args:
  384. query_ids (`list[int]`): list of token IDs corresponding to the ID.
  385. table_values (`list[TableValue]`): lift of table values, which are named tuples containing the
  386. token value, the column ID and the row ID of said token.
  387. Returns:
  388. `list[int]`: List of ints containing the column token type IDs values.
  389. """
  390. table_column_ids = list(zip(*table_values))[1] if table_values else []
  391. return [0] * (1 + len(query_ids) + 1) + list(table_column_ids)
  392. def create_row_token_type_ids_from_sequences(
  393. self, query_ids: list[int], table_values: list[TableValue]
  394. ) -> list[int]:
  395. """
  396. Creates the row token type IDs according to the query token IDs and a list of table values.
  397. Args:
  398. query_ids (`list[int]`): list of token IDs corresponding to the ID.
  399. table_values (`list[TableValue]`): lift of table values, which are named tuples containing the
  400. token value, the column ID and the row ID of said token.
  401. Returns:
  402. `list[int]`: List of ints containing the row token type IDs values.
  403. """
  404. table_row_ids = list(zip(*table_values))[2] if table_values else []
  405. return [0] * (1 + len(query_ids) + 1) + list(table_row_ids)
  406. def build_inputs_with_special_tokens(
  407. self, token_ids_0: list[int], token_ids_1: list[int] | None = None
  408. ) -> list[int]:
  409. """
  410. Build model inputs from a question and flattened table for question answering or sequence classification tasks
  411. by concatenating and adding special tokens.
  412. Args:
  413. token_ids_0 (`list[int]`): The ids of the question.
  414. token_ids_1 (`list[int]`, *optional*): The ids of the flattened table.
  415. Returns:
  416. `list[int]`: The model input with special tokens.
  417. """
  418. if token_ids_1 is None:
  419. raise ValueError("With TAPAS, you must provide both question IDs and table IDs.")
  420. return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + token_ids_1
  421. def get_special_tokens_mask(
  422. self, token_ids_0: list[int], token_ids_1: list[int] | None = None, already_has_special_tokens: bool = False
  423. ) -> list[int]:
  424. """
  425. Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
  426. special tokens using the tokenizer `prepare_for_model` method.
  427. Args:
  428. token_ids_0 (`list[int]`):
  429. List of question IDs.
  430. token_ids_1 (`list[int]`, *optional*):
  431. List of flattened table IDs.
  432. already_has_special_tokens (`bool`, *optional*, defaults to `False`):
  433. Whether or not the token list is already formatted with special tokens for the model.
  434. Returns:
  435. `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
  436. """
  437. if already_has_special_tokens:
  438. return super().get_special_tokens_mask(
  439. token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
  440. )
  441. if token_ids_1 is not None:
  442. return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
  443. return [1] + ([0] * len(token_ids_0)) + [1]
  444. @add_end_docstrings(TAPAS_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
  445. def __call__(
  446. self,
  447. table: Union["pd.DataFrame", TextInput, list[TextInput], None],
  448. queries: TextInput
  449. | PreTokenizedInput
  450. | EncodedInput
  451. | list[TextInput]
  452. | list[PreTokenizedInput]
  453. | list[EncodedInput]
  454. | None = None,
  455. answer_coordinates: list[tuple] | list[list[tuple]] | None = None,
  456. answer_text: list[TextInput] | list[list[TextInput]] | None = None,
  457. add_special_tokens: bool = True,
  458. padding: bool | str | PaddingStrategy = False,
  459. truncation: bool | str | TapasTruncationStrategy = False,
  460. max_length: int | None = None,
  461. pad_to_multiple_of: int | None = None,
  462. padding_side: str | None = None,
  463. return_tensors: str | TensorType | None = None,
  464. return_token_type_ids: bool | None = None,
  465. return_attention_mask: bool | None = None,
  466. return_overflowing_tokens: bool = False,
  467. return_special_tokens_mask: bool = False,
  468. return_offsets_mapping: bool = False,
  469. return_length: bool = False,
  470. verbose: bool = True,
  471. **kwargs,
  472. ) -> BatchEncoding:
  473. """
  474. Main method to tokenize and prepare for the model one or several sequence(s) related to a table.
  475. Args:
  476. table (`pd.DataFrame` or `str` or `list[str]`):
  477. Table containing tabular data. Note that all cell values must be text. Use *.astype(str)* on a Pandas
  478. dataframe to convert it to string. When passing a string or list of strings, those will be interpreted
  479. as queries with an empty table (to support generic tokenizer tests).
  480. queries (`str` or `list[str]`):
  481. Question or batch of questions related to a table to be encoded. Note that in case of a batch, all
  482. questions must refer to the **same** table.
  483. answer_coordinates (`list[Tuple]` or `list[list[Tuple]]`, *optional*):
  484. Answer coordinates of each table-question pair in the batch. In case only a single table-question pair
  485. is provided, then the answer_coordinates must be a single list of one or more tuples. Each tuple must
  486. be a (row_index, column_index) pair. The first data row (not the column header row) has index 0. The
  487. first column has index 0. In case a batch of table-question pairs is provided, then the
  488. answer_coordinates must be a list of lists of tuples (each list corresponding to a single
  489. table-question pair).
  490. answer_text (`list[str]` or `list[list[str]]`, *optional*):
  491. Answer text of each table-question pair in the batch. In case only a single table-question pair is
  492. provided, then the answer_text must be a single list of one or more strings. Each string must be the
  493. answer text of a corresponding answer coordinate. In case a batch of table-question pairs is provided,
  494. then the answer_coordinates must be a list of lists of strings (each list corresponding to a single
  495. table-question pair).
  496. """
  497. if not isinstance(table, pd.DataFrame):
  498. if queries is not None:
  499. raise AssertionError("Table must be of type pd.DataFrame when queries are provided separately.")
  500. inferred_queries = table
  501. table = pd.DataFrame.from_dict({})
  502. queries = inferred_queries
  503. # Input type checking for clearer error
  504. valid_query = False
  505. # Check that query has a valid type
  506. if queries is None or isinstance(queries, str):
  507. valid_query = True
  508. elif isinstance(queries, (list, tuple)):
  509. if len(queries) == 0 or isinstance(queries[0], str):
  510. valid_query = True
  511. if not valid_query:
  512. raise ValueError(
  513. "queries input must of type `str` (single example), `list[str]` (batch or single pretokenized"
  514. " example). "
  515. )
  516. is_batched = isinstance(queries, (list, tuple))
  517. if is_batched:
  518. return self.batch_encode_plus(
  519. table=table,
  520. queries=queries,
  521. answer_coordinates=answer_coordinates,
  522. answer_text=answer_text,
  523. add_special_tokens=add_special_tokens,
  524. padding=padding,
  525. truncation=truncation,
  526. max_length=max_length,
  527. pad_to_multiple_of=pad_to_multiple_of,
  528. padding_side=padding_side,
  529. return_tensors=return_tensors,
  530. return_token_type_ids=return_token_type_ids,
  531. return_attention_mask=return_attention_mask,
  532. return_overflowing_tokens=return_overflowing_tokens,
  533. return_special_tokens_mask=return_special_tokens_mask,
  534. return_offsets_mapping=return_offsets_mapping,
  535. return_length=return_length,
  536. verbose=verbose,
  537. **kwargs,
  538. )
  539. else:
  540. return self.encode_plus(
  541. table=table,
  542. query=queries,
  543. answer_coordinates=answer_coordinates,
  544. answer_text=answer_text,
  545. add_special_tokens=add_special_tokens,
  546. padding=padding,
  547. truncation=truncation,
  548. max_length=max_length,
  549. pad_to_multiple_of=pad_to_multiple_of,
  550. padding_side=padding_side,
  551. return_tensors=return_tensors,
  552. return_token_type_ids=return_token_type_ids,
  553. return_attention_mask=return_attention_mask,
  554. return_overflowing_tokens=return_overflowing_tokens,
  555. return_special_tokens_mask=return_special_tokens_mask,
  556. return_offsets_mapping=return_offsets_mapping,
  557. return_length=return_length,
  558. verbose=verbose,
  559. **kwargs,
  560. )
  561. @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, TAPAS_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
  562. def batch_encode_plus(
  563. self,
  564. table: "pd.DataFrame",
  565. queries: list[TextInput] | list[PreTokenizedInput] | list[EncodedInput] | None = None,
  566. answer_coordinates: list[list[tuple]] | None = None,
  567. answer_text: list[list[TextInput]] | None = None,
  568. add_special_tokens: bool = True,
  569. padding: bool | str | PaddingStrategy = False,
  570. truncation: bool | str | TapasTruncationStrategy = False,
  571. max_length: int | None = None,
  572. pad_to_multiple_of: int | None = None,
  573. padding_side: str | None = None,
  574. return_tensors: str | TensorType | None = None,
  575. return_token_type_ids: bool | None = None,
  576. return_attention_mask: bool | None = None,
  577. return_overflowing_tokens: bool = False,
  578. return_special_tokens_mask: bool = False,
  579. return_offsets_mapping: bool = False,
  580. return_length: bool = False,
  581. verbose: bool = True,
  582. **kwargs,
  583. ) -> BatchEncoding:
  584. """
  585. Prepare a table and a list of strings for the model.
  586. <Tip warning={true}>
  587. This method is deprecated, `__call__` should be used instead.
  588. </Tip>
  589. Args:
  590. table (`pd.DataFrame`):
  591. Table containing tabular data. Note that all cell values must be text. Use *.astype(str)* on a Pandas
  592. dataframe to convert it to string.
  593. queries (`list[str]`):
  594. Batch of questions related to a table to be encoded. Note that all questions must refer to the **same**
  595. table.
  596. answer_coordinates (`list[Tuple]` or `list[list[Tuple]]`, *optional*):
  597. Answer coordinates of each table-question pair in the batch. Each tuple must be a (row_index,
  598. column_index) pair. The first data row (not the column header row) has index 0. The first column has
  599. index 0. The answer_coordinates must be a list of lists of tuples (each list corresponding to a single
  600. table-question pair).
  601. answer_text (`list[str]` or `list[list[str]]`, *optional*):
  602. Answer text of each table-question pair in the batch. In case a batch of table-question pairs is
  603. provided, then the answer_coordinates must be a list of lists of strings (each list corresponding to a
  604. single table-question pair). Each string must be the answer text of a corresponding answer coordinate.
  605. """
  606. if return_token_type_ids is not None and not add_special_tokens:
  607. raise ValueError(
  608. "Asking to return token_type_ids while setting add_special_tokens to False "
  609. "results in an undefined behavior. Please set add_special_tokens to True or "
  610. "set return_token_type_ids to None."
  611. )
  612. if (answer_coordinates and not answer_text) or (not answer_coordinates and answer_text):
  613. raise ValueError("In case you provide answers, both answer_coordinates and answer_text should be provided")
  614. elif answer_coordinates is None and answer_text is None:
  615. answer_coordinates = answer_text = [None] * len(queries)
  616. if "is_split_into_words" in kwargs:
  617. raise NotImplementedError("Currently TapasTokenizer only supports questions as strings.")
  618. if return_offsets_mapping:
  619. raise NotImplementedError(
  620. "return_offset_mapping is not available when using Python tokenizers. "
  621. "To use this feature, change your tokenizer to one deriving from "
  622. "transformers.PreTrainedTokenizerFast."
  623. )
  624. return self._batch_encode_plus(
  625. table=table,
  626. queries=queries,
  627. answer_coordinates=answer_coordinates,
  628. answer_text=answer_text,
  629. add_special_tokens=add_special_tokens,
  630. padding=padding,
  631. truncation=truncation,
  632. max_length=max_length,
  633. pad_to_multiple_of=pad_to_multiple_of,
  634. padding_side=padding_side,
  635. return_tensors=return_tensors,
  636. return_token_type_ids=return_token_type_ids,
  637. return_attention_mask=return_attention_mask,
  638. return_overflowing_tokens=return_overflowing_tokens,
  639. return_special_tokens_mask=return_special_tokens_mask,
  640. return_offsets_mapping=return_offsets_mapping,
  641. return_length=return_length,
  642. verbose=verbose,
  643. **kwargs,
  644. )
  645. def _get_question_tokens(self, query):
  646. """Tokenizes the query, taking into account the max and min question length."""
  647. query_tokens = self.tokenize(query)
  648. if self.max_question_length is not None and len(query_tokens) > self.max_question_length:
  649. logger.warning("Skipping query as its tokens are longer than the max question length")
  650. return "", []
  651. if self.min_question_length is not None and len(query_tokens) < self.min_question_length:
  652. logger.warning("Skipping query as its tokens are shorter than the min question length")
  653. return "", []
  654. return query, query_tokens
  655. def _batch_encode_plus(
  656. self,
  657. table,
  658. queries: list[TextInput] | list[PreTokenizedInput] | list[EncodedInput],
  659. answer_coordinates: list[list[tuple]] | None = None,
  660. answer_text: list[list[TextInput]] | None = None,
  661. add_special_tokens: bool = True,
  662. padding: bool | str | PaddingStrategy = False,
  663. truncation: bool | str | TapasTruncationStrategy = False,
  664. max_length: int | None = None,
  665. pad_to_multiple_of: int | None = None,
  666. padding_side: str | None = None,
  667. return_tensors: str | TensorType | None = None,
  668. return_token_type_ids: bool | None = True,
  669. return_attention_mask: bool | None = None,
  670. return_overflowing_tokens: bool = False,
  671. return_special_tokens_mask: bool = False,
  672. return_offsets_mapping: bool = False,
  673. return_length: bool = False,
  674. verbose: bool = True,
  675. **kwargs,
  676. ) -> BatchEncoding:
  677. table_tokens = self._tokenize_table(table)
  678. queries_tokens = []
  679. for idx, query in enumerate(queries):
  680. query, query_tokens = self._get_question_tokens(query)
  681. queries[idx] = query
  682. queries_tokens.append(query_tokens)
  683. batch_outputs = self._batch_prepare_for_model(
  684. table,
  685. queries,
  686. tokenized_table=table_tokens,
  687. queries_tokens=queries_tokens,
  688. answer_coordinates=answer_coordinates,
  689. padding=padding,
  690. truncation=truncation,
  691. answer_text=answer_text,
  692. add_special_tokens=add_special_tokens,
  693. max_length=max_length,
  694. pad_to_multiple_of=pad_to_multiple_of,
  695. padding_side=padding_side,
  696. return_tensors=return_tensors,
  697. prepend_batch_axis=True,
  698. return_attention_mask=return_attention_mask,
  699. return_token_type_ids=return_token_type_ids,
  700. return_overflowing_tokens=return_overflowing_tokens,
  701. return_special_tokens_mask=return_special_tokens_mask,
  702. return_length=return_length,
  703. verbose=verbose,
  704. )
  705. return BatchEncoding(batch_outputs)
  706. def _batch_prepare_for_model(
  707. self,
  708. raw_table: "pd.DataFrame",
  709. raw_queries: list[TextInput] | list[PreTokenizedInput] | list[EncodedInput],
  710. tokenized_table: TokenizedTable | None = None,
  711. queries_tokens: list[list[str]] | None = None,
  712. answer_coordinates: list[list[tuple]] | None = None,
  713. answer_text: list[list[TextInput]] | None = None,
  714. add_special_tokens: bool = True,
  715. padding: bool | str | PaddingStrategy = False,
  716. truncation: bool | str | TapasTruncationStrategy = False,
  717. max_length: int | None = None,
  718. pad_to_multiple_of: int | None = None,
  719. padding_side: str | None = None,
  720. return_tensors: str | TensorType | None = None,
  721. return_token_type_ids: bool | None = True,
  722. return_attention_mask: bool | None = True,
  723. return_special_tokens_mask: bool = False,
  724. return_offsets_mapping: bool = False,
  725. return_length: bool = False,
  726. verbose: bool = True,
  727. prepend_batch_axis: bool = False,
  728. **kwargs,
  729. ) -> BatchEncoding:
  730. batch_outputs = {}
  731. for index, example in enumerate(zip(raw_queries, queries_tokens, answer_coordinates, answer_text)):
  732. raw_query, query_tokens, answer_coords, answer_txt = example
  733. outputs = self.prepare_for_model(
  734. raw_table,
  735. raw_query,
  736. tokenized_table=tokenized_table,
  737. query_tokens=query_tokens,
  738. answer_coordinates=answer_coords,
  739. answer_text=answer_txt,
  740. add_special_tokens=add_special_tokens,
  741. padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterwards
  742. truncation=truncation,
  743. max_length=max_length,
  744. pad_to_multiple_of=None, # we pad in batch afterwards
  745. padding_side=None, # we pad in batch afterward
  746. return_attention_mask=False, # we pad in batch afterwards
  747. return_token_type_ids=return_token_type_ids,
  748. return_special_tokens_mask=return_special_tokens_mask,
  749. return_length=return_length,
  750. return_tensors=None, # We convert the whole batch to tensors at the end
  751. prepend_batch_axis=False,
  752. verbose=verbose,
  753. prev_answer_coordinates=answer_coordinates[index - 1] if index != 0 else None,
  754. prev_answer_text=answer_text[index - 1] if index != 0 else None,
  755. )
  756. for key, value in outputs.items():
  757. if key not in batch_outputs:
  758. batch_outputs[key] = []
  759. batch_outputs[key].append(value)
  760. batch_outputs = self.pad(
  761. batch_outputs,
  762. padding=padding,
  763. max_length=max_length,
  764. pad_to_multiple_of=pad_to_multiple_of,
  765. padding_side=padding_side,
  766. return_attention_mask=return_attention_mask,
  767. )
  768. batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
  769. return batch_outputs
  770. @add_end_docstrings(ENCODE_KWARGS_DOCSTRING)
  771. def encode(
  772. self,
  773. table: Union["pd.DataFrame", TextInput, list[TextInput]],
  774. query: TextInput | PreTokenizedInput | EncodedInput | None = None,
  775. add_special_tokens: bool = True,
  776. padding: bool | str | PaddingStrategy = False,
  777. truncation: bool | str | TapasTruncationStrategy = False,
  778. max_length: int | None = None,
  779. return_tensors: str | TensorType | None = None,
  780. **kwargs,
  781. ) -> list[int]:
  782. """
  783. Prepare a table and a string for the model. This method does not return token type IDs, attention masks, etc.
  784. which are necessary for the model to work correctly. Use that method if you want to build your processing on
  785. your own, otherwise refer to `__call__`.
  786. Args:
  787. table (`pd.DataFrame` or `str` or `list[str]`):
  788. Table containing tabular data. When passing a string or list of strings, those will be interpreted as
  789. queries with an empty table (to support generic tokenizer tests).
  790. query (`str` or `list[str]`):
  791. Question related to a table to be encoded.
  792. """
  793. if not isinstance(table, pd.DataFrame):
  794. if query is not None:
  795. raise AssertionError("Table must be of type pd.DataFrame when queries are provided separately.")
  796. inferred_query = table
  797. table = pd.DataFrame.from_dict({})
  798. query = inferred_query
  799. encoded_inputs = self.encode_plus(
  800. table,
  801. query=query,
  802. add_special_tokens=add_special_tokens,
  803. padding=padding,
  804. truncation=truncation,
  805. max_length=max_length,
  806. return_tensors=return_tensors,
  807. **kwargs,
  808. )
  809. return encoded_inputs["input_ids"]
  810. @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, TAPAS_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
  811. def encode_plus(
  812. self,
  813. table: Union["pd.DataFrame", TextInput, list[TextInput]],
  814. query: TextInput | PreTokenizedInput | EncodedInput | None = None,
  815. answer_coordinates: list[tuple] | None = None,
  816. answer_text: list[TextInput] | None = None,
  817. add_special_tokens: bool = True,
  818. padding: bool | str | PaddingStrategy = False,
  819. truncation: bool | str | TapasTruncationStrategy = False,
  820. max_length: int | None = None,
  821. pad_to_multiple_of: int | None = None,
  822. padding_side: str | None = None,
  823. return_tensors: str | TensorType | None = None,
  824. return_token_type_ids: bool | None = None,
  825. return_attention_mask: bool | None = None,
  826. return_special_tokens_mask: bool = False,
  827. return_offsets_mapping: bool = False,
  828. return_length: bool = False,
  829. verbose: bool = True,
  830. **kwargs,
  831. ) -> BatchEncoding:
  832. """
  833. Prepare a table and a string for the model.
  834. Args:
  835. table (`pd.DataFrame` or `str` or `list[str]`):
  836. Table containing tabular data. When passing a string or list of strings, those will be interpreted as
  837. queries with an empty table (to support generic tokenizer tests).
  838. query (`str` or `list[str]`):
  839. Question related to a table to be encoded.
  840. answer_coordinates (`list[Tuple]` or `list[list[Tuple]]`, *optional*):
  841. Answer coordinates of each table-question pair in the batch. The answer_coordinates must be a single
  842. list of one or more tuples. Each tuple must be a (row_index, column_index) pair. The first data row
  843. (not the column header row) has index 0. The first column has index 0.
  844. answer_text (`list[str]` or `list[list[str]]`, *optional*):
  845. Answer text of each table-question pair in the batch. The answer_text must be a single list of one or
  846. more strings. Each string must be the answer text of a corresponding answer coordinate.
  847. """
  848. if return_token_type_ids is not None and not add_special_tokens:
  849. raise ValueError(
  850. "Asking to return token_type_ids while setting add_special_tokens to False "
  851. "results in an undefined behavior. Please set add_special_tokens to True or "
  852. "set return_token_type_ids to None."
  853. )
  854. if (answer_coordinates and not answer_text) or (not answer_coordinates and answer_text):
  855. raise ValueError("In case you provide answers, both answer_coordinates and answer_text should be provided")
  856. if "is_split_into_words" in kwargs:
  857. raise NotImplementedError("Currently TapasTokenizer only supports questions as strings.")
  858. if return_offsets_mapping:
  859. raise NotImplementedError(
  860. "return_offset_mapping is not available when using Python tokenizers. "
  861. "To use this feature, change your tokenizer to one deriving from "
  862. "transformers.PreTrainedTokenizerFast."
  863. )
  864. if not isinstance(table, pd.DataFrame):
  865. if query is not None:
  866. raise AssertionError("Table must be of type pd.DataFrame when queries are provided separately.")
  867. inferred_query = table
  868. table = pd.DataFrame.from_dict({})
  869. query = inferred_query
  870. return self._encode_plus(
  871. table=table,
  872. query=query,
  873. answer_coordinates=answer_coordinates,
  874. answer_text=answer_text,
  875. add_special_tokens=add_special_tokens,
  876. truncation=truncation,
  877. padding=padding,
  878. max_length=max_length,
  879. pad_to_multiple_of=pad_to_multiple_of,
  880. padding_side=padding_side,
  881. return_tensors=return_tensors,
  882. return_token_type_ids=return_token_type_ids,
  883. return_attention_mask=return_attention_mask,
  884. return_special_tokens_mask=return_special_tokens_mask,
  885. return_offsets_mapping=return_offsets_mapping,
  886. return_length=return_length,
  887. verbose=verbose,
  888. **kwargs,
  889. )
  890. def _encode_plus(
  891. self,
  892. table: "pd.DataFrame",
  893. query: TextInput | PreTokenizedInput | EncodedInput,
  894. answer_coordinates: list[tuple] | None = None,
  895. answer_text: list[TextInput] | None = None,
  896. add_special_tokens: bool = True,
  897. padding: bool | str | PaddingStrategy = False,
  898. truncation: bool | str | TapasTruncationStrategy = False,
  899. max_length: int | None = None,
  900. pad_to_multiple_of: int | None = None,
  901. padding_side: str | None = None,
  902. return_tensors: str | TensorType | None = None,
  903. return_token_type_ids: bool | None = True,
  904. return_attention_mask: bool | None = True,
  905. return_special_tokens_mask: bool = False,
  906. return_offsets_mapping: bool = False,
  907. return_length: bool = False,
  908. verbose: bool = True,
  909. **kwargs,
  910. ):
  911. if query is None:
  912. query = ""
  913. logger.warning(
  914. "TAPAS is a question answering model but you have not passed a query. Please be aware that the "
  915. "model will probably not behave correctly."
  916. )
  917. table_tokens = self._tokenize_table(table)
  918. query, query_tokens = self._get_question_tokens(query)
  919. return self.prepare_for_model(
  920. table,
  921. query,
  922. tokenized_table=table_tokens,
  923. query_tokens=query_tokens,
  924. answer_coordinates=answer_coordinates,
  925. answer_text=answer_text,
  926. add_special_tokens=add_special_tokens,
  927. truncation=truncation,
  928. padding=padding,
  929. max_length=max_length,
  930. pad_to_multiple_of=pad_to_multiple_of,
  931. padding_side=padding_side,
  932. return_tensors=return_tensors,
  933. prepend_batch_axis=True,
  934. return_attention_mask=return_attention_mask,
  935. return_token_type_ids=return_token_type_ids,
  936. return_special_tokens_mask=return_special_tokens_mask,
  937. return_length=return_length,
  938. verbose=verbose,
  939. )
  940. @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, TAPAS_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
  941. def prepare_for_model(
  942. self,
  943. raw_table: "pd.DataFrame",
  944. raw_query: TextInput | PreTokenizedInput | EncodedInput,
  945. tokenized_table: TokenizedTable | None = None,
  946. query_tokens: TokenizedTable | None = None,
  947. answer_coordinates: list[tuple] | None = None,
  948. answer_text: list[TextInput] | None = None,
  949. add_special_tokens: bool = True,
  950. padding: bool | str | PaddingStrategy = False,
  951. truncation: bool | str | TapasTruncationStrategy = False,
  952. max_length: int | None = None,
  953. pad_to_multiple_of: int | None = None,
  954. padding_side: str | None = None,
  955. return_tensors: str | TensorType | None = None,
  956. return_token_type_ids: bool | None = True,
  957. return_attention_mask: bool | None = True,
  958. return_special_tokens_mask: bool = False,
  959. return_offsets_mapping: bool = False,
  960. return_length: bool = False,
  961. verbose: bool = True,
  962. prepend_batch_axis: bool = False,
  963. **kwargs,
  964. ) -> BatchEncoding:
  965. """
  966. Prepares a sequence of input id so that it can be used by the model. It adds special tokens, truncates
  967. sequences if overflowing while taking into account the special tokens.
  968. Args:
  969. raw_table (`pd.DataFrame`):
  970. The original table before any transformation (like tokenization) was applied to it.
  971. raw_query (`TextInput` or `PreTokenizedInput` or `EncodedInput`):
  972. The original query before any transformation (like tokenization) was applied to it.
  973. tokenized_table (`TokenizedTable`):
  974. The table after tokenization.
  975. query_tokens (`list[str]`):
  976. The query after tokenization.
  977. answer_coordinates (`list[Tuple]` or `list[list[Tuple]]`, *optional*):
  978. Answer coordinates of each table-question pair in the batch. The answer_coordinates must be a single
  979. list of one or more tuples. Each tuple must be a (row_index, column_index) pair. The first data row
  980. (not the column header row) has index 0. The first column has index 0.
  981. answer_text (`list[str]` or `list[list[str]]`, *optional*):
  982. Answer text of each table-question pair in the batch. The answer_text must be a single list of one or
  983. more strings. Each string must be the answer text of a corresponding answer coordinate.
  984. """
  985. if isinstance(padding, bool):
  986. if padding and (max_length is not None or pad_to_multiple_of is not None):
  987. padding = PaddingStrategy.MAX_LENGTH
  988. else:
  989. padding = PaddingStrategy.DO_NOT_PAD
  990. elif not isinstance(padding, PaddingStrategy):
  991. padding = PaddingStrategy(padding)
  992. if isinstance(truncation, bool):
  993. if truncation:
  994. truncation = TapasTruncationStrategy.DROP_ROWS_TO_FIT
  995. else:
  996. truncation = TapasTruncationStrategy.DO_NOT_TRUNCATE
  997. elif not isinstance(truncation, TapasTruncationStrategy):
  998. truncation = TapasTruncationStrategy(truncation)
  999. encoded_inputs = {}
  1000. is_part_of_batch = False
  1001. prev_answer_coordinates, prev_answer_text = None, None
  1002. if "prev_answer_coordinates" in kwargs and "prev_answer_text" in kwargs:
  1003. is_part_of_batch = True
  1004. prev_answer_coordinates = kwargs["prev_answer_coordinates"]
  1005. prev_answer_text = kwargs["prev_answer_text"]
  1006. num_rows = self._get_num_rows(raw_table, truncation != TapasTruncationStrategy.DO_NOT_TRUNCATE)
  1007. num_columns = self._get_num_columns(raw_table)
  1008. _, _, num_tokens = self._get_table_boundaries(tokenized_table)
  1009. if truncation != TapasTruncationStrategy.DO_NOT_TRUNCATE:
  1010. num_rows, num_tokens = self._get_truncated_table_rows(
  1011. query_tokens, tokenized_table, num_rows, num_columns, max_length, truncation_strategy=truncation
  1012. )
  1013. table_data = list(self._get_table_values(tokenized_table, num_columns, num_rows, num_tokens))
  1014. query_ids = self.convert_tokens_to_ids(query_tokens)
  1015. table_ids = list(zip(*table_data))[0] if len(table_data) > 0 else list(zip(*table_data))
  1016. table_ids = self.convert_tokens_to_ids(list(table_ids))
  1017. if "return_overflowing_tokens" in kwargs and kwargs["return_overflowing_tokens"]:
  1018. raise ValueError("TAPAS does not return overflowing tokens as it works on tables.")
  1019. if add_special_tokens:
  1020. input_ids = self.build_inputs_with_special_tokens(query_ids, table_ids)
  1021. else:
  1022. input_ids = query_ids + table_ids
  1023. if max_length is not None and len(input_ids) > max_length:
  1024. raise ValueError(
  1025. "Could not encode the query and table header given the maximum length. Encoding the query and table "
  1026. f"header results in a length of {len(input_ids)} which is higher than the max_length of {max_length}"
  1027. )
  1028. encoded_inputs["input_ids"] = input_ids
  1029. segment_ids = self.create_segment_token_type_ids_from_sequences(query_ids, table_data)
  1030. column_ids = self.create_column_token_type_ids_from_sequences(query_ids, table_data)
  1031. row_ids = self.create_row_token_type_ids_from_sequences(query_ids, table_data)
  1032. if not is_part_of_batch or (prev_answer_coordinates is None and prev_answer_text is None):
  1033. # simply set the prev_labels to zeros
  1034. prev_labels = [0] * len(row_ids)
  1035. else:
  1036. prev_labels = self.get_answer_ids(
  1037. column_ids, row_ids, table_data, prev_answer_text, prev_answer_coordinates
  1038. )
  1039. # FIRST: parse both the table and question in terms of numeric values
  1040. raw_table = add_numeric_table_values(raw_table)
  1041. raw_query = add_numeric_values_to_question(raw_query)
  1042. # SECOND: add numeric-related features (and not parse them in these functions):
  1043. column_ranks, inv_column_ranks = self._get_numeric_column_ranks(column_ids, row_ids, raw_table)
  1044. numeric_relations = self._get_numeric_relations(raw_query, column_ids, row_ids, raw_table)
  1045. # Load from model defaults
  1046. if return_token_type_ids is None:
  1047. return_token_type_ids = "token_type_ids" in self.model_input_names
  1048. if return_attention_mask is None:
  1049. return_attention_mask = "attention_mask" in self.model_input_names
  1050. if return_attention_mask:
  1051. attention_mask = self.create_attention_mask_from_sequences(query_ids, table_data)
  1052. encoded_inputs["attention_mask"] = attention_mask
  1053. if answer_coordinates is not None and answer_text is not None:
  1054. labels = self.get_answer_ids(column_ids, row_ids, table_data, answer_text, answer_coordinates)
  1055. numeric_values = self._get_numeric_values(raw_table, column_ids, row_ids)
  1056. numeric_values_scale = self._get_numeric_values_scale(raw_table, column_ids, row_ids)
  1057. encoded_inputs["labels"] = labels
  1058. encoded_inputs["numeric_values"] = numeric_values
  1059. encoded_inputs["numeric_values_scale"] = numeric_values_scale
  1060. if return_token_type_ids:
  1061. token_type_ids = [
  1062. segment_ids,
  1063. column_ids,
  1064. row_ids,
  1065. prev_labels,
  1066. column_ranks,
  1067. inv_column_ranks,
  1068. numeric_relations,
  1069. ]
  1070. token_type_ids = [list(ids) for ids in list(zip(*token_type_ids))]
  1071. encoded_inputs["token_type_ids"] = token_type_ids
  1072. if return_special_tokens_mask:
  1073. if add_special_tokens:
  1074. encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(query_ids, table_ids)
  1075. else:
  1076. encoded_inputs["special_tokens_mask"] = [0] * len(input_ids)
  1077. # Check lengths
  1078. if max_length is None and len(encoded_inputs["input_ids"]) > self.model_max_length and verbose:
  1079. if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False):
  1080. logger.warning(
  1081. "Token indices sequence length is longer than the specified maximum sequence length "
  1082. f"for this model ({len(encoded_inputs['input_ids'])} > {self.model_max_length}). Running this "
  1083. "sequence through the model will result in indexing errors."
  1084. )
  1085. self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True
  1086. # Padding
  1087. if padding != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
  1088. encoded_inputs = self.pad(
  1089. encoded_inputs,
  1090. max_length=max_length,
  1091. padding=padding.value,
  1092. pad_to_multiple_of=pad_to_multiple_of,
  1093. padding_side=padding_side,
  1094. return_attention_mask=return_attention_mask,
  1095. )
  1096. if return_length:
  1097. encoded_inputs["length"] = len(encoded_inputs["input_ids"])
  1098. batch_outputs = BatchEncoding(
  1099. encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis
  1100. )
  1101. return batch_outputs
  1102. def _get_truncated_table_rows(
  1103. self,
  1104. query_tokens: list[str],
  1105. tokenized_table: TokenizedTable,
  1106. num_rows: int,
  1107. num_columns: int,
  1108. max_length: int,
  1109. truncation_strategy: str | TapasTruncationStrategy,
  1110. ) -> tuple[int, int]:
  1111. """
  1112. Truncates a sequence pair in-place following the strategy.
  1113. Args:
  1114. query_tokens (`list[str]`):
  1115. List of strings corresponding to the tokenized query.
  1116. tokenized_table (`TokenizedTable`):
  1117. Tokenized table
  1118. num_rows (`int`):
  1119. Total number of table rows
  1120. num_columns (`int`):
  1121. Total number of table columns
  1122. max_length (`int`):
  1123. Total maximum length.
  1124. truncation_strategy (`str` or [`TapasTruncationStrategy]`):
  1125. Truncation strategy to use. Seeing as this method should only be called when truncating, the only
  1126. available strategy is the `"drop_rows_to_fit"` strategy.
  1127. Returns:
  1128. `Tuple(int, int)`: tuple containing the number of rows after truncation, and the number of tokens available
  1129. for each table element.
  1130. """
  1131. if not isinstance(truncation_strategy, TapasTruncationStrategy):
  1132. truncation_strategy = TapasTruncationStrategy(truncation_strategy)
  1133. if max_length is None:
  1134. max_length = self.model_max_length
  1135. if truncation_strategy == TapasTruncationStrategy.DROP_ROWS_TO_FIT:
  1136. while True:
  1137. num_tokens = self._get_max_num_tokens(
  1138. query_tokens, tokenized_table, num_rows=num_rows, num_columns=num_columns, max_length=max_length
  1139. )
  1140. if num_tokens is not None:
  1141. # We could fit the table.
  1142. break
  1143. # Try to drop a row to fit the table.
  1144. num_rows -= 1
  1145. if num_rows < 1:
  1146. break
  1147. elif truncation_strategy != TapasTruncationStrategy.DO_NOT_TRUNCATE:
  1148. raise ValueError(f"Unknown truncation strategy {truncation_strategy}.")
  1149. return num_rows, num_tokens or 1
  1150. def _tokenize_table(
  1151. self,
  1152. table=None,
  1153. ):
  1154. """
  1155. Tokenizes column headers and cell texts of a table.
  1156. Args:
  1157. table (`pd.Dataframe`):
  1158. Table. Returns: `TokenizedTable`: TokenizedTable object.
  1159. """
  1160. tokenized_rows = []
  1161. tokenized_row = []
  1162. # tokenize column headers
  1163. for column in table:
  1164. if self.strip_column_names:
  1165. tokenized_row.append(self.tokenize(""))
  1166. else:
  1167. tokenized_row.append(self.tokenize(column))
  1168. tokenized_rows.append(tokenized_row)
  1169. # tokenize cell values
  1170. for idx, row in table.iterrows():
  1171. tokenized_row = []
  1172. for cell in row:
  1173. tokenized_row.append(self.tokenize(cell))
  1174. tokenized_rows.append(tokenized_row)
  1175. token_coordinates = []
  1176. for row_index, row in enumerate(tokenized_rows):
  1177. for column_index, cell in enumerate(row):
  1178. for token_index, _ in enumerate(cell):
  1179. token_coordinates.append(
  1180. TokenCoordinates(
  1181. row_index=row_index,
  1182. column_index=column_index,
  1183. token_index=token_index,
  1184. )
  1185. )
  1186. return TokenizedTable(
  1187. rows=tokenized_rows,
  1188. selected_tokens=token_coordinates,
  1189. )
  1190. def _question_encoding_cost(self, question_tokens):
  1191. # Two extra spots of SEP and CLS.
  1192. return len(question_tokens) + 2
  1193. def _get_token_budget(self, question_tokens, max_length=None):
  1194. """
  1195. Computes the number of tokens left for the table after tokenizing a question, taking into account the max
  1196. sequence length of the model.
  1197. Args:
  1198. question_tokens (`list[String]`):
  1199. List of question tokens. Returns: `int`: the number of tokens left for the table, given the model max
  1200. length.
  1201. """
  1202. return (max_length if max_length is not None else self.model_max_length) - self._question_encoding_cost(
  1203. question_tokens
  1204. )
  1205. def _get_table_values(self, table, num_columns, num_rows, num_tokens) -> Generator[TableValue, None, None]:
  1206. """Iterates over partial table and returns token, column and row indexes."""
  1207. for tc in table.selected_tokens:
  1208. # First row is header row.
  1209. if tc.row_index >= num_rows + 1:
  1210. continue
  1211. if tc.column_index >= num_columns:
  1212. continue
  1213. cell = table.rows[tc.row_index][tc.column_index]
  1214. token = cell[tc.token_index]
  1215. word_begin_index = tc.token_index
  1216. # Don't add partial words. Find the starting word piece and check if it
  1217. # fits in the token budget.
  1218. while word_begin_index >= 0 and _is_inner_wordpiece(cell[word_begin_index]):
  1219. word_begin_index -= 1
  1220. if word_begin_index >= num_tokens:
  1221. continue
  1222. yield TableValue(token, tc.column_index + 1, tc.row_index)
  1223. def _get_table_boundaries(self, table):
  1224. """Return maximal number of rows, columns and tokens."""
  1225. max_num_tokens = 0
  1226. max_num_columns = 0
  1227. max_num_rows = 0
  1228. for tc in table.selected_tokens:
  1229. max_num_columns = max(max_num_columns, tc.column_index + 1)
  1230. max_num_rows = max(max_num_rows, tc.row_index + 1)
  1231. max_num_tokens = max(max_num_tokens, tc.token_index + 1)
  1232. max_num_columns = min(self.max_column_id, max_num_columns)
  1233. max_num_rows = min(self.max_row_id, max_num_rows)
  1234. return max_num_rows, max_num_columns, max_num_tokens
  1235. def _get_table_cost(self, table, num_columns, num_rows, num_tokens):
  1236. return sum(1 for _ in self._get_table_values(table, num_columns, num_rows, num_tokens))
  1237. def _get_max_num_tokens(self, question_tokens, tokenized_table, num_columns, num_rows, max_length):
  1238. """Computes max number of tokens that can be squeezed into the budget."""
  1239. token_budget = self._get_token_budget(question_tokens, max_length)
  1240. _, _, max_num_tokens = self._get_table_boundaries(tokenized_table)
  1241. if self.cell_trim_length >= 0 and max_num_tokens > self.cell_trim_length:
  1242. max_num_tokens = self.cell_trim_length
  1243. num_tokens = 0
  1244. for num_tokens in range(max_num_tokens + 1):
  1245. cost = self._get_table_cost(tokenized_table, num_columns, num_rows, num_tokens + 1)
  1246. if cost > token_budget:
  1247. break
  1248. if num_tokens < max_num_tokens:
  1249. if self.cell_trim_length >= 0:
  1250. # We don't allow dynamic trimming if a cell_trim_length is set.
  1251. return None
  1252. if num_tokens == 0:
  1253. return None
  1254. return num_tokens
  1255. def _get_num_columns(self, table):
  1256. num_columns = table.shape[1]
  1257. if num_columns >= self.max_column_id:
  1258. raise ValueError("Too many columns")
  1259. return num_columns
  1260. def _get_num_rows(self, table, drop_rows_to_fit):
  1261. num_rows = table.shape[0]
  1262. if num_rows >= self.max_row_id:
  1263. if drop_rows_to_fit:
  1264. num_rows = self.max_row_id - 1
  1265. else:
  1266. raise ValueError("Too many rows")
  1267. return num_rows
  1268. def _serialize_text(self, question_tokens):
  1269. """Serializes texts in index arrays."""
  1270. tokens = []
  1271. segment_ids = []
  1272. column_ids = []
  1273. row_ids = []
  1274. # add [CLS] token at the beginning
  1275. tokens.append(self.cls_token)
  1276. segment_ids.append(0)
  1277. column_ids.append(0)
  1278. row_ids.append(0)
  1279. for token in question_tokens:
  1280. tokens.append(token)
  1281. segment_ids.append(0)
  1282. column_ids.append(0)
  1283. row_ids.append(0)
  1284. return tokens, segment_ids, column_ids, row_ids
  1285. def _serialize(
  1286. self,
  1287. question_tokens,
  1288. table,
  1289. num_columns,
  1290. num_rows,
  1291. num_tokens,
  1292. ):
  1293. """Serializes table and text."""
  1294. tokens, segment_ids, column_ids, row_ids = self._serialize_text(question_tokens)
  1295. # add [SEP] token between question and table tokens
  1296. tokens.append(self.sep_token)
  1297. segment_ids.append(0)
  1298. column_ids.append(0)
  1299. row_ids.append(0)
  1300. for token, column_id, row_id in self._get_table_values(table, num_columns, num_rows, num_tokens):
  1301. tokens.append(token)
  1302. segment_ids.append(1)
  1303. column_ids.append(column_id)
  1304. row_ids.append(row_id)
  1305. return SerializedExample(
  1306. tokens=tokens,
  1307. segment_ids=segment_ids,
  1308. column_ids=column_ids,
  1309. row_ids=row_ids,
  1310. )
  1311. def _get_column_values(self, table, col_index):
  1312. table_numeric_values = {}
  1313. for row_index, row in table.iterrows():
  1314. cell = row[col_index]
  1315. if cell.numeric_value is not None:
  1316. table_numeric_values[row_index] = cell.numeric_value
  1317. return table_numeric_values
  1318. def _get_cell_token_indexes(self, column_ids, row_ids, column_id, row_id):
  1319. for index in range(len(column_ids)):
  1320. if column_ids[index] - 1 == column_id and row_ids[index] - 1 == row_id:
  1321. yield index
  1322. def _get_numeric_column_ranks(self, column_ids, row_ids, table):
  1323. """Returns column ranks for all numeric columns."""
  1324. ranks = [0] * len(column_ids)
  1325. inv_ranks = [0] * len(column_ids)
  1326. # original code from tf_example_utils.py of the original implementation
  1327. if table is not None:
  1328. for col_index in range(len(table.columns)):
  1329. table_numeric_values = self._get_column_values(table, col_index)
  1330. if not table_numeric_values:
  1331. continue
  1332. try:
  1333. key_fn = get_numeric_sort_key_fn(table_numeric_values.values())
  1334. except ValueError:
  1335. continue
  1336. table_numeric_values = {row_index: key_fn(value) for row_index, value in table_numeric_values.items()}
  1337. table_numeric_values_inv = collections.defaultdict(list)
  1338. for row_index, value in table_numeric_values.items():
  1339. table_numeric_values_inv[value].append(row_index)
  1340. unique_values = sorted(table_numeric_values_inv.keys())
  1341. for rank, value in enumerate(unique_values):
  1342. for row_index in table_numeric_values_inv[value]:
  1343. for index in self._get_cell_token_indexes(column_ids, row_ids, col_index, row_index):
  1344. ranks[index] = rank + 1
  1345. inv_ranks[index] = len(unique_values) - rank
  1346. return ranks, inv_ranks
  1347. def _get_numeric_sort_key_fn(self, table_numeric_values, value):
  1348. """
  1349. Returns the sort key function for comparing value to table values. The function returned will be a suitable
  1350. input for the key param of the sort(). See number_annotation_utils._get_numeric_sort_key_fn for details
  1351. Args:
  1352. table_numeric_values: Numeric values of a column
  1353. value: Numeric value in the question
  1354. Returns:
  1355. A function key function to compare column and question values.
  1356. """
  1357. if not table_numeric_values:
  1358. return None
  1359. all_values = list(table_numeric_values.values())
  1360. all_values.append(value)
  1361. try:
  1362. return get_numeric_sort_key_fn(all_values)
  1363. except ValueError:
  1364. return None
  1365. def _get_numeric_relations(self, question, column_ids, row_ids, table):
  1366. """
  1367. Returns numeric relations embeddings
  1368. Args:
  1369. question: Question object.
  1370. column_ids: Maps word piece position to column id.
  1371. row_ids: Maps word piece position to row id.
  1372. table: The table containing the numeric cell values.
  1373. """
  1374. numeric_relations = [0] * len(column_ids)
  1375. # first, we add any numeric value spans to the question:
  1376. # Create a dictionary that maps a table cell to the set of all relations
  1377. # this cell has with any value in the question.
  1378. cell_indices_to_relations = collections.defaultdict(set)
  1379. if question is not None and table is not None:
  1380. for numeric_value_span in question.numeric_spans:
  1381. for value in numeric_value_span.values:
  1382. for column_index in range(len(table.columns)):
  1383. table_numeric_values = self._get_column_values(table, column_index)
  1384. sort_key_fn = self._get_numeric_sort_key_fn(table_numeric_values, value)
  1385. if sort_key_fn is None:
  1386. continue
  1387. for row_index, cell_value in table_numeric_values.items():
  1388. relation = get_numeric_relation(value, cell_value, sort_key_fn)
  1389. if relation is not None:
  1390. cell_indices_to_relations[column_index, row_index].add(relation)
  1391. # For each cell add a special feature for all its word pieces.
  1392. for (column_index, row_index), relations in cell_indices_to_relations.items():
  1393. relation_set_index = 0
  1394. for relation in relations:
  1395. assert relation.value >= Relation.EQ.value
  1396. relation_set_index += 2 ** (relation.value - Relation.EQ.value)
  1397. for cell_token_index in self._get_cell_token_indexes(column_ids, row_ids, column_index, row_index):
  1398. numeric_relations[cell_token_index] = relation_set_index
  1399. return numeric_relations
  1400. def _get_numeric_values(self, table, column_ids, row_ids):
  1401. """Returns numeric values for computation of answer loss."""
  1402. numeric_values = [float("nan")] * len(column_ids)
  1403. if table is not None:
  1404. num_rows = table.shape[0]
  1405. num_columns = table.shape[1]
  1406. for col_index in range(num_columns):
  1407. for row_index in range(num_rows):
  1408. numeric_value = table.iloc[row_index, col_index].numeric_value
  1409. if numeric_value is not None:
  1410. if numeric_value.float_value is None:
  1411. continue
  1412. float_value = numeric_value.float_value
  1413. if float_value == float("inf"):
  1414. continue
  1415. for index in self._get_cell_token_indexes(column_ids, row_ids, col_index, row_index):
  1416. numeric_values[index] = float_value
  1417. return numeric_values
  1418. def _get_numeric_values_scale(self, table, column_ids, row_ids):
  1419. """Returns a scale to each token to down weigh the value of long words."""
  1420. numeric_values_scale = [1.0] * len(column_ids)
  1421. if table is None:
  1422. return numeric_values_scale
  1423. num_rows = table.shape[0]
  1424. num_columns = table.shape[1]
  1425. for col_index in range(num_columns):
  1426. for row_index in range(num_rows):
  1427. indices = list(self._get_cell_token_indexes(column_ids, row_ids, col_index, row_index))
  1428. num_indices = len(indices)
  1429. if num_indices > 1:
  1430. for index in indices:
  1431. numeric_values_scale[index] = float(num_indices)
  1432. return numeric_values_scale
  1433. def _pad_to_seq_length(self, inputs):
  1434. while len(inputs) > self.model_max_length:
  1435. inputs.pop()
  1436. while len(inputs) < self.model_max_length:
  1437. inputs.append(0)
  1438. def _get_all_answer_ids_from_coordinates(
  1439. self,
  1440. column_ids,
  1441. row_ids,
  1442. answers_list,
  1443. ):
  1444. """Maps lists of answer coordinates to token indexes."""
  1445. answer_ids = [0] * len(column_ids)
  1446. found_answers = set()
  1447. all_answers = set()
  1448. for answers in answers_list:
  1449. column_index, row_index = answers
  1450. all_answers.add((column_index, row_index))
  1451. for index in self._get_cell_token_indexes(column_ids, row_ids, column_index, row_index):
  1452. found_answers.add((column_index, row_index))
  1453. answer_ids[index] = 1
  1454. missing_count = len(all_answers) - len(found_answers)
  1455. return answer_ids, missing_count
  1456. def _get_all_answer_ids(self, column_ids, row_ids, answer_coordinates):
  1457. """
  1458. Maps answer coordinates of a question to token indexes.
  1459. In the SQA format (TSV), the coordinates are given as (row, column) tuples. Here, we first swap them to
  1460. (column, row) format before calling _get_all_answer_ids_from_coordinates.
  1461. """
  1462. def _to_coordinates(answer_coordinates_question):
  1463. return [(coords[1], coords[0]) for coords in answer_coordinates_question]
  1464. return self._get_all_answer_ids_from_coordinates(
  1465. column_ids, row_ids, answers_list=(_to_coordinates(answer_coordinates))
  1466. )
  1467. def _find_tokens(self, text, segment):
  1468. """Return start index of segment in text or None."""
  1469. logging.info(f"text: {text} {segment}")
  1470. for index in range(1 + len(text) - len(segment)):
  1471. for seg_index, seg_token in enumerate(segment):
  1472. if text[index + seg_index].piece != seg_token.piece:
  1473. break
  1474. else:
  1475. return index
  1476. return None
  1477. def _find_answer_coordinates_from_answer_text(
  1478. self,
  1479. tokenized_table,
  1480. answer_text,
  1481. ):
  1482. """Returns all occurrences of answer_text in the table."""
  1483. logging.info(f"answer text: {answer_text}")
  1484. for row_index, row in enumerate(tokenized_table.rows):
  1485. if row_index == 0:
  1486. # We don't search for answers in the header.
  1487. continue
  1488. for col_index, cell in enumerate(row):
  1489. token_index = self._find_tokens(cell, answer_text)
  1490. if token_index is not None:
  1491. yield TokenCoordinates(
  1492. row_index=row_index,
  1493. column_index=col_index,
  1494. token_index=token_index,
  1495. )
  1496. def _find_answer_ids_from_answer_texts(
  1497. self,
  1498. column_ids,
  1499. row_ids,
  1500. tokenized_table,
  1501. answer_texts,
  1502. ):
  1503. """Maps question with answer texts to the first matching token indexes."""
  1504. answer_ids = [0] * len(column_ids)
  1505. for answer_text in answer_texts:
  1506. for coordinates in self._find_answer_coordinates_from_answer_text(
  1507. tokenized_table,
  1508. answer_text,
  1509. ):
  1510. # Maps answer coordinates to indexes this can fail if tokens / rows have
  1511. # been pruned.
  1512. indexes = list(
  1513. self._get_cell_token_indexes(
  1514. column_ids,
  1515. row_ids,
  1516. column_id=coordinates.column_index,
  1517. row_id=coordinates.row_index - 1,
  1518. )
  1519. )
  1520. indexes.sort()
  1521. coordinate_answer_ids = []
  1522. if indexes:
  1523. begin_index = coordinates.token_index + indexes[0]
  1524. end_index = begin_index + len(answer_text)
  1525. for index in indexes:
  1526. if index >= begin_index and index < end_index:
  1527. coordinate_answer_ids.append(index)
  1528. if len(coordinate_answer_ids) == len(answer_text):
  1529. for index in coordinate_answer_ids:
  1530. answer_ids[index] = 1
  1531. break
  1532. return answer_ids
  1533. def _get_answer_ids(self, column_ids, row_ids, answer_coordinates):
  1534. """Maps answer coordinates of a question to token indexes."""
  1535. answer_ids, missing_count = self._get_all_answer_ids(column_ids, row_ids, answer_coordinates)
  1536. if missing_count:
  1537. raise ValueError("Couldn't find all answers")
  1538. return answer_ids
  1539. def get_answer_ids(self, column_ids, row_ids, tokenized_table, answer_texts_question, answer_coordinates_question):
  1540. if self.update_answer_coordinates:
  1541. return self._find_answer_ids_from_answer_texts(
  1542. column_ids,
  1543. row_ids,
  1544. tokenized_table,
  1545. answer_texts=[self.tokenize(at) for at in answer_texts_question],
  1546. )
  1547. return self._get_answer_ids(column_ids, row_ids, answer_coordinates_question)
  1548. def _pad(
  1549. self,
  1550. encoded_inputs: dict[str, EncodedInput] | BatchEncoding,
  1551. max_length: int | None = None,
  1552. padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
  1553. pad_to_multiple_of: int | None = None,
  1554. padding_side: str | None = None,
  1555. return_attention_mask: bool | None = None,
  1556. ) -> dict:
  1557. """
  1558. Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
  1559. Args:
  1560. encoded_inputs:
  1561. Dictionary of tokenized inputs (`list[int]`) or batch of tokenized inputs (`list[list[int]]`).
  1562. max_length: maximum length of the returned list and optionally padding length (see below).
  1563. Will truncate by taking into account the special tokens.
  1564. padding_strategy: PaddingStrategy to use for padding.
  1565. - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
  1566. - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
  1567. - PaddingStrategy.DO_NOT_PAD: Do not pad
  1568. The tokenizer padding sides are defined in self.padding_side:
  1569. - 'left': pads on the left of the sequences
  1570. - 'right': pads on the right of the sequences
  1571. pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
  1572. This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
  1573. `>= 7.5` (Volta).
  1574. padding_side:
  1575. The side on which the model should have padding applied. Should be selected between ['right', 'left'].
  1576. Default value is picked from the class attribute of the same name.
  1577. return_attention_mask:
  1578. (optional) Set to False to avoid returning attention mask (default: set to model specifics)
  1579. """
  1580. # Load from model defaults
  1581. if return_attention_mask is None:
  1582. return_attention_mask = "attention_mask" in self.model_input_names
  1583. if padding_strategy == PaddingStrategy.LONGEST:
  1584. max_length = len(encoded_inputs["input_ids"])
  1585. if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
  1586. max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
  1587. needs_to_be_padded = (
  1588. padding_strategy != PaddingStrategy.DO_NOT_PAD and len(encoded_inputs["input_ids"]) != max_length
  1589. )
  1590. # Initialize attention mask if not present.
  1591. if return_attention_mask and "attention_mask" not in encoded_inputs:
  1592. encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"])
  1593. if needs_to_be_padded:
  1594. difference = max_length - len(encoded_inputs["input_ids"])
  1595. padding_side = padding_side if padding_side is not None else self.padding_side
  1596. if padding_side == "right":
  1597. if return_attention_mask:
  1598. encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
  1599. if "token_type_ids" in encoded_inputs:
  1600. encoded_inputs["token_type_ids"] = (
  1601. encoded_inputs["token_type_ids"] + [[self.pad_token_type_id] * 7] * difference
  1602. )
  1603. if "labels" in encoded_inputs:
  1604. encoded_inputs["labels"] = encoded_inputs["labels"] + [0] * difference
  1605. if "numeric_values" in encoded_inputs:
  1606. encoded_inputs["numeric_values"] = encoded_inputs["numeric_values"] + [float("nan")] * difference
  1607. if "numeric_values_scale" in encoded_inputs:
  1608. encoded_inputs["numeric_values_scale"] = (
  1609. encoded_inputs["numeric_values_scale"] + [1.0] * difference
  1610. )
  1611. if "special_tokens_mask" in encoded_inputs:
  1612. encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
  1613. encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference
  1614. elif padding_side == "left":
  1615. if return_attention_mask:
  1616. encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
  1617. if "token_type_ids" in encoded_inputs:
  1618. encoded_inputs["token_type_ids"] = [[self.pad_token_type_id] * 7] * difference + encoded_inputs[
  1619. "token_type_ids"
  1620. ]
  1621. if "labels" in encoded_inputs:
  1622. encoded_inputs["labels"] = [0] * difference + encoded_inputs["labels"]
  1623. if "numeric_values" in encoded_inputs:
  1624. encoded_inputs["numeric_values"] = [float("nan")] * difference + encoded_inputs["numeric_values"]
  1625. if "numeric_values_scale" in encoded_inputs:
  1626. encoded_inputs["numeric_values_scale"] = [1.0] * difference + encoded_inputs[
  1627. "numeric_values_scale"
  1628. ]
  1629. if "special_tokens_mask" in encoded_inputs:
  1630. encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
  1631. encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"]
  1632. else:
  1633. raise ValueError("Invalid padding strategy:" + str(padding_side))
  1634. return encoded_inputs
  1635. # Everything related to converting logits to predictions
  1636. def _get_cell_token_probs(self, probabilities, segment_ids, row_ids, column_ids):
  1637. for i, p in enumerate(probabilities):
  1638. segment_id = segment_ids[i]
  1639. col = column_ids[i] - 1
  1640. row = row_ids[i] - 1
  1641. if col >= 0 and row >= 0 and segment_id == 1:
  1642. yield i, p
  1643. def _get_mean_cell_probs(self, probabilities, segment_ids, row_ids, column_ids):
  1644. """Computes average probability per cell, aggregating over tokens."""
  1645. coords_to_probs = collections.defaultdict(list)
  1646. for i, prob in self._get_cell_token_probs(probabilities, segment_ids, row_ids, column_ids):
  1647. col = column_ids[i] - 1
  1648. row = row_ids[i] - 1
  1649. coords_to_probs[(col, row)].append(prob)
  1650. return {coords: np.array(cell_probs).mean() for coords, cell_probs in coords_to_probs.items()}
  1651. def convert_logits_to_predictions(self, data, logits, logits_agg=None, cell_classification_threshold=0.5):
  1652. """
  1653. Converts logits of [`TapasForQuestionAnswering`] to actual predicted answer coordinates and optional
  1654. aggregation indices.
  1655. The original implementation, on which this function is based, can be found
  1656. [here](https://github.com/google-research/tapas/blob/4908213eb4df7aa988573350278b44c4dbe3f71b/tapas/experiments/prediction_utils.py#L288).
  1657. Args:
  1658. data (`dict`):
  1659. Dictionary mapping features to actual values. Should be created using [`TapasTokenizer`].
  1660. logits (`torch.Tensor` of shape `(batch_size, sequence_length)`):
  1661. Tensor containing the logits at the token level.
  1662. logits_agg (`torch.Tensor` of shape `(batch_size, num_aggregation_labels)`, *optional*):
  1663. Tensor containing the aggregation logits.
  1664. cell_classification_threshold (`float`, *optional*, defaults to 0.5):
  1665. Threshold to be used for cell selection. All table cells for which their probability is larger than
  1666. this threshold will be selected.
  1667. Returns:
  1668. `tuple` comprising various elements depending on the inputs:
  1669. - predicted_answer_coordinates (`list[list[[tuple]]` of length `batch_size`): Predicted answer coordinates
  1670. as a list of lists of tuples. Each element in the list contains the predicted answer coordinates of a
  1671. single example in the batch, as a list of tuples. Each tuple is a cell, i.e. (row index, column index).
  1672. - predicted_aggregation_indices (`list[int]`of length `batch_size`, *optional*, returned when
  1673. `logits_aggregation` is provided): Predicted aggregation operator indices of the aggregation head.
  1674. """
  1675. logits = logits.numpy()
  1676. if logits_agg is not None:
  1677. logits_agg = logits_agg.numpy()
  1678. data = {key: value.numpy() for key, value in data.items() if key != "training"}
  1679. # input data is of type float32
  1680. # np.log(np.finfo(np.float32).max) = 88.72284
  1681. # Any value over 88.72284 will overflow when passed through the exponential, sending a warning
  1682. # We disable this warning by truncating the logits.
  1683. logits[logits < -88.7] = -88.7
  1684. # Compute probabilities from token logits
  1685. probabilities = 1 / (1 + np.exp(-logits)) * data["attention_mask"]
  1686. token_types = [
  1687. "segment_ids",
  1688. "column_ids",
  1689. "row_ids",
  1690. "prev_labels",
  1691. "column_ranks",
  1692. "inv_column_ranks",
  1693. "numeric_relations",
  1694. ]
  1695. # collect input_ids, segment ids, row ids and column ids of batch. Shape (batch_size, seq_len)
  1696. input_ids = data["input_ids"]
  1697. segment_ids = data["token_type_ids"][:, :, token_types.index("segment_ids")]
  1698. row_ids = data["token_type_ids"][:, :, token_types.index("row_ids")]
  1699. column_ids = data["token_type_ids"][:, :, token_types.index("column_ids")]
  1700. # next, get answer coordinates for every example in the batch
  1701. num_batch = input_ids.shape[0]
  1702. predicted_answer_coordinates = []
  1703. for i in range(num_batch):
  1704. probabilities_example = probabilities[i].tolist()
  1705. segment_ids_example = segment_ids[i]
  1706. row_ids_example = row_ids[i]
  1707. column_ids_example = column_ids[i]
  1708. max_width = column_ids_example.max()
  1709. max_height = row_ids_example.max()
  1710. if max_width == 0 and max_height == 0:
  1711. continue
  1712. cell_coords_to_prob = self._get_mean_cell_probs(
  1713. probabilities_example,
  1714. segment_ids_example.tolist(),
  1715. row_ids_example.tolist(),
  1716. column_ids_example.tolist(),
  1717. )
  1718. # Select the answers above the classification threshold.
  1719. answer_coordinates = []
  1720. for col in range(max_width):
  1721. for row in range(max_height):
  1722. cell_prob = cell_coords_to_prob.get((col, row), None)
  1723. if cell_prob is not None:
  1724. if cell_prob > cell_classification_threshold:
  1725. answer_coordinates.append((row, col))
  1726. answer_coordinates = sorted(answer_coordinates)
  1727. predicted_answer_coordinates.append(answer_coordinates)
  1728. output = (predicted_answer_coordinates,)
  1729. if logits_agg is not None:
  1730. predicted_aggregation_indices = logits_agg.argmax(axis=-1)
  1731. output = (predicted_answer_coordinates, predicted_aggregation_indices.tolist())
  1732. return output
  1733. # End of everything related to converting logits to predictions
  1734. class BasicTokenizer:
  1735. """
  1736. Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
  1737. Args:
  1738. do_lower_case (`bool`, *optional*, defaults to `True`):
  1739. Whether or not to lowercase the input when tokenizing.
  1740. never_split (`Iterable`, *optional*):
  1741. Collection of tokens which will never be split during tokenization. Only has an effect when
  1742. `do_basic_tokenize=True`
  1743. tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
  1744. Whether or not to tokenize Chinese characters.
  1745. This should likely be deactivated for Japanese (see this
  1746. [issue](https://github.com/huggingface/transformers/issues/328)).
  1747. strip_accents (`bool`, *optional*):
  1748. Whether or not to strip all accents. If this option is not specified, then it will be determined by the
  1749. value for `lowercase` (as in the original BERT).
  1750. do_split_on_punc (`bool`, *optional*, defaults to `True`):
  1751. In some instances we want to skip the basic punctuation splitting so that later tokenization can capture
  1752. the full context of the words, such as contractions.
  1753. """
  1754. def __init__(
  1755. self,
  1756. do_lower_case=True,
  1757. never_split=None,
  1758. tokenize_chinese_chars=True,
  1759. strip_accents=None,
  1760. do_split_on_punc=True,
  1761. ):
  1762. if never_split is None:
  1763. never_split = []
  1764. self.do_lower_case = do_lower_case
  1765. self.never_split = set(never_split)
  1766. self.tokenize_chinese_chars = tokenize_chinese_chars
  1767. self.strip_accents = strip_accents
  1768. self.do_split_on_punc = do_split_on_punc
  1769. def tokenize(self, text, never_split=None):
  1770. """
  1771. Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer.
  1772. Args:
  1773. never_split (`List[str]`, *optional*)
  1774. Kept for backward compatibility purposes. Now implemented directly at the base class level (see
  1775. [`PreTrainedTokenizer.tokenize`]) List of token not to split.
  1776. """
  1777. # union() returns a new set by concatenating the two sets.
  1778. never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
  1779. text = self._clean_text(text)
  1780. # This was added on November 1st, 2018 for the multilingual and Chinese
  1781. # models. This is also applied to the English models now, but it doesn't
  1782. # matter since the English models were not trained on any Chinese data
  1783. # and generally don't have any Chinese data in them (there are Chinese
  1784. # characters in the vocabulary because Wikipedia does have some Chinese
  1785. # words in the English Wikipedia.).
  1786. if self.tokenize_chinese_chars:
  1787. text = self._tokenize_chinese_chars(text)
  1788. # prevents treating the same character with different unicode codepoints as different characters
  1789. unicode_normalized_text = unicodedata.normalize("NFC", text)
  1790. orig_tokens = whitespace_tokenize(unicode_normalized_text)
  1791. split_tokens = []
  1792. for token in orig_tokens:
  1793. if token not in never_split:
  1794. if self.do_lower_case:
  1795. token = token.lower()
  1796. if self.strip_accents is not False:
  1797. token = self._run_strip_accents(token)
  1798. elif self.strip_accents:
  1799. token = self._run_strip_accents(token)
  1800. split_tokens.extend(self._run_split_on_punc(token, never_split))
  1801. output_tokens = whitespace_tokenize(" ".join(split_tokens))
  1802. return output_tokens
  1803. def _run_strip_accents(self, text):
  1804. """Strips accents from a piece of text."""
  1805. text = unicodedata.normalize("NFD", text)
  1806. output = []
  1807. for char in text:
  1808. cat = unicodedata.category(char)
  1809. if cat == "Mn":
  1810. continue
  1811. output.append(char)
  1812. return "".join(output)
  1813. def _run_split_on_punc(self, text, never_split=None):
  1814. """Splits punctuation on a piece of text."""
  1815. if not self.do_split_on_punc or (never_split is not None and text in never_split):
  1816. return [text]
  1817. chars = list(text)
  1818. i = 0
  1819. start_new_word = True
  1820. output = []
  1821. while i < len(chars):
  1822. char = chars[i]
  1823. if _is_punctuation(char):
  1824. output.append([char])
  1825. start_new_word = True
  1826. else:
  1827. if start_new_word:
  1828. output.append([])
  1829. start_new_word = False
  1830. output[-1].append(char)
  1831. i += 1
  1832. return ["".join(x) for x in output]
  1833. def _tokenize_chinese_chars(self, text):
  1834. """Adds whitespace around any CJK character."""
  1835. output = []
  1836. for char in text:
  1837. cp = ord(char)
  1838. if self._is_chinese_char(cp):
  1839. output.append(" ")
  1840. output.append(char)
  1841. output.append(" ")
  1842. else:
  1843. output.append(char)
  1844. return "".join(output)
  1845. def _is_chinese_char(self, cp):
  1846. """Checks whether CP is the codepoint of a CJK character."""
  1847. # This defines a "chinese character" as anything in the CJK Unicode block:
  1848. # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
  1849. #
  1850. # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
  1851. # despite its name. The modern Korean Hangul alphabet is a different block,
  1852. # as is Japanese Hiragana and Katakana. Those alphabets are used to write
  1853. # space-separated words, so they are not treated specially and handled
  1854. # like the all of the other languages.
  1855. if (
  1856. (cp >= 0x4E00 and cp <= 0x9FFF)
  1857. or (cp >= 0x3400 and cp <= 0x4DBF)
  1858. or (cp >= 0x20000 and cp <= 0x2A6DF)
  1859. or (cp >= 0x2A700 and cp <= 0x2B73F)
  1860. or (cp >= 0x2B740 and cp <= 0x2B81F)
  1861. or (cp >= 0x2B820 and cp <= 0x2CEAF)
  1862. or (cp >= 0xF900 and cp <= 0xFAFF)
  1863. or (cp >= 0x2F800 and cp <= 0x2FA1F)
  1864. ):
  1865. return True
  1866. return False
  1867. def _clean_text(self, text):
  1868. """Performs invalid character removal and whitespace cleanup on text."""
  1869. output = []
  1870. for char in text:
  1871. cp = ord(char)
  1872. if cp == 0 or cp == 0xFFFD or _is_control(char):
  1873. continue
  1874. if _is_whitespace(char):
  1875. output.append(" ")
  1876. else:
  1877. output.append(char)
  1878. return "".join(output)
  1879. class WordpieceTokenizer:
  1880. """Runs WordPiece tokenization."""
  1881. def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
  1882. self.vocab = vocab
  1883. self.unk_token = unk_token
  1884. self.max_input_chars_per_word = max_input_chars_per_word
  1885. def tokenize(self, text):
  1886. """
  1887. Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
  1888. tokenization using the given vocabulary.
  1889. For example, `input = "unaffable"` will return as output `["un", "##aff", "##able"]`.
  1890. Args:
  1891. text: A single token or whitespace separated tokens. This should have
  1892. already been passed through *BasicTokenizer*.
  1893. Returns:
  1894. A list of wordpiece tokens.
  1895. """
  1896. output_tokens = []
  1897. for token in whitespace_tokenize(text):
  1898. chars = list(token)
  1899. if len(chars) > self.max_input_chars_per_word:
  1900. output_tokens.append(self.unk_token)
  1901. continue
  1902. is_bad = False
  1903. start = 0
  1904. sub_tokens = []
  1905. while start < len(chars):
  1906. end = len(chars)
  1907. cur_substr = None
  1908. while start < end:
  1909. substr = "".join(chars[start:end])
  1910. if start > 0:
  1911. substr = "##" + substr
  1912. if substr in self.vocab:
  1913. cur_substr = substr
  1914. break
  1915. end -= 1
  1916. if cur_substr is None:
  1917. is_bad = True
  1918. break
  1919. sub_tokens.append(cur_substr)
  1920. start = end
  1921. if is_bad:
  1922. output_tokens.append(self.unk_token)
  1923. else:
  1924. output_tokens.extend(sub_tokens)
  1925. return output_tokens
  1926. # Below: utilities for TAPAS tokenizer
  1927. # This includes functions to parse numeric values (dates and numbers) from both the table and questions in order
  1928. # to create the column_ranks, inv_column_ranks, numeric_values, numeric values_scale and numeric_relations in
  1929. # prepare_for_model of TapasTokenizer.
  1930. # These are meant to be used in an academic setup, for production use cases Gold mine or Aqua should be used.
  1931. # taken from constants.py of the original implementation
  1932. # URL: https://github.com/google-research/tapas/blob/master/tapas/utils/constants.py
  1933. class Relation(enum.Enum):
  1934. HEADER_TO_CELL = 1 # Connects header to cell.
  1935. CELL_TO_HEADER = 2 # Connects cell to header.
  1936. QUERY_TO_HEADER = 3 # Connects query to headers.
  1937. QUERY_TO_CELL = 4 # Connects query to cells.
  1938. ROW_TO_CELL = 5 # Connects row to cells.
  1939. CELL_TO_ROW = 6 # Connects cells to row.
  1940. EQ = 7 # Annotation value is same as cell value
  1941. LT = 8 # Annotation value is less than cell value
  1942. GT = 9 # Annotation value is greater than cell value
  1943. @dataclass
  1944. class Date:
  1945. year: int | None = None
  1946. month: int | None = None
  1947. day: int | None = None
  1948. @dataclass
  1949. class NumericValue:
  1950. float_value: float | None = None
  1951. date: Date | None = None
  1952. @dataclass
  1953. class NumericValueSpan:
  1954. begin_index: int | None = None
  1955. end_index: int | None = None
  1956. values: list[NumericValue] = None
  1957. @dataclass
  1958. class Cell:
  1959. text: str
  1960. numeric_value: NumericValue | None = None
  1961. @dataclass
  1962. class Question:
  1963. original_text: str # The original raw question string.
  1964. text: str # The question string after normalization.
  1965. numeric_spans: list[NumericValueSpan] | None = None
  1966. # Below: all functions from number_utils.py as well as 2 functions (namely get_all_spans and normalize_for_match)
  1967. # from text_utils.py of the original implementation. URL's:
  1968. # - https://github.com/google-research/tapas/blob/master/tapas/utils/number_utils.py
  1969. # - https://github.com/google-research/tapas/blob/master/tapas/utils/text_utils.py
  1970. # Constants for parsing date expressions.
  1971. # Masks that specify (by a bool) which of (year, month, day) will be populated.
  1972. _DateMask = collections.namedtuple("_DateMask", ["year", "month", "day"])
  1973. _YEAR = _DateMask(True, False, False)
  1974. _YEAR_MONTH = _DateMask(True, True, False)
  1975. _YEAR_MONTH_DAY = _DateMask(True, True, True)
  1976. _MONTH = _DateMask(False, True, False)
  1977. _MONTH_DAY = _DateMask(False, True, True)
  1978. # Pairs of patterns to pass to 'datetime.strptime' and masks specifying which
  1979. # fields will be set by the corresponding pattern.
  1980. _DATE_PATTERNS = (
  1981. ("%B", _MONTH),
  1982. ("%Y", _YEAR),
  1983. ("%Ys", _YEAR),
  1984. ("%b %Y", _YEAR_MONTH),
  1985. ("%B %Y", _YEAR_MONTH),
  1986. ("%B %d", _MONTH_DAY),
  1987. ("%b %d", _MONTH_DAY),
  1988. ("%d %b", _MONTH_DAY),
  1989. ("%d %B", _MONTH_DAY),
  1990. ("%B %d, %Y", _YEAR_MONTH_DAY),
  1991. ("%d %B %Y", _YEAR_MONTH_DAY),
  1992. ("%m-%d-%Y", _YEAR_MONTH_DAY),
  1993. ("%Y-%m-%d", _YEAR_MONTH_DAY),
  1994. ("%Y-%m", _YEAR_MONTH),
  1995. ("%B %Y", _YEAR_MONTH),
  1996. ("%d %b %Y", _YEAR_MONTH_DAY),
  1997. ("%Y-%m-%d", _YEAR_MONTH_DAY),
  1998. ("%b %d, %Y", _YEAR_MONTH_DAY),
  1999. ("%d.%m.%Y", _YEAR_MONTH_DAY),
  2000. ("%A, %b %d", _MONTH_DAY),
  2001. ("%A, %B %d", _MONTH_DAY),
  2002. )
  2003. # This mapping is used to convert date patterns to regex patterns.
  2004. _FIELD_TO_REGEX = (
  2005. ("%A", r"\w+"), # Weekday as locale’s full name.
  2006. ("%B", r"\w+"), # Month as locale’s full name.
  2007. ("%Y", r"\d{4}"), # Year with century as a decimal number.
  2008. ("%b", r"\w{3}"), # Month as locale’s abbreviated name.
  2009. ("%d", r"\d{1,2}"), # Day of the month as a zero-padded decimal number.
  2010. ("%m", r"\d{1,2}"), # Month as a zero-padded decimal number.
  2011. )
  2012. def _process_date_pattern(dp):
  2013. """Compute a regex for each date pattern to use as a prefilter."""
  2014. pattern, mask = dp
  2015. regex = pattern
  2016. regex = regex.replace(".", re.escape("."))
  2017. regex = regex.replace("-", re.escape("-"))
  2018. regex = regex.replace(" ", r"\s+")
  2019. for field, field_regex in _FIELD_TO_REGEX:
  2020. regex = regex.replace(field, field_regex)
  2021. # Make sure we didn't miss any of the fields.
  2022. assert "%" not in regex, regex
  2023. return pattern, mask, re.compile("^" + regex + "$")
  2024. def _process_date_patterns():
  2025. return tuple(_process_date_pattern(dp) for dp in _DATE_PATTERNS)
  2026. _PROCESSED_DATE_PATTERNS = _process_date_patterns()
  2027. _MAX_DATE_NGRAM_SIZE = 5
  2028. # Following DynSp:
  2029. # https://github.com/Microsoft/DynSP/blob/master/util.py#L414.
  2030. _NUMBER_WORDS = [
  2031. "zero",
  2032. "one",
  2033. "two",
  2034. "three",
  2035. "four",
  2036. "five",
  2037. "six",
  2038. "seven",
  2039. "eight",
  2040. "nine",
  2041. "ten",
  2042. "eleven",
  2043. "twelve",
  2044. ]
  2045. _ORDINAL_WORDS = [
  2046. "zeroth",
  2047. "first",
  2048. "second",
  2049. "third",
  2050. "fourth",
  2051. "fifth",
  2052. "sixth",
  2053. "seventh",
  2054. "eighth",
  2055. "ninth",
  2056. "tenth",
  2057. "eleventh",
  2058. "twelfth",
  2059. ]
  2060. _ORDINAL_SUFFIXES = ["st", "nd", "rd", "th"]
  2061. _NUMBER_PATTERN = re.compile(r"((^|\s)[+-])?((\.\d+)|(\d+(,\d\d\d)*(\.\d*)?))")
  2062. # Following DynSp:
  2063. # https://github.com/Microsoft/DynSP/blob/master/util.py#L293.
  2064. _MIN_YEAR = 1700
  2065. _MAX_YEAR = 2016
  2066. _INF = float("INF")
  2067. def _get_numeric_value_from_date(date, mask):
  2068. """Converts date (datetime Python object) to a NumericValue object with a Date object value."""
  2069. if date.year < _MIN_YEAR or date.year > _MAX_YEAR:
  2070. raise ValueError(f"Invalid year: {date.year}")
  2071. new_date = Date()
  2072. if mask.year:
  2073. new_date.year = date.year
  2074. if mask.month:
  2075. new_date.month = date.month
  2076. if mask.day:
  2077. new_date.day = date.day
  2078. return NumericValue(date=new_date)
  2079. def _get_span_length_key(span):
  2080. """Sorts span by decreasing length first and increasing first index second."""
  2081. return span[1] - span[0], -span[0]
  2082. def _get_numeric_value_from_float(value):
  2083. """Converts float (Python) to a NumericValue object with a float value."""
  2084. return NumericValue(float_value=value)
  2085. # Doesn't parse ordinal expressions such as '18th of february 1655'.
  2086. def _parse_date(text):
  2087. """Attempts to format a text as a standard date string (yyyy-mm-dd)."""
  2088. text = re.sub(r"Sept\b", "Sep", text)
  2089. for in_pattern, mask, regex in _PROCESSED_DATE_PATTERNS:
  2090. if not regex.match(text):
  2091. continue
  2092. try:
  2093. date = datetime.datetime.strptime(text, in_pattern).date()
  2094. except ValueError:
  2095. continue
  2096. try:
  2097. return _get_numeric_value_from_date(date, mask)
  2098. except ValueError:
  2099. continue
  2100. return None
  2101. def _parse_number(text):
  2102. """Parses simple cardinal and ordinals numbers."""
  2103. for suffix in _ORDINAL_SUFFIXES:
  2104. if text.endswith(suffix):
  2105. text = text[: -len(suffix)]
  2106. break
  2107. text = text.replace(",", "")
  2108. try:
  2109. value = float(text)
  2110. except ValueError:
  2111. return None
  2112. if math.isnan(value):
  2113. return None
  2114. if value == _INF:
  2115. return None
  2116. return value
  2117. def get_all_spans(text, max_ngram_length):
  2118. """
  2119. Split a text into all possible ngrams up to 'max_ngram_length'. Split points are white space and punctuation.
  2120. Args:
  2121. text: Text to split.
  2122. max_ngram_length: maximal ngram length.
  2123. Yields:
  2124. Spans, tuples of begin-end index.
  2125. """
  2126. start_indexes = []
  2127. for index, char in enumerate(text):
  2128. if not char.isalnum():
  2129. continue
  2130. if index == 0 or not text[index - 1].isalnum():
  2131. start_indexes.append(index)
  2132. if index + 1 == len(text) or not text[index + 1].isalnum():
  2133. for start_index in start_indexes[-max_ngram_length:]:
  2134. yield start_index, index + 1
  2135. def normalize_for_match(text):
  2136. return " ".join(text.lower().split())
  2137. def format_text(text):
  2138. """Lowercases and strips punctuation."""
  2139. text = text.lower().strip()
  2140. if text == "n/a" or text == "?" or text == "nan":
  2141. text = EMPTY_TEXT
  2142. text = re.sub(r"[^\w\d]+", " ", text).replace("_", " ")
  2143. text = " ".join(text.split())
  2144. text = text.strip()
  2145. if text:
  2146. return text
  2147. return EMPTY_TEXT
  2148. def parse_text(text):
  2149. """
  2150. Extracts longest number and date spans.
  2151. Args:
  2152. text: text to annotate
  2153. Returns:
  2154. List of longest numeric value spans.
  2155. """
  2156. span_dict = collections.defaultdict(list)
  2157. for match in _NUMBER_PATTERN.finditer(text):
  2158. span_text = text[match.start() : match.end()]
  2159. number = _parse_number(span_text)
  2160. if number is not None:
  2161. span_dict[match.span()].append(_get_numeric_value_from_float(number))
  2162. for begin_index, end_index in get_all_spans(text, max_ngram_length=1):
  2163. if (begin_index, end_index) in span_dict:
  2164. continue
  2165. span_text = text[begin_index:end_index]
  2166. number = _parse_number(span_text)
  2167. if number is not None:
  2168. span_dict[begin_index, end_index].append(_get_numeric_value_from_float(number))
  2169. for number, word in enumerate(_NUMBER_WORDS):
  2170. if span_text == word:
  2171. span_dict[begin_index, end_index].append(_get_numeric_value_from_float(float(number)))
  2172. break
  2173. for number, word in enumerate(_ORDINAL_WORDS):
  2174. if span_text == word:
  2175. span_dict[begin_index, end_index].append(_get_numeric_value_from_float(float(number)))
  2176. break
  2177. for begin_index, end_index in get_all_spans(text, max_ngram_length=_MAX_DATE_NGRAM_SIZE):
  2178. span_text = text[begin_index:end_index]
  2179. date = _parse_date(span_text)
  2180. if date is not None:
  2181. span_dict[begin_index, end_index].append(date)
  2182. spans = sorted(span_dict.items(), key=lambda span_value: _get_span_length_key(span_value[0]), reverse=True)
  2183. selected_spans = []
  2184. for span, value in spans:
  2185. for selected_span, _ in selected_spans:
  2186. if selected_span[0] <= span[0] and span[1] <= selected_span[1]:
  2187. break
  2188. else:
  2189. selected_spans.append((span, value))
  2190. selected_spans.sort(key=lambda span_value: span_value[0][0])
  2191. numeric_value_spans = []
  2192. for span, values in selected_spans:
  2193. numeric_value_spans.append(NumericValueSpan(begin_index=span[0], end_index=span[1], values=values))
  2194. return numeric_value_spans
  2195. # Below: all functions from number_annotation_utils.py and 2 functions (namely filter_invalid_unicode
  2196. # and filter_invalid_unicode_from_table) from text_utils.py of the original implementation. URL's:
  2197. # - https://github.com/google-research/tapas/blob/master/tapas/utils/number_annotation_utils.py
  2198. # - https://github.com/google-research/tapas/blob/master/tapas/utils/text_utils.py
  2199. _PrimitiveNumericValue = float | tuple[float | None]
  2200. _SortKeyFn = Callable[[NumericValue], tuple[float, Ellipsis]]
  2201. _DATE_TUPLE_SIZE = 3
  2202. EMPTY_TEXT = "EMPTY"
  2203. NUMBER_TYPE = "number"
  2204. DATE_TYPE = "date"
  2205. def _get_value_type(numeric_value):
  2206. if numeric_value.float_value is not None:
  2207. return NUMBER_TYPE
  2208. elif numeric_value.date is not None:
  2209. return DATE_TYPE
  2210. raise ValueError(f"Unknown type: {numeric_value}")
  2211. def _get_value_as_primitive_value(numeric_value):
  2212. """Maps a NumericValue proto to a float or tuple of float."""
  2213. if numeric_value.float_value is not None:
  2214. return numeric_value.float_value
  2215. if numeric_value.date is not None:
  2216. date = numeric_value.date
  2217. value_tuple = [None, None, None]
  2218. # All dates fields are cased to float to produce a simple primitive value.
  2219. if date.year is not None:
  2220. value_tuple[0] = float(date.year)
  2221. if date.month is not None:
  2222. value_tuple[1] = float(date.month)
  2223. if date.day is not None:
  2224. value_tuple[2] = float(date.day)
  2225. return tuple(value_tuple)
  2226. raise ValueError(f"Unknown type: {numeric_value}")
  2227. def _get_all_types(numeric_values):
  2228. return {_get_value_type(value) for value in numeric_values}
  2229. def get_numeric_sort_key_fn(numeric_values):
  2230. """
  2231. Creates a function that can be used as a sort key or to compare the values. Maps to primitive types and finds the
  2232. biggest common subset. Consider the values "05/05/2010" and "August 2007". With the corresponding primitive values
  2233. (2010.,5.,5.) and (2007.,8., None). These values can be compared by year and date so we map to the sequence (2010.,
  2234. 5.), (2007., 8.). If we added a third value "2006" with primitive value (2006., None, None), we could only compare
  2235. by the year so we would map to (2010.,), (2007.,) and (2006.,).
  2236. Args:
  2237. numeric_values: Values to compare
  2238. Returns:
  2239. A function that can be used as a sort key function (mapping numeric values to a comparable tuple)
  2240. Raises:
  2241. ValueError if values don't have a common type or are not comparable.
  2242. """
  2243. value_types = _get_all_types(numeric_values)
  2244. if len(value_types) != 1:
  2245. raise ValueError(f"No common value type in {numeric_values}")
  2246. value_type = next(iter(value_types))
  2247. if value_type == NUMBER_TYPE:
  2248. # Primitive values are simple floats, nothing to do here.
  2249. return _get_value_as_primitive_value
  2250. # The type can only be Date at this point which means the primitive type
  2251. # is a float triple.
  2252. valid_indexes = set(range(_DATE_TUPLE_SIZE))
  2253. for numeric_value in numeric_values:
  2254. value = _get_value_as_primitive_value(numeric_value)
  2255. assert isinstance(value, tuple)
  2256. for tuple_index, inner_value in enumerate(value):
  2257. if inner_value is None:
  2258. valid_indexes.discard(tuple_index)
  2259. if not valid_indexes:
  2260. raise ValueError(f"No common value in {numeric_values}")
  2261. def _sort_key_fn(numeric_value):
  2262. value = _get_value_as_primitive_value(numeric_value)
  2263. return tuple(value[index] for index in valid_indexes)
  2264. return _sort_key_fn
  2265. def _consolidate_numeric_values(row_index_to_values, min_consolidation_fraction, debug_info):
  2266. """
  2267. Finds the most common numeric values in a column and returns them
  2268. Args:
  2269. row_index_to_values:
  2270. For each row index all the values in that cell.
  2271. min_consolidation_fraction:
  2272. Fraction of cells that need to have consolidated value.
  2273. debug_info:
  2274. Additional information only used for logging
  2275. Returns:
  2276. For each row index the first value that matches the most common value. Rows that don't have a matching value
  2277. are dropped. Empty list if values can't be consolidated.
  2278. """
  2279. type_counts = collections.Counter()
  2280. for numeric_values in row_index_to_values.values():
  2281. type_counts.update(_get_all_types(numeric_values))
  2282. if not type_counts:
  2283. return {}
  2284. max_count = max(type_counts.values())
  2285. if max_count < len(row_index_to_values) * min_consolidation_fraction:
  2286. # logging.log_every_n(logging.INFO, f'Can\'t consolidate types: {debug_info} {row_index_to_values} {max_count}', 100)
  2287. return {}
  2288. valid_types = set()
  2289. for value_type, count in type_counts.items():
  2290. if count == max_count:
  2291. valid_types.add(value_type)
  2292. if len(valid_types) > 1:
  2293. assert DATE_TYPE in valid_types
  2294. max_type = DATE_TYPE
  2295. else:
  2296. max_type = next(iter(valid_types))
  2297. new_row_index_to_value = {}
  2298. for index, values in row_index_to_values.items():
  2299. # Extract the first matching value.
  2300. for value in values:
  2301. if _get_value_type(value) == max_type:
  2302. new_row_index_to_value[index] = value
  2303. break
  2304. return new_row_index_to_value
  2305. def _get_numeric_values(text):
  2306. """Parses text and returns numeric values."""
  2307. numeric_spans = parse_text(text)
  2308. return itertools.chain(*(span.values for span in numeric_spans))
  2309. def _get_column_values(table, col_index):
  2310. """
  2311. Parses text in column and returns a dict mapping row_index to values. This is the _get_column_values function from
  2312. number_annotation_utils.py of the original implementation
  2313. Args:
  2314. table: Pandas dataframe
  2315. col_index: integer, indicating the index of the column to get the numeric values of
  2316. """
  2317. index_to_values = {}
  2318. for row_index, row in table.iterrows():
  2319. text = normalize_for_match(row[col_index].text)
  2320. index_to_values[row_index] = list(_get_numeric_values(text))
  2321. return index_to_values
  2322. def get_numeric_relation(value, other_value, sort_key_fn):
  2323. """Compares two values and returns their relation or None."""
  2324. value = sort_key_fn(value)
  2325. other_value = sort_key_fn(other_value)
  2326. if value == other_value:
  2327. return Relation.EQ
  2328. if value < other_value:
  2329. return Relation.LT
  2330. if value > other_value:
  2331. return Relation.GT
  2332. return None
  2333. def add_numeric_values_to_question(question):
  2334. """Adds numeric value spans to a question."""
  2335. original_text = question
  2336. question = normalize_for_match(question)
  2337. numeric_spans = parse_text(question)
  2338. return Question(original_text=original_text, text=question, numeric_spans=numeric_spans)
  2339. def filter_invalid_unicode(text):
  2340. """Return an empty string and True if 'text' is in invalid unicode."""
  2341. return ("", True) if isinstance(text, bytes) else (text, False)
  2342. def filter_invalid_unicode_from_table(table):
  2343. """
  2344. Removes invalid unicode from table. Checks whether a table cell text contains an invalid unicode encoding. If yes,
  2345. reset the table cell text to an empty str and log a warning for each invalid cell
  2346. Args:
  2347. table: table to clean.
  2348. """
  2349. # to do: add table id support
  2350. if not hasattr(table, "table_id"):
  2351. table.table_id = 0
  2352. for row_index, row in table.iterrows():
  2353. for col_index, cell in enumerate(row):
  2354. cell, is_invalid = filter_invalid_unicode(cell)
  2355. if is_invalid:
  2356. logging.warning(
  2357. f"Scrub an invalid table body @ table_id: {table.table_id}, row_index: {row_index}, "
  2358. f"col_index: {col_index}",
  2359. )
  2360. for col_index, column in enumerate(table.columns):
  2361. column, is_invalid = filter_invalid_unicode(column)
  2362. if is_invalid:
  2363. logging.warning(f"Scrub an invalid table header @ table_id: {table.table_id}, col_index: {col_index}")
  2364. def add_numeric_table_values(table, min_consolidation_fraction=0.7, debug_info=None):
  2365. """
  2366. Parses text in table column-wise and adds the consolidated values. Consolidation refers to finding values with a
  2367. common types (date or number)
  2368. Args:
  2369. table:
  2370. Table to annotate.
  2371. min_consolidation_fraction:
  2372. Fraction of cells in a column that need to have consolidated value.
  2373. debug_info:
  2374. Additional information used for logging.
  2375. """
  2376. table = table.copy()
  2377. # First, filter table on invalid unicode
  2378. filter_invalid_unicode_from_table(table)
  2379. # Second, replace cell values by Cell objects
  2380. for row_index, row in table.iterrows():
  2381. for col_index, cell in enumerate(row):
  2382. table.iloc[row_index, col_index] = Cell(text=cell)
  2383. # Third, add numeric_value attributes to these Cell objects
  2384. for col_index, column in enumerate(table.columns):
  2385. column_values = _consolidate_numeric_values(
  2386. _get_column_values(table, col_index),
  2387. min_consolidation_fraction=min_consolidation_fraction,
  2388. debug_info=(debug_info, column),
  2389. )
  2390. for row_index, numeric_value in column_values.items():
  2391. table.iloc[row_index, col_index].numeric_value = numeric_value
  2392. return table
  2393. __all__ = ["TapasTokenizer"]