modeling_tapas.py 95 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105
  1. # Copyright 2020 Google Research and The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch TAPAS model."""
  15. import enum
  16. import math
  17. from dataclasses import dataclass
  18. import torch
  19. from torch import nn
  20. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  21. from ... import initialization as init
  22. from ...activations import ACT2FN
  23. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput
  26. from ...modeling_utils import PreTrainedModel
  27. from ...pytorch_utils import apply_chunking_to_forward
  28. from ...utils import ModelOutput, auto_docstring, logging
  29. from .configuration_tapas import TapasConfig
  30. logger = logging.get_logger(__name__)
  31. EPSILON_ZERO_DIVISION = 1e-10
  32. CLOSE_ENOUGH_TO_LOG_ZERO = -10000.0
  33. @dataclass
  34. @auto_docstring(
  35. custom_intro="""
  36. Output type of [`TapasForQuestionAnswering`].
  37. """
  38. )
  39. class TableQuestionAnsweringOutput(ModelOutput):
  40. r"""
  41. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` (and possibly `answer`, `aggregation_labels`, `numeric_values` and `numeric_values_scale` are provided)):
  42. Total loss as the sum of the hierarchical cell selection log-likelihood loss and (optionally) the
  43. semi-supervised regression loss and (optionally) supervised loss for aggregations.
  44. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  45. Prediction scores of the cell selection head, for every token.
  46. logits_aggregation (`torch.FloatTensor`, *optional*, of shape `(batch_size, num_aggregation_labels)`):
  47. Prediction scores of the aggregation head, for every aggregation operator.
  48. """
  49. loss: torch.FloatTensor | None = None
  50. logits: torch.FloatTensor | None = None
  51. logits_aggregation: torch.FloatTensor | None = None
  52. hidden_states: tuple[torch.FloatTensor] | None = None
  53. attentions: tuple[torch.FloatTensor] | None = None
  54. class TapasEmbeddings(nn.Module):
  55. """
  56. Construct the embeddings from word, position and token_type embeddings. Same as BertEmbeddings but with a number of
  57. additional token type embeddings to encode tabular structure.
  58. """
  59. def __init__(self, config):
  60. super().__init__()
  61. # we do not include config.disabled_features and config.disable_position_embeddings from the original implementation
  62. # word embeddings
  63. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  64. # position embeddings
  65. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  66. # token type embeddings
  67. for i, type_vocab_sizes in enumerate(config.type_vocab_sizes):
  68. name = f"token_type_embeddings_{i}"
  69. setattr(self, name, nn.Embedding(type_vocab_sizes, config.hidden_size))
  70. self.number_of_token_type_embeddings = len(config.type_vocab_sizes)
  71. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  72. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  73. self.config = config
  74. def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
  75. if input_ids is not None:
  76. input_shape = input_ids.size()
  77. else:
  78. input_shape = inputs_embeds.size()[:-1]
  79. seq_length = input_shape[1]
  80. device = input_ids.device if input_ids is not None else inputs_embeds.device
  81. if position_ids is None:
  82. # create absolute position embeddings
  83. position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
  84. position_ids = position_ids.unsqueeze(0).expand(input_shape)
  85. # when self.config.reset_position_index_per_cell is set to True, create relative position embeddings
  86. if self.config.reset_position_index_per_cell:
  87. # shape (batch_size, seq_len)
  88. col_index = IndexMap(token_type_ids[:, :, 1], self.config.type_vocab_sizes[1], batch_dims=1)
  89. # shape (batch_size, seq_len)
  90. row_index = IndexMap(token_type_ids[:, :, 2], self.config.type_vocab_sizes[2], batch_dims=1)
  91. # shape (batch_size, seq_len)
  92. full_index = ProductIndexMap(col_index, row_index)
  93. # shape (max_rows * max_columns,). First absolute position for every cell
  94. first_position_per_segment = reduce_min(position_ids, full_index)[0]
  95. # ? shape (batch_size, seq_len). First absolute position of the cell for every token
  96. first_position = gather(first_position_per_segment, full_index)
  97. # shape (1, seq_len)
  98. position = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0)
  99. position_ids = torch.min(
  100. torch.as_tensor(self.config.max_position_embeddings - 1, device=device), position - first_position
  101. )
  102. if token_type_ids is None:
  103. token_type_ids = torch.zeros(
  104. (input_shape + self.number_of_token_type_embeddings), dtype=torch.long, device=device
  105. )
  106. if inputs_embeds is None:
  107. inputs_embeds = self.word_embeddings(input_ids)
  108. position_embeddings = self.position_embeddings(position_ids)
  109. embeddings = inputs_embeds + position_embeddings
  110. for i in range(self.number_of_token_type_embeddings):
  111. name = f"token_type_embeddings_{i}"
  112. embeddings += getattr(self, name)(token_type_ids[:, :, i])
  113. embeddings = self.LayerNorm(embeddings)
  114. embeddings = self.dropout(embeddings)
  115. return embeddings
  116. class TapasSelfAttention(nn.Module):
  117. def __init__(self, config, layer_idx=None):
  118. super().__init__()
  119. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  120. raise ValueError(
  121. f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
  122. f"heads {config.num_attention_heads}"
  123. )
  124. self.num_attention_heads = config.num_attention_heads
  125. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  126. self.all_head_size = self.num_attention_heads * self.attention_head_size
  127. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  128. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  129. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  130. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  131. self.is_decoder = config.is_decoder
  132. self.layer_idx = layer_idx
  133. def forward(
  134. self,
  135. hidden_states,
  136. attention_mask=None,
  137. encoder_hidden_states=None,
  138. past_key_values=None,
  139. output_attentions=False,
  140. **kwargs,
  141. ):
  142. input_shape = hidden_states.shape[:-1]
  143. hidden_shape = (*input_shape, -1, self.attention_head_size)
  144. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  145. is_updated = False
  146. is_cross_attention = encoder_hidden_states is not None
  147. if past_key_values is not None:
  148. if isinstance(past_key_values, EncoderDecoderCache):
  149. is_updated = past_key_values.is_updated.get(self.layer_idx)
  150. if is_cross_attention:
  151. # after the first generated id, we can subsequently re-use all key/value_layer from cache
  152. curr_past_key_values = past_key_values.cross_attention_cache
  153. else:
  154. curr_past_key_values = past_key_values.self_attention_cache
  155. else:
  156. curr_past_key_values = past_key_values
  157. current_states = encoder_hidden_states if is_cross_attention else hidden_states
  158. if is_cross_attention and past_key_values is not None and is_updated:
  159. # reuse k,v, cross_attentions
  160. key_layer = curr_past_key_values.layers[self.layer_idx].keys
  161. value_layer = curr_past_key_values.layers[self.layer_idx].values
  162. else:
  163. key_layer = self.key(current_states).view(hidden_shape).transpose(1, 2)
  164. value_layer = self.value(current_states).view(hidden_shape).transpose(1, 2)
  165. if past_key_values is not None:
  166. # save all key/value_layer to cache to be re-used for fast auto-regressive generation
  167. key_layer, value_layer = curr_past_key_values.update(key_layer, value_layer, self.layer_idx)
  168. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  169. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  170. past_key_values.is_updated[self.layer_idx] = True
  171. # Take the dot product between "query" and "key" to get the raw attention scores.
  172. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  173. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  174. if attention_mask is not None:
  175. # Apply the attention mask is (precomputed for all layers in TapasModel forward() function)
  176. attention_scores = attention_scores + attention_mask
  177. # Normalize the attention scores to probabilities.
  178. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  179. # This is actually dropping out entire tokens to attend to, which might
  180. # seem a bit unusual, but is taken from the original Transformer paper.
  181. attention_probs = self.dropout(attention_probs)
  182. context_layer = torch.matmul(attention_probs, value_layer)
  183. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  184. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  185. context_layer = context_layer.view(*new_context_layer_shape)
  186. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  187. if self.is_decoder:
  188. outputs = outputs + (past_key_values,)
  189. return outputs
  190. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
  191. class TapasSelfOutput(nn.Module):
  192. def __init__(self, config):
  193. super().__init__()
  194. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  195. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  196. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  197. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  198. hidden_states = self.dense(hidden_states)
  199. hidden_states = self.dropout(hidden_states)
  200. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  201. return hidden_states
  202. class TapasAttention(nn.Module):
  203. def __init__(self, config, layer_idx=None):
  204. super().__init__()
  205. self.self = TapasSelfAttention(config, layer_idx=layer_idx)
  206. self.output = TapasSelfOutput(config)
  207. # Copied from transformers.models.rembert.modeling_rembert.RemBertAttention.forward
  208. def forward(
  209. self,
  210. hidden_states: torch.Tensor,
  211. attention_mask: torch.FloatTensor | None = None,
  212. encoder_hidden_states: torch.FloatTensor | None = None,
  213. past_key_values: Cache | None = None,
  214. output_attentions: bool | None = False,
  215. **kwargs,
  216. ) -> tuple[torch.Tensor]:
  217. self_outputs = self.self(
  218. hidden_states,
  219. attention_mask=attention_mask,
  220. encoder_hidden_states=encoder_hidden_states,
  221. past_key_values=past_key_values,
  222. output_attentions=output_attentions,
  223. )
  224. attention_output = self.output(self_outputs[0], hidden_states)
  225. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  226. return outputs
  227. # Copied from transformers.models.bert.modeling_bert.BertIntermediate
  228. class TapasIntermediate(nn.Module):
  229. def __init__(self, config):
  230. super().__init__()
  231. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  232. if isinstance(config.hidden_act, str):
  233. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  234. else:
  235. self.intermediate_act_fn = config.hidden_act
  236. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  237. hidden_states = self.dense(hidden_states)
  238. hidden_states = self.intermediate_act_fn(hidden_states)
  239. return hidden_states
  240. # Copied from transformers.models.bert.modeling_bert.BertOutput
  241. class TapasOutput(nn.Module):
  242. def __init__(self, config):
  243. super().__init__()
  244. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  245. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  246. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  247. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  248. hidden_states = self.dense(hidden_states)
  249. hidden_states = self.dropout(hidden_states)
  250. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  251. return hidden_states
  252. class TapasLayer(GradientCheckpointingLayer):
  253. def __init__(self, config, layer_idx=None):
  254. super().__init__()
  255. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  256. self.seq_len_dim = 1
  257. self.attention = TapasAttention(config, layer_idx=layer_idx)
  258. self.is_decoder = config.is_decoder
  259. self.add_cross_attention = config.add_cross_attention
  260. if self.add_cross_attention:
  261. if not self.is_decoder:
  262. raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
  263. self.crossattention = TapasAttention(config, layer_idx=layer_idx)
  264. self.intermediate = TapasIntermediate(config)
  265. self.output = TapasOutput(config)
  266. # Copied from transformers.models.rembert.modeling_rembert.RemBertLayer.forward
  267. def forward(
  268. self,
  269. hidden_states: torch.Tensor,
  270. attention_mask: torch.FloatTensor | None = None,
  271. encoder_hidden_states: torch.FloatTensor | None = None,
  272. encoder_attention_mask: torch.FloatTensor | None = None,
  273. past_key_values: Cache | None = None,
  274. output_attentions: bool | None = False,
  275. **kwargs,
  276. ) -> tuple[torch.Tensor]:
  277. self_attention_outputs = self.attention(
  278. hidden_states,
  279. attention_mask=attention_mask,
  280. output_attentions=output_attentions,
  281. past_key_values=past_key_values,
  282. )
  283. attention_output = self_attention_outputs[0]
  284. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  285. if self.is_decoder and encoder_hidden_states is not None:
  286. if not hasattr(self, "crossattention"):
  287. raise ValueError(
  288. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
  289. " by setting `config.add_cross_attention=True`"
  290. )
  291. cross_attention_outputs = self.crossattention(
  292. attention_output,
  293. attention_mask=encoder_attention_mask,
  294. encoder_hidden_states=encoder_hidden_states,
  295. past_key_values=past_key_values,
  296. output_attentions=output_attentions,
  297. )
  298. attention_output = cross_attention_outputs[0]
  299. outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
  300. layer_output = apply_chunking_to_forward(
  301. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  302. )
  303. outputs = (layer_output,) + outputs
  304. return outputs
  305. # Copied from transformers.models.bert.modeling_bert.BertLayer.feed_forward_chunk
  306. def feed_forward_chunk(self, attention_output):
  307. intermediate_output = self.intermediate(attention_output)
  308. layer_output = self.output(intermediate_output, attention_output)
  309. return layer_output
  310. class TapasEncoder(nn.Module):
  311. def __init__(self, config):
  312. super().__init__()
  313. self.config = config
  314. self.layer = nn.ModuleList([TapasLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
  315. self.gradient_checkpointing = False
  316. def forward(
  317. self,
  318. hidden_states,
  319. attention_mask=None,
  320. encoder_hidden_states=None,
  321. encoder_attention_mask=None,
  322. past_key_values=None,
  323. use_cache=None,
  324. output_attentions=False,
  325. output_hidden_states=False,
  326. return_dict=True,
  327. **kwargs,
  328. ):
  329. if use_cache and past_key_values is None:
  330. past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  331. all_hidden_states = () if output_hidden_states else None
  332. all_attentions = () if output_attentions else None
  333. for i, layer_module in enumerate(self.layer):
  334. if output_hidden_states:
  335. all_hidden_states = all_hidden_states + (hidden_states,)
  336. layer_outputs = layer_module(
  337. hidden_states,
  338. attention_mask,
  339. encoder_hidden_states, # as a positional argument for gradient checkpointing
  340. encoder_attention_mask=encoder_attention_mask,
  341. past_key_values=past_key_values,
  342. output_attentions=output_attentions,
  343. )
  344. hidden_states = layer_outputs[0]
  345. if output_attentions:
  346. all_attentions = all_attentions + (layer_outputs[1],)
  347. if output_hidden_states:
  348. all_hidden_states = all_hidden_states + (hidden_states,)
  349. if not return_dict:
  350. return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
  351. return BaseModelOutput(
  352. last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
  353. )
  354. # Copied from transformers.models.bert.modeling_bert.BertPooler
  355. class TapasPooler(nn.Module):
  356. def __init__(self, config):
  357. super().__init__()
  358. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  359. self.activation = nn.Tanh()
  360. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  361. # We "pool" the model by simply taking the hidden state corresponding
  362. # to the first token.
  363. first_token_tensor = hidden_states[:, 0]
  364. pooled_output = self.dense(first_token_tensor)
  365. pooled_output = self.activation(pooled_output)
  366. return pooled_output
  367. # Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->Tapas
  368. class TapasPredictionHeadTransform(nn.Module):
  369. def __init__(self, config):
  370. super().__init__()
  371. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  372. if isinstance(config.hidden_act, str):
  373. self.transform_act_fn = ACT2FN[config.hidden_act]
  374. else:
  375. self.transform_act_fn = config.hidden_act
  376. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  377. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  378. hidden_states = self.dense(hidden_states)
  379. hidden_states = self.transform_act_fn(hidden_states)
  380. hidden_states = self.LayerNorm(hidden_states)
  381. return hidden_states
  382. # Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->Tapas
  383. class TapasLMPredictionHead(nn.Module):
  384. def __init__(self, config):
  385. super().__init__()
  386. self.transform = TapasPredictionHeadTransform(config)
  387. # The output weights are the same as the input embeddings, but there is
  388. # an output-only bias for each token.
  389. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
  390. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  391. def forward(self, hidden_states):
  392. hidden_states = self.transform(hidden_states)
  393. hidden_states = self.decoder(hidden_states)
  394. return hidden_states
  395. # Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->Tapas
  396. class TapasOnlyMLMHead(nn.Module):
  397. def __init__(self, config):
  398. super().__init__()
  399. self.predictions = TapasLMPredictionHead(config)
  400. def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
  401. prediction_scores = self.predictions(sequence_output)
  402. return prediction_scores
  403. @auto_docstring
  404. class TapasPreTrainedModel(PreTrainedModel):
  405. config: TapasConfig
  406. base_model_prefix = "tapas"
  407. supports_gradient_checkpointing = True
  408. @torch.no_grad()
  409. def _init_weights(self, module):
  410. """Initialize the weights"""
  411. super()._init_weights(module)
  412. if isinstance(module, TapasLMPredictionHead):
  413. init.zeros_(module.bias)
  414. if isinstance(module, TapasForQuestionAnswering):
  415. if module.config.init_cell_selection_weights_to_zero:
  416. init.zeros_(module.output_weights)
  417. init.zeros_(module.column_output_weights)
  418. else:
  419. init.normal_(module.output_weights, std=module.config.initializer_range)
  420. init.normal_(module.column_output_weights, std=module.config.initializer_range)
  421. init.zeros_(module.output_bias)
  422. init.zeros_(module.column_output_bias)
  423. @auto_docstring
  424. class TapasModel(TapasPreTrainedModel):
  425. """
  426. This class is a small change compared to [`BertModel`], taking into account the additional token type ids.
  427. The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
  428. cross-attention is added between the self-attention layers, following the architecture described in [Attention is
  429. all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
  430. Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
  431. """
  432. def __init__(self, config, add_pooling_layer=True):
  433. r"""
  434. add_pooling_layer (bool, *optional*, defaults to `True`):
  435. Whether to add a pooling layer
  436. """
  437. super().__init__(config)
  438. self.config = config
  439. self.embeddings = TapasEmbeddings(config)
  440. self.encoder = TapasEncoder(config)
  441. self.pooler = TapasPooler(config) if add_pooling_layer else None
  442. # Initialize weights and apply final processing
  443. self.post_init()
  444. def get_input_embeddings(self):
  445. return self.embeddings.word_embeddings
  446. def set_input_embeddings(self, value):
  447. self.embeddings.word_embeddings = value
  448. @auto_docstring
  449. def forward(
  450. self,
  451. input_ids: torch.LongTensor | None = None,
  452. attention_mask: torch.FloatTensor | None = None,
  453. token_type_ids: torch.LongTensor | None = None,
  454. position_ids: torch.LongTensor | None = None,
  455. inputs_embeds: torch.FloatTensor | None = None,
  456. encoder_hidden_states: torch.FloatTensor | None = None,
  457. encoder_attention_mask: torch.FloatTensor | None = None,
  458. output_attentions: bool | None = None,
  459. output_hidden_states: bool | None = None,
  460. return_dict: bool | None = None,
  461. **kwargs,
  462. ) -> tuple | BaseModelOutputWithPooling:
  463. r"""
  464. token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, 7)`, *optional*):
  465. Token indices that encode tabular structure. Indices can be obtained using [`AutoTokenizer`]. See this
  466. class for more info.
  467. [What are token type IDs?](../glossary#token-type-ids)
  468. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  469. Indices of positions of each input sequence tokens in the position embeddings. If
  470. `reset_position_index_per_cell` of [`TapasConfig`] is set to `True`, relative position embeddings will be
  471. used. Selected in the range `[0, config.max_position_embeddings - 1]`.
  472. [What are position IDs?](../glossary#position-ids)
  473. Examples:
  474. ```python
  475. >>> from transformers import AutoTokenizer, TapasModel
  476. >>> import pandas as pd
  477. >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base")
  478. >>> model = TapasModel.from_pretrained("google/tapas-base")
  479. >>> data = {
  480. ... "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
  481. ... "Age": ["56", "45", "59"],
  482. ... "Number of movies": ["87", "53", "69"],
  483. ... }
  484. >>> table = pd.DataFrame.from_dict(data)
  485. >>> queries = ["How many movies has George Clooney played in?", "How old is Brad Pitt?"]
  486. >>> inputs = tokenizer(table=table, queries=queries, padding="max_length", return_tensors="pt")
  487. >>> outputs = model(**inputs)
  488. >>> last_hidden_states = outputs.last_hidden_state
  489. ```"""
  490. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  491. output_hidden_states = (
  492. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  493. )
  494. return_dict = return_dict if return_dict is not None else self.config.return_dict
  495. if input_ids is not None and inputs_embeds is not None:
  496. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  497. elif input_ids is not None:
  498. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  499. input_shape = input_ids.size()
  500. elif inputs_embeds is not None:
  501. input_shape = inputs_embeds.size()[:-1]
  502. else:
  503. raise ValueError("You have to specify either input_ids or inputs_embeds")
  504. device = input_ids.device if input_ids is not None else inputs_embeds.device
  505. if attention_mask is None:
  506. attention_mask = torch.ones(input_shape, device=device)
  507. if token_type_ids is None:
  508. token_type_ids = torch.zeros(
  509. (*input_shape, len(self.config.type_vocab_sizes)), dtype=torch.long, device=device
  510. )
  511. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  512. # ourselves in which case we just need to make it broadcastable to all heads.
  513. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
  514. # If a 2D ou 3D attention mask is provided for the cross-attention
  515. # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
  516. if self.config.is_decoder and encoder_hidden_states is not None:
  517. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  518. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  519. if encoder_attention_mask is None:
  520. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
  521. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  522. else:
  523. encoder_extended_attention_mask = None
  524. embedding_output = self.embeddings(
  525. input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
  526. )
  527. encoder_outputs = self.encoder(
  528. embedding_output,
  529. attention_mask=extended_attention_mask,
  530. encoder_hidden_states=encoder_hidden_states,
  531. encoder_attention_mask=encoder_extended_attention_mask,
  532. output_attentions=output_attentions,
  533. output_hidden_states=output_hidden_states,
  534. return_dict=return_dict,
  535. )
  536. sequence_output = encoder_outputs[0]
  537. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  538. if not return_dict:
  539. return (sequence_output, pooled_output) + encoder_outputs[1:]
  540. return BaseModelOutputWithPooling(
  541. last_hidden_state=sequence_output,
  542. pooler_output=pooled_output,
  543. hidden_states=encoder_outputs.hidden_states,
  544. attentions=encoder_outputs.attentions,
  545. )
  546. @auto_docstring
  547. class TapasForMaskedLM(TapasPreTrainedModel):
  548. _tied_weights_keys = {
  549. "cls.predictions.decoder.bias": "cls.predictions.bias",
  550. "cls.predictions.decoder.weight": "tapas.embeddings.word_embeddings.weight",
  551. }
  552. config: TapasConfig
  553. base_model_prefix = "tapas"
  554. def __init__(self, config):
  555. super().__init__(config)
  556. self.tapas = TapasModel(config, add_pooling_layer=False)
  557. self.cls = TapasOnlyMLMHead(config)
  558. # Initialize weights and apply final processing
  559. self.post_init()
  560. def get_output_embeddings(self):
  561. return self.cls.predictions.decoder
  562. def set_output_embeddings(self, new_embeddings):
  563. self.cls.predictions.decoder = new_embeddings
  564. self.cls.predictions.bias = new_embeddings.bias
  565. @auto_docstring
  566. def forward(
  567. self,
  568. input_ids: torch.LongTensor | None = None,
  569. attention_mask: torch.FloatTensor | None = None,
  570. token_type_ids: torch.LongTensor | None = None,
  571. position_ids: torch.LongTensor | None = None,
  572. inputs_embeds: torch.FloatTensor | None = None,
  573. encoder_hidden_states: torch.FloatTensor | None = None,
  574. encoder_attention_mask: torch.FloatTensor | None = None,
  575. labels: torch.LongTensor | None = None,
  576. output_attentions: bool | None = None,
  577. output_hidden_states: bool | None = None,
  578. return_dict: bool | None = None,
  579. **kwargs,
  580. ) -> tuple | MaskedLMOutput:
  581. r"""
  582. token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, 7)`, *optional*):
  583. Token indices that encode tabular structure. Indices can be obtained using [`AutoTokenizer`]. See this
  584. class for more info.
  585. [What are token type IDs?](../glossary#token-type-ids)
  586. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  587. Indices of positions of each input sequence tokens in the position embeddings. If
  588. `reset_position_index_per_cell` of [`TapasConfig`] is set to `True`, relative position embeddings will be
  589. used. Selected in the range `[0, config.max_position_embeddings - 1]`.
  590. [What are position IDs?](../glossary#position-ids)
  591. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  592. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  593. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  594. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  595. Examples:
  596. ```python
  597. >>> from transformers import AutoTokenizer, TapasForMaskedLM
  598. >>> import pandas as pd
  599. >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base")
  600. >>> model = TapasForMaskedLM.from_pretrained("google/tapas-base")
  601. >>> data = {
  602. ... "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
  603. ... "Age": ["56", "45", "59"],
  604. ... "Number of movies": ["87", "53", "69"],
  605. ... }
  606. >>> table = pd.DataFrame.from_dict(data)
  607. >>> inputs = tokenizer(
  608. ... table=table, queries="How many [MASK] has George [MASK] played in?", return_tensors="pt"
  609. ... )
  610. >>> labels = tokenizer(
  611. ... table=table, queries="How many movies has George Clooney played in?", return_tensors="pt"
  612. ... )["input_ids"]
  613. >>> outputs = model(**inputs, labels=labels)
  614. >>> logits = outputs.logits
  615. ```"""
  616. return_dict = return_dict if return_dict is not None else self.config.return_dict
  617. outputs = self.tapas(
  618. input_ids,
  619. attention_mask=attention_mask,
  620. token_type_ids=token_type_ids,
  621. position_ids=position_ids,
  622. inputs_embeds=inputs_embeds,
  623. encoder_hidden_states=encoder_hidden_states,
  624. encoder_attention_mask=encoder_attention_mask,
  625. output_attentions=output_attentions,
  626. output_hidden_states=output_hidden_states,
  627. return_dict=return_dict,
  628. )
  629. sequence_output = outputs[0]
  630. prediction_scores = self.cls(sequence_output)
  631. masked_lm_loss = None
  632. if labels is not None:
  633. loss_fct = CrossEntropyLoss() # -100 index = padding token
  634. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  635. if not return_dict:
  636. output = (prediction_scores,) + outputs[2:]
  637. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  638. return MaskedLMOutput(
  639. loss=masked_lm_loss,
  640. logits=prediction_scores,
  641. hidden_states=outputs.hidden_states,
  642. attentions=outputs.attentions,
  643. )
  644. @auto_docstring(
  645. custom_intro="""
  646. Tapas Model with a cell selection head and optional aggregation head on top for question-answering tasks on tables
  647. (linear layers on top of the hidden-states output to compute `logits` and optional `logits_aggregation`), e.g. for
  648. SQA, WTQ or WikiSQL-supervised tasks.
  649. """
  650. )
  651. class TapasForQuestionAnswering(TapasPreTrainedModel):
  652. def __init__(self, config: TapasConfig):
  653. super().__init__(config)
  654. # base model
  655. self.tapas = TapasModel(config)
  656. # dropout (only used when training)
  657. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  658. # cell selection heads
  659. self.output_weights = nn.Parameter(torch.empty(config.hidden_size))
  660. self.column_output_weights = nn.Parameter(torch.empty(config.hidden_size))
  661. self.output_bias = nn.Parameter(torch.empty([]))
  662. self.column_output_bias = nn.Parameter(torch.empty([]))
  663. # aggregation head
  664. if config.num_aggregation_labels > 0:
  665. self.aggregation_classifier = nn.Linear(config.hidden_size, config.num_aggregation_labels)
  666. # Initialize weights and apply final processing
  667. self.post_init()
  668. @auto_docstring
  669. def forward(
  670. self,
  671. input_ids: torch.LongTensor | None = None,
  672. attention_mask: torch.FloatTensor | None = None,
  673. token_type_ids: torch.LongTensor | None = None,
  674. position_ids: torch.LongTensor | None = None,
  675. inputs_embeds: torch.FloatTensor | None = None,
  676. table_mask: torch.LongTensor | None = None,
  677. labels: torch.LongTensor | None = None,
  678. aggregation_labels: torch.LongTensor | None = None,
  679. float_answer: torch.FloatTensor | None = None,
  680. numeric_values: torch.FloatTensor | None = None,
  681. numeric_values_scale: torch.FloatTensor | None = None,
  682. output_attentions: bool | None = None,
  683. output_hidden_states: bool | None = None,
  684. return_dict: bool | None = None,
  685. **kwargs,
  686. ) -> tuple | TableQuestionAnsweringOutput:
  687. r"""
  688. token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, 7)`, *optional*):
  689. Token indices that encode tabular structure. Indices can be obtained using [`AutoTokenizer`]. See this
  690. class for more info.
  691. [What are token type IDs?](../glossary#token-type-ids)
  692. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  693. Indices of positions of each input sequence tokens in the position embeddings. If
  694. `reset_position_index_per_cell` of [`TapasConfig`] is set to `True`, relative position embeddings will be
  695. used. Selected in the range `[0, config.max_position_embeddings - 1]`.
  696. [What are position IDs?](../glossary#position-ids)
  697. table_mask (`torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*):
  698. Mask for the table. Indicates which tokens belong to the table (1). Question tokens, table headers and
  699. padding are 0.
  700. labels (`torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*):
  701. Labels per token for computing the hierarchical cell selection loss. This encodes the positions of the
  702. answer appearing in the table. Can be obtained using [`AutoTokenizer`].
  703. - 1 for tokens that are **part of the answer**,
  704. - 0 for tokens that are **not part of the answer**.
  705. aggregation_labels (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
  706. Aggregation function index for every example in the batch for computing the aggregation loss. Indices
  707. should be in `[0, ..., config.num_aggregation_labels - 1]`. Only required in case of strong supervision for
  708. aggregation (WikiSQL-supervised).
  709. float_answer (`torch.FloatTensor` of shape `(batch_size, )`, *optional*):
  710. Float answer for every example in the batch. Set to *float('nan')* for cell selection questions. Only
  711. required in case of weak supervision (WTQ) to calculate the aggregate mask and regression loss.
  712. numeric_values (`torch.FloatTensor` of shape `(batch_size, seq_length)`, *optional*):
  713. Numeric values of every token, NaN for tokens which are not numeric values. Can be obtained using
  714. [`AutoTokenizer`]. Only required in case of weak supervision for aggregation (WTQ) to calculate the
  715. regression loss.
  716. numeric_values_scale (`torch.FloatTensor` of shape `(batch_size, seq_length)`, *optional*):
  717. Scale of the numeric values of every token. Can be obtained using [`AutoTokenizer`]. Only required in case
  718. of weak supervision for aggregation (WTQ) to calculate the regression loss.
  719. Examples:
  720. ```python
  721. >>> from transformers import AutoTokenizer, TapasForQuestionAnswering
  722. >>> import pandas as pd
  723. >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base-finetuned-wtq")
  724. >>> model = TapasForQuestionAnswering.from_pretrained("google/tapas-base-finetuned-wtq")
  725. >>> data = {
  726. ... "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
  727. ... "Age": ["56", "45", "59"],
  728. ... "Number of movies": ["87", "53", "69"],
  729. ... }
  730. >>> table = pd.DataFrame.from_dict(data)
  731. >>> queries = ["How many movies has George Clooney played in?", "How old is Brad Pitt?"]
  732. >>> inputs = tokenizer(table=table, queries=queries, padding="max_length", return_tensors="pt")
  733. >>> outputs = model(**inputs)
  734. >>> logits = outputs.logits
  735. >>> logits_aggregation = outputs.logits_aggregation
  736. ```"""
  737. return_dict = return_dict if return_dict is not None else self.config.return_dict
  738. outputs = self.tapas(
  739. input_ids,
  740. attention_mask=attention_mask,
  741. token_type_ids=token_type_ids,
  742. position_ids=position_ids,
  743. inputs_embeds=inputs_embeds,
  744. output_attentions=output_attentions,
  745. output_hidden_states=output_hidden_states,
  746. return_dict=return_dict,
  747. )
  748. sequence_output = outputs[0]
  749. pooled_output = outputs[1]
  750. sequence_output = self.dropout(sequence_output)
  751. if input_ids is not None:
  752. input_shape = input_ids.size()
  753. else:
  754. input_shape = inputs_embeds.size()[:-1]
  755. device = input_ids.device if input_ids is not None else inputs_embeds.device
  756. # Construct indices for the table.
  757. if token_type_ids is None:
  758. token_type_ids = torch.zeros(
  759. (*input_shape, len(self.config.type_vocab_sizes)), dtype=torch.long, device=device
  760. )
  761. token_types = [
  762. "segment_ids",
  763. "column_ids",
  764. "row_ids",
  765. "prev_labels",
  766. "column_ranks",
  767. "inv_column_ranks",
  768. "numeric_relations",
  769. ]
  770. row_ids = token_type_ids[:, :, token_types.index("row_ids")]
  771. column_ids = token_type_ids[:, :, token_types.index("column_ids")]
  772. row_index = IndexMap(
  773. indices=torch.min(row_ids, torch.as_tensor(self.config.max_num_rows - 1, device=row_ids.device)),
  774. num_segments=self.config.max_num_rows,
  775. batch_dims=1,
  776. )
  777. col_index = IndexMap(
  778. indices=torch.min(column_ids, torch.as_tensor(self.config.max_num_columns - 1, device=column_ids.device)),
  779. num_segments=self.config.max_num_columns,
  780. batch_dims=1,
  781. )
  782. cell_index = ProductIndexMap(row_index, col_index)
  783. # Masks.
  784. input_shape = input_ids.size() if input_ids is not None else inputs_embeds.size()[:-1]
  785. device = input_ids.device if input_ids is not None else inputs_embeds.device
  786. if attention_mask is None:
  787. attention_mask = torch.ones(input_shape, device=device)
  788. # Table cells only, without question tokens and table headers.
  789. if table_mask is None:
  790. table_mask = torch.where(row_ids > 0, torch.ones_like(row_ids), torch.zeros_like(row_ids))
  791. # torch.FloatTensor[batch_size, seq_length]
  792. input_mask_float = attention_mask.to(device=device, dtype=torch.float)
  793. table_mask_float = table_mask.to(device=device, dtype=torch.float)
  794. # Mask for cells that exist in the table (i.e. that are not padding).
  795. cell_mask, _ = reduce_mean(input_mask_float, cell_index)
  796. # Compute logits per token. These are used to select individual cells.
  797. logits = compute_token_logits(sequence_output, self.config.temperature, self.output_weights, self.output_bias)
  798. # Compute logits per column. These are used to select a column.
  799. column_logits = None
  800. if self.config.select_one_column:
  801. column_logits = compute_column_logits(
  802. sequence_output,
  803. self.column_output_weights,
  804. self.column_output_bias,
  805. cell_index,
  806. cell_mask,
  807. self.config.allow_empty_column_selection,
  808. )
  809. # Aggregation logits
  810. logits_aggregation = None
  811. if self.config.num_aggregation_labels > 0:
  812. logits_aggregation = self.aggregation_classifier(pooled_output)
  813. # Total loss calculation
  814. total_loss = 0.0
  815. calculate_loss = False
  816. if labels is not None:
  817. calculate_loss = True
  818. is_supervised = not self.config.num_aggregation_labels > 0 or not self.config.use_answer_as_supervision
  819. # Semi-supervised cell selection in case of no aggregation:
  820. # If the answer (the denotation) appears directly in the table we might
  821. # select the answer without applying any aggregation function. There are
  822. # some ambiguous cases, see utils._calculate_aggregate_mask for more info.
  823. # `aggregate_mask` is 1 for examples where we chose to aggregate and 0
  824. # for examples where we chose to select the answer directly.
  825. # `labels` encodes the positions of the answer appearing in the table.
  826. if is_supervised:
  827. aggregate_mask = None
  828. else:
  829. if float_answer is not None:
  830. assert labels.shape[0] == float_answer.shape[0], (
  831. "Make sure the answers are a FloatTensor of shape (batch_size,)"
  832. )
  833. # <float32>[batch_size]
  834. aggregate_mask = _calculate_aggregate_mask(
  835. float_answer,
  836. pooled_output,
  837. self.config.cell_selection_preference,
  838. labels,
  839. self.aggregation_classifier,
  840. )
  841. else:
  842. raise ValueError("You have to specify float answers in order to calculate the aggregate mask")
  843. # Cell selection log-likelihood
  844. if self.config.average_logits_per_cell:
  845. logits_per_cell, _ = reduce_mean(logits, cell_index)
  846. logits = gather(logits_per_cell, cell_index)
  847. dist_per_token = torch.distributions.Bernoulli(logits=logits)
  848. # Compute cell selection loss per example.
  849. selection_loss_per_example = None
  850. if not self.config.select_one_column:
  851. weight = torch.where(
  852. labels == 0,
  853. torch.ones_like(labels, dtype=torch.float32),
  854. self.config.positive_label_weight * torch.ones_like(labels, dtype=torch.float32),
  855. )
  856. selection_loss_per_token = -dist_per_token.log_prob(labels) * weight
  857. selection_loss_per_example = torch.sum(selection_loss_per_token * input_mask_float, dim=1) / (
  858. torch.sum(input_mask_float, dim=1) + EPSILON_ZERO_DIVISION
  859. )
  860. else:
  861. selection_loss_per_example, logits = _single_column_cell_selection_loss(
  862. logits, column_logits, labels, cell_index, col_index, cell_mask
  863. )
  864. dist_per_token = torch.distributions.Bernoulli(logits=logits)
  865. # Supervised cell selection
  866. if self.config.disable_per_token_loss:
  867. pass
  868. elif is_supervised:
  869. total_loss += torch.mean(selection_loss_per_example)
  870. else:
  871. # For the not supervised case, do not assign loss for cell selection
  872. total_loss += torch.mean(selection_loss_per_example * (1.0 - aggregate_mask))
  873. # Semi-supervised regression loss and supervised loss for aggregations
  874. if self.config.num_aggregation_labels > 0:
  875. if is_supervised:
  876. # Note that `aggregate_mask` is None if the setting is supervised.
  877. if aggregation_labels is not None:
  878. assert labels.shape[0] == aggregation_labels.shape[0], (
  879. "Make sure the aggregation labels are a LongTensor of shape (batch_size,)"
  880. )
  881. per_example_additional_loss = _calculate_aggregation_loss(
  882. logits_aggregation,
  883. aggregate_mask,
  884. aggregation_labels,
  885. self.config.use_answer_as_supervision,
  886. self.config.num_aggregation_labels,
  887. self.config.aggregation_loss_weight,
  888. )
  889. else:
  890. raise ValueError(
  891. "You have to specify aggregation labels in order to calculate the aggregation loss"
  892. )
  893. else:
  894. # Set aggregation labels to zeros
  895. aggregation_labels = torch.zeros(labels.shape[0], dtype=torch.long, device=labels.device)
  896. per_example_additional_loss = _calculate_aggregation_loss(
  897. logits_aggregation,
  898. aggregate_mask,
  899. aggregation_labels,
  900. self.config.use_answer_as_supervision,
  901. self.config.num_aggregation_labels,
  902. self.config.aggregation_loss_weight,
  903. )
  904. if self.config.use_answer_as_supervision:
  905. if numeric_values is not None and numeric_values_scale is not None:
  906. assert numeric_values.shape == numeric_values_scale.shape
  907. # Add regression loss for numeric answers which require aggregation.
  908. answer_loss, large_answer_loss_mask = _calculate_regression_loss(
  909. float_answer,
  910. aggregate_mask,
  911. dist_per_token,
  912. numeric_values,
  913. numeric_values_scale,
  914. table_mask_float,
  915. logits_aggregation,
  916. self.config,
  917. )
  918. per_example_additional_loss += answer_loss
  919. # Zero loss for examples with answer_loss > cutoff.
  920. per_example_additional_loss *= large_answer_loss_mask
  921. else:
  922. raise ValueError(
  923. "You have to specify numeric values and numeric values scale in order to calculate the"
  924. " regression loss"
  925. )
  926. total_loss += torch.mean(per_example_additional_loss)
  927. else:
  928. # if no label ids are provided, set them to zeros in order to properly compute logits
  929. labels = torch.zeros_like(logits)
  930. _, logits = _single_column_cell_selection_loss(
  931. logits, column_logits, labels, cell_index, col_index, cell_mask
  932. )
  933. if not return_dict:
  934. output = (logits, logits_aggregation) + outputs[2:]
  935. return ((total_loss,) + output) if calculate_loss else output
  936. return TableQuestionAnsweringOutput(
  937. loss=total_loss if calculate_loss else None,
  938. logits=logits,
  939. logits_aggregation=logits_aggregation,
  940. hidden_states=outputs.hidden_states,
  941. attentions=outputs.attentions,
  942. )
  943. @auto_docstring(
  944. custom_intro="""
  945. Tapas Model with a sequence classification head on top (a linear layer on top of the pooled output), e.g. for table
  946. entailment tasks, such as TabFact (Chen et al., 2020).
  947. """
  948. )
  949. class TapasForSequenceClassification(TapasPreTrainedModel):
  950. def __init__(self, config):
  951. super().__init__(config)
  952. self.num_labels = config.num_labels
  953. self.tapas = TapasModel(config)
  954. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  955. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  956. # Initialize weights and apply final processing
  957. self.post_init()
  958. @auto_docstring
  959. def forward(
  960. self,
  961. input_ids: torch.LongTensor | None = None,
  962. attention_mask: torch.FloatTensor | None = None,
  963. token_type_ids: torch.LongTensor | None = None,
  964. position_ids: torch.LongTensor | None = None,
  965. inputs_embeds: torch.FloatTensor | None = None,
  966. labels: torch.LongTensor | None = None,
  967. output_attentions: bool | None = None,
  968. output_hidden_states: bool | None = None,
  969. return_dict: bool | None = None,
  970. **kwargs,
  971. ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
  972. r"""
  973. token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, 7)`, *optional*):
  974. Token indices that encode tabular structure. Indices can be obtained using [`AutoTokenizer`]. See this
  975. class for more info.
  976. [What are token type IDs?](../glossary#token-type-ids)
  977. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  978. Indices of positions of each input sequence tokens in the position embeddings. If
  979. `reset_position_index_per_cell` of [`TapasConfig`] is set to `True`, relative position embeddings will be
  980. used. Selected in the range `[0, config.max_position_embeddings - 1]`.
  981. [What are position IDs?](../glossary#position-ids)
  982. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  983. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  984. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  985. `config.num_labels > 1` a classification loss is computed (Cross-Entropy). Note: this is called
  986. "classification_class_index" in the original implementation.
  987. Examples:
  988. ```python
  989. >>> from transformers import AutoTokenizer, TapasForSequenceClassification
  990. >>> import torch
  991. >>> import pandas as pd
  992. >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base-finetuned-tabfact")
  993. >>> model = TapasForSequenceClassification.from_pretrained("google/tapas-base-finetuned-tabfact")
  994. >>> data = {
  995. ... "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
  996. ... "Age": ["56", "45", "59"],
  997. ... "Number of movies": ["87", "53", "69"],
  998. ... }
  999. >>> table = pd.DataFrame.from_dict(data)
  1000. >>> queries = [
  1001. ... "There is only one actor who is 45 years old",
  1002. ... "There are 3 actors which played in more than 60 movies",
  1003. ... ]
  1004. >>> inputs = tokenizer(table=table, queries=queries, padding="max_length", return_tensors="pt")
  1005. >>> labels = torch.tensor([1, 0]) # 1 means entailed, 0 means refuted
  1006. >>> outputs = model(**inputs, labels=labels)
  1007. >>> loss = outputs.loss
  1008. >>> logits = outputs.logits
  1009. ```"""
  1010. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1011. outputs = self.tapas(
  1012. input_ids,
  1013. attention_mask=attention_mask,
  1014. token_type_ids=token_type_ids,
  1015. position_ids=position_ids,
  1016. inputs_embeds=inputs_embeds,
  1017. output_attentions=output_attentions,
  1018. output_hidden_states=output_hidden_states,
  1019. return_dict=return_dict,
  1020. )
  1021. pooled_output = outputs[1]
  1022. pooled_output = self.dropout(pooled_output)
  1023. logits = self.classifier(pooled_output)
  1024. loss = None
  1025. if labels is not None:
  1026. if self.config.problem_type is None:
  1027. if self.num_labels == 1:
  1028. self.config.problem_type = "regression"
  1029. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1030. self.config.problem_type = "single_label_classification"
  1031. else:
  1032. self.config.problem_type = "multi_label_classification"
  1033. if self.config.problem_type == "regression":
  1034. loss_fct = MSELoss()
  1035. if self.num_labels == 1:
  1036. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1037. else:
  1038. loss = loss_fct(logits, labels)
  1039. elif self.config.problem_type == "single_label_classification":
  1040. loss_fct = CrossEntropyLoss()
  1041. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1042. elif self.config.problem_type == "multi_label_classification":
  1043. loss_fct = BCEWithLogitsLoss()
  1044. loss = loss_fct(logits, labels)
  1045. if not return_dict:
  1046. output = (logits,) + outputs[2:]
  1047. return ((loss,) + output) if loss is not None else output
  1048. return SequenceClassifierOutput(
  1049. loss=loss,
  1050. logits=logits,
  1051. hidden_states=outputs.hidden_states,
  1052. attentions=outputs.attentions,
  1053. )
  1054. """ TAPAS utilities."""
  1055. class AverageApproximationFunction(str, enum.Enum):
  1056. RATIO = "ratio"
  1057. FIRST_ORDER = "first_order"
  1058. SECOND_ORDER = "second_order"
  1059. # Beginning of everything related to segmented tensors
  1060. class IndexMap:
  1061. """Index grouping entries within a tensor."""
  1062. def __init__(self, indices, num_segments, batch_dims=0):
  1063. """
  1064. Creates an index
  1065. Args:
  1066. indices (`torch.LongTensor`, same shape as a *values* Tensor to which the indices refer):
  1067. Tensor containing the indices.
  1068. num_segments (`torch.LongTensor`):
  1069. Scalar tensor, the number of segments. All elements in a batched segmented tensor must have the same
  1070. number of segments (although many segments can be empty).
  1071. batch_dims (`int`, *optional*, defaults to 0):
  1072. The number of batch dimensions. The first *batch_dims* dimensions of a SegmentedTensor are treated as
  1073. batch dimensions. Segments in different batch elements are always distinct even if they have the same
  1074. index.
  1075. """
  1076. self.indices = torch.as_tensor(indices, device=indices.device)
  1077. self.num_segments = torch.as_tensor(num_segments, device=indices.device)
  1078. self.batch_dims = batch_dims
  1079. def batch_shape(self):
  1080. return self.indices.size()[: self.batch_dims] # returns a torch.Size object
  1081. class ProductIndexMap(IndexMap):
  1082. """The product of two indices."""
  1083. def __init__(self, outer_index, inner_index):
  1084. """
  1085. Combines indices i and j into pairs (i, j). The result is an index where each segment (i, j) is the
  1086. intersection of segments i and j. For example if the inputs represent table cells indexed by respectively rows
  1087. and columns the output will be a table indexed by (row, column) pairs, i.e. by cell. The implementation
  1088. combines indices {0, .., n - 1} and {0, .., m - 1} into {0, .., nm - 1}. The output has *num_segments* equal to
  1089. *outer_index.num_segments* * *inner_index.num_segments*
  1090. Args:
  1091. outer_index (`IndexMap`):
  1092. IndexMap.
  1093. inner_index (`IndexMap`):
  1094. IndexMap, must have the same shape as *outer_index*.
  1095. """
  1096. if outer_index.batch_dims != inner_index.batch_dims:
  1097. raise ValueError("outer_index.batch_dims and inner_index.batch_dims must be the same.")
  1098. super().__init__(
  1099. indices=(inner_index.indices + outer_index.indices * inner_index.num_segments),
  1100. num_segments=inner_index.num_segments * outer_index.num_segments,
  1101. batch_dims=inner_index.batch_dims,
  1102. )
  1103. self.outer_index = outer_index
  1104. self.inner_index = inner_index
  1105. def project_outer(self, index):
  1106. """Projects an index with the same index set onto the outer components."""
  1107. indices = torch.div(index.indices, self.inner_index.num_segments, rounding_mode="floor").type(torch.long)
  1108. return IndexMap(indices=indices, num_segments=self.outer_index.num_segments, batch_dims=index.batch_dims)
  1109. def project_inner(self, index):
  1110. """Projects an index with the same index set onto the inner components."""
  1111. return IndexMap(
  1112. indices=torch.fmod(index.indices, self.inner_index.num_segments)
  1113. .type(torch.float)
  1114. .floor()
  1115. .type(torch.long),
  1116. num_segments=self.inner_index.num_segments,
  1117. batch_dims=index.batch_dims,
  1118. )
  1119. def gather(values, index, name="segmented_gather"):
  1120. """
  1121. Gathers from *values* using the index map. For each element in the domain of the index map this operation looks up
  1122. a value for that index in *values*. Two elements from the same segment always get assigned the same value.
  1123. Args:
  1124. values (`torch.Tensor` of shape (B1, ..., Bn, num_segments, V1, ...)):
  1125. Tensor with segment values.
  1126. index (`IndexMap` of shape (B1, ..., Bn, I1, ..., Ik)):
  1127. IndexMap.
  1128. name (`str`, *optional*, defaults to 'segmented_gather'):
  1129. Name for the operation. Currently not used
  1130. Returns:
  1131. `tuple(torch.Tensor)`: Tensor of shape (B1, ..., Bn, I1, ..., Ik, V1, ...) with the gathered values.
  1132. """
  1133. indices = index.indices
  1134. # first, check whether the indices of the index represent scalar values (i.e. not vectorized)
  1135. if len(values.shape[index.batch_dims :]) < 2:
  1136. return torch.gather(
  1137. values,
  1138. index.batch_dims,
  1139. indices.view(
  1140. values.size()[0], -1
  1141. ), # torch.gather expects index to have the same number of dimensions as values
  1142. ).view(indices.size())
  1143. else:
  1144. # this means we have a vectorized version
  1145. # we have to adjust the index
  1146. indices = indices.unsqueeze(-1).expand(values.shape)
  1147. return torch.gather(values, index.batch_dims, indices)
  1148. def flatten(index, name="segmented_flatten"):
  1149. """
  1150. Flattens a batched index map (which is typically of shape batch_size, seq_length) to a 1d index map. This operation
  1151. relabels the segments to keep batch elements distinct. The k-th batch element will have indices shifted by
  1152. *num_segments* * (k - 1). The result is a tensor with *num_segments* multiplied by the number of elements in the
  1153. batch.
  1154. Args:
  1155. index (`IndexMap`):
  1156. IndexMap to flatten.
  1157. name (`str`, *optional*, defaults to 'segmented_flatten'):
  1158. Name for the operation. Currently not used
  1159. Returns:
  1160. (`IndexMap`): The flattened IndexMap.
  1161. """
  1162. # first, get batch_size as scalar tensor
  1163. batch_size = torch.prod(torch.tensor(list(index.batch_shape())))
  1164. # next, create offset as 1-D tensor of length batch_size,
  1165. # and multiply element-wise by num segments (to offset different elements in the batch) e.g. if batch size is 2: [0, 64]
  1166. offset = torch.arange(start=0, end=batch_size, device=index.num_segments.device) * index.num_segments
  1167. offset = offset.view(index.batch_shape())
  1168. for _ in range(index.batch_dims, len(index.indices.size())): # typically range(1,2)
  1169. offset = offset.unsqueeze(-1)
  1170. indices = offset + index.indices
  1171. return IndexMap(indices=indices.view(-1), num_segments=index.num_segments * batch_size, batch_dims=0)
  1172. def range_index_map(batch_shape, num_segments, name="range_index_map"):
  1173. """
  1174. Constructs an index map equal to range(num_segments).
  1175. Args:
  1176. batch_shape (`torch.Size`):
  1177. Batch shape
  1178. num_segments (`int`):
  1179. Number of segments
  1180. name (`str`, *optional*, defaults to 'range_index_map'):
  1181. Name for the operation. Currently not used
  1182. Returns:
  1183. (`IndexMap`): IndexMap of shape batch_shape with elements equal to range(num_segments).
  1184. """
  1185. device = num_segments.device if torch.is_tensor(num_segments) else "cpu"
  1186. batch_shape = torch.as_tensor(
  1187. batch_shape, dtype=torch.long, device=device
  1188. ) # create a rank 1 tensor vector containing batch_shape (e.g. [2])
  1189. assert len(batch_shape.size()) == 1
  1190. num_segments = torch.as_tensor(
  1191. num_segments, device=device
  1192. ) # create a rank 0 tensor (scalar) containing num_segments (e.g. 64)
  1193. assert len(num_segments.size()) == 0
  1194. indices = torch.arange(
  1195. start=0, end=num_segments, device=num_segments.device
  1196. ) # create a rank 1 vector with num_segments elements
  1197. new_tensor = torch.cat(
  1198. [torch.ones_like(batch_shape, dtype=torch.long, device=num_segments.device), num_segments.unsqueeze(dim=0)],
  1199. dim=0,
  1200. )
  1201. # new_tensor is just a vector of [1 64] for example (assuming only 1 batch dimension)
  1202. new_shape = [int(x) for x in new_tensor.tolist()]
  1203. indices = indices.view(new_shape)
  1204. multiples = torch.cat([batch_shape, torch.as_tensor([1], device=device)], dim=0)
  1205. indices = indices.repeat(multiples.tolist())
  1206. # equivalent (in Numpy:)
  1207. # indices = torch.as_tensor(np.tile(indices.numpy(), multiples.tolist()))
  1208. return IndexMap(indices=indices, num_segments=num_segments, batch_dims=list(batch_shape.size())[0])
  1209. def _segment_reduce(values, index, segment_reduce_fn, name):
  1210. """
  1211. Applies a segment reduction segment-wise.
  1212. Args:
  1213. values (`torch.Tensor`):
  1214. Tensor with segment values.
  1215. index (`IndexMap`):
  1216. IndexMap.
  1217. segment_reduce_fn (`str`):
  1218. Name for the reduce operation. One of "sum", "mean", "max" or "min".
  1219. name (`str`):
  1220. Name for the operation. Currently not used
  1221. Returns:
  1222. (`IndexMap`): IndexMap of shape batch_shape with elements equal to range(num_segments).
  1223. """
  1224. # Flatten the batch dimensions, as segments ops (scatter) do not support batching.
  1225. # However if `values` has extra dimensions to the right keep them
  1226. # unflattened. Segmented ops support vector-valued operations.
  1227. flat_index = flatten(index)
  1228. vector_shape = values.size()[len(index.indices.size()) :] # torch.Size object
  1229. flattened_shape = torch.cat(
  1230. [torch.as_tensor([-1], dtype=torch.long), torch.as_tensor(vector_shape, dtype=torch.long)], dim=0
  1231. )
  1232. # changed "view" by "reshape" in the following line
  1233. flat_values = values.reshape(flattened_shape.tolist())
  1234. out = torch.zeros(int(flat_index.num_segments), dtype=torch.float, device=flat_values.device)
  1235. segment_means = out.scatter_reduce(
  1236. dim=0, index=flat_index.indices.long(), src=flat_values.float(), reduce=segment_reduce_fn, include_self=False
  1237. )
  1238. device = index.num_segments.device
  1239. # Unflatten the values.
  1240. new_shape = torch.cat(
  1241. [
  1242. torch.as_tensor(index.batch_shape(), dtype=torch.long, device=device),
  1243. torch.as_tensor(index.num_segments, dtype=torch.long, device=device).unsqueeze(dim=0),
  1244. torch.as_tensor(vector_shape, dtype=torch.long, device=device),
  1245. ],
  1246. dim=0,
  1247. )
  1248. output_values = segment_means.clone().view(new_shape.tolist()).to(values.dtype)
  1249. output_index = range_index_map(index.batch_shape(), index.num_segments)
  1250. return output_values, output_index
  1251. def reduce_sum(values, index, name="segmented_reduce_sum"):
  1252. """
  1253. Sums a tensor over its segments.
  1254. Outputs 0 for empty segments.
  1255. This operations computes the sum over segments, with support for:
  1256. - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices.
  1257. - Vectorization using the last dimension [V1, V2, ...]. If they are present, the output will be a sum of
  1258. vectors rather than scalars. Only the middle dimensions [I1, ..., Ik] are reduced by the operation.
  1259. Args:
  1260. values (`torch.Tensor` of shape [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..]):
  1261. Tensor containing the values of which the sum must be taken segment-wise.
  1262. index (`IndexMap`, indices are of shape [B1, B2, ..., Bn, I1, .., Ik].):
  1263. Index defining the segments.
  1264. name (`str`, *optional*, defaults to 'segmented_reduce_sum'):
  1265. Name for the operation. Currently not used
  1266. Returns:
  1267. output_values (`torch.Tensor`of shape [B1, B2, ..., Bn, num_segments, V1, V2, ..]): Tensor containing the
  1268. output values. output_index (`IndexMap`): IndexMap with shape [B1, B2, ..., Bn, num_segments]. .
  1269. """
  1270. return _segment_reduce(values, index, "sum", name)
  1271. def reduce_mean(values, index, name="segmented_reduce_mean"):
  1272. """
  1273. Averages a tensor over its segments.
  1274. Outputs 0 for empty segments.
  1275. This operations computes the mean over segments, with support for:
  1276. - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices.
  1277. - Vectorization using the last dimension [V1, V2, ...]. If they are present, the output will be a mean of
  1278. vectors rather than scalars.
  1279. Only the middle dimensions [I1, ..., Ik] are reduced by the operation.
  1280. Args:
  1281. values (`torch.Tensor` of shape [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..]):
  1282. Tensor containing the values of which the mean must be taken segment-wise.
  1283. index (`IndexMap`, indices are of shape [B1, B2, ..., Bn, I1, .., Ik].):
  1284. Index defining the segments.
  1285. name (`str`, *optional*, defaults to 'segmented_reduce_sum'):
  1286. Name for the operation. Currently not used
  1287. Returns:
  1288. output_values (`torch.Tensor`of shape [B1, B2, ..., Bn, num_segments, V1, V2, ..]): Tensor containing the
  1289. output values. output_index (`IndexMap`): IndexMap with shape [B1, B2, ..., Bn, num_segments].
  1290. """
  1291. return _segment_reduce(values, index, "mean", name)
  1292. def reduce_max(values, index, name="segmented_reduce_max"):
  1293. """
  1294. Computes the maximum over segments.
  1295. This operation computes the maximum over segments, with support for:
  1296. - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices.
  1297. - Vectorization using the last dimension [V1, V2, ...]. If they are present, the output will be an element-wise
  1298. maximum of vectors rather than scalars.
  1299. Only the middle dimensions [I1, ..., Ik] are reduced by the operation.
  1300. Args:
  1301. values (`torch.Tensor` of shape [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..]):
  1302. Tensor containing the values of which the max must be taken segment-wise.
  1303. index (`IndexMap`, indices are of shape [B1, B2, ..., Bn, I1, .., Ik].):
  1304. Index defining the segments.
  1305. name (`str`, *optional*, defaults to 'segmented_reduce_sum'):
  1306. Name for the operation. Currently not used
  1307. Returns:
  1308. output_values (`torch.Tensor`of shape [B1, B2, ..., Bn, num_segments, V1, V2, ..]): Tensor containing the
  1309. output values. output_index (`IndexMap`): IndexMap with shape [B1, B2, ..., Bn, num_segments].
  1310. """
  1311. return _segment_reduce(values, index, "amax", name)
  1312. def reduce_min(values, index, name="segmented_reduce_min"):
  1313. """
  1314. Computes the minimum over segments.
  1315. This operations computes the minimum over segments, with support for:
  1316. - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices.
  1317. - Vectorization using the last dimension [V1, V2, ...]. If they are present, the output will be an element-wise
  1318. minimum of vectors rather than scalars.
  1319. Only the middle dimensions [I1, ..., Ik] are reduced by the operation.
  1320. Args:
  1321. values (`torch.Tensor` of shape [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..]):
  1322. Tensor containing the values of which the min must be taken segment-wise.
  1323. index (`IndexMap`, indices are of shape [B1, B2, ..., Bn, I1, .., Ik].):
  1324. Index defining the segments.
  1325. name (`str`, *optional*, defaults to 'segmented_reduce_sum'):
  1326. Name for the operation. Currently not used
  1327. Returns:
  1328. output_values (`torch.Tensor`of shape [B1, B2, ..., Bn, num_segments, V1, V2, ..]): Tensor containing the
  1329. output values. output_index (`IndexMap`): IndexMap with shape [B1, B2, ..., Bn, num_segments].
  1330. """
  1331. return _segment_reduce(values, index, "amin", name)
  1332. # End of everything related to segmented tensors
  1333. def compute_column_logits(
  1334. sequence_output, column_output_weights, column_output_bias, cell_index, cell_mask, allow_empty_column_selection
  1335. ):
  1336. """
  1337. Computes the column logits.
  1338. Args:
  1339. sequence_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  1340. Also known as last_hidden_state. Sequence of hidden-states at the output of the last layer of the model.
  1341. column_output_weights (`torch.FloatTensor` of shape `(hidden_size)`):
  1342. Weights of the linear layer for column selection.
  1343. column_output_bias (`torch.FloatTensor` of shape `()`):
  1344. Bias of the linear layer for column selection.
  1345. cell_index (`ProductIndexMap`):
  1346. Index that groups tokens into cells.
  1347. cell_mask (`torch.FloatTensor` of shape `(batch_size, max_num_rows * max_num_cols)`):
  1348. Mask for cells that exist in the table (i.e. that are not padding).
  1349. allow_empty_column_selection (`bool`):
  1350. Whether to allow not to select any column
  1351. Returns:
  1352. column_logits (`torch.FloatTensor`of shape `(batch_size, max_num_cols)`): Tensor containing the column logits
  1353. for every example in the batch.
  1354. """
  1355. # First, compute the token logits (batch_size, seq_len) - without temperature
  1356. token_logits = torch.einsum("bsj,j->bs", sequence_output, column_output_weights) + column_output_bias
  1357. # Next, average the logits per cell (batch_size, max_num_cols*max_num_rows)
  1358. cell_logits, cell_logits_index = reduce_mean(token_logits, cell_index)
  1359. # Finally, average the logits per column (batch_size, max_num_cols)
  1360. column_index = cell_index.project_inner(cell_logits_index)
  1361. column_logits, out_index = reduce_sum(cell_logits * cell_mask, column_index)
  1362. cell_count, _ = reduce_sum(cell_mask, column_index)
  1363. column_logits /= cell_count + EPSILON_ZERO_DIVISION
  1364. # Mask columns that do not appear in the example.
  1365. is_padding = torch.logical_and(cell_count < 0.5, ~torch.eq(out_index.indices, 0))
  1366. column_logits += CLOSE_ENOUGH_TO_LOG_ZERO * torch.as_tensor(
  1367. is_padding, dtype=torch.float32, device=is_padding.device
  1368. )
  1369. if not allow_empty_column_selection:
  1370. column_logits += CLOSE_ENOUGH_TO_LOG_ZERO * torch.as_tensor(
  1371. torch.eq(out_index.indices, 0), dtype=torch.float32, device=out_index.indices.device
  1372. )
  1373. return column_logits
  1374. def _single_column_cell_selection_loss(token_logits, column_logits, labels, cell_index, col_index, cell_mask):
  1375. """
  1376. Computes the loss for cell selection constrained to a single column. The loss is a hierarchical log-likelihood. The
  1377. model first predicts a column and then selects cells within that column (conditioned on the column). Cells outside
  1378. the selected column are never selected.
  1379. Args:
  1380. token_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  1381. Tensor containing the logits per token.
  1382. column_logits (`torch.FloatTensor` of shape `(batch_size, max_num_cols)`):
  1383. Tensor containing the logits per column.
  1384. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1385. Labels per token.
  1386. cell_index (`ProductIndexMap`):
  1387. Index that groups tokens into cells.
  1388. col_index (`IndexMap`):
  1389. Index that groups tokens into columns.
  1390. cell_mask (`torch.FloatTensor` of shape `(batch_size, max_num_rows * max_num_cols)`):
  1391. Mask for cells that exist in the table (i.e. that are not padding).
  1392. Returns:
  1393. selection_loss_per_example (`torch.FloatTensor` of shape `(batch_size,)`): Loss for each example. logits
  1394. (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): New logits which are only allowed to select
  1395. cells in a single column. Logits outside of the most likely column according to *column_logits* will be set to
  1396. a very low value (such that the probabilities are 0).
  1397. """
  1398. # Part 1: column loss
  1399. # First find the column we should select. We use the column with maximum number of selected cells.
  1400. labels_per_column, _ = reduce_sum(torch.as_tensor(labels, dtype=torch.float32, device=labels.device), col_index)
  1401. # shape of labels_per_column is (batch_size, max_num_cols). It contains the number of label ids for every column, for every example
  1402. column_label = torch.argmax(labels_per_column, dim=-1) # shape (batch_size,)
  1403. # Check if there are no selected cells in the column. In that case the model
  1404. # should predict the special column id 0, which means "select nothing".
  1405. no_cell_selected = torch.eq(
  1406. torch.max(labels_per_column, dim=-1)[0], 0
  1407. ) # no_cell_selected is of shape (batch_size,) and equals True
  1408. # if an example of the batch has no cells selected (i.e. if there are no labels set to 1 for that example)
  1409. column_label = torch.where(
  1410. no_cell_selected.view(column_label.size()), torch.zeros_like(column_label), column_label
  1411. )
  1412. column_dist = torch.distributions.Categorical(logits=column_logits) # shape (batch_size, max_num_cols)
  1413. column_loss_per_example = -column_dist.log_prob(column_label)
  1414. # Part 2: cell loss
  1415. # Reduce the labels and logits to per-cell from per-token.
  1416. # logits_per_cell: shape (batch_size, max_num_rows*max_num_cols) i.e. (batch_size, 64*32)
  1417. logits_per_cell, _ = reduce_mean(token_logits, cell_index)
  1418. # labels_per_cell: shape (batch_size, 64*32), indicating whether each cell should be selected (1) or not (0)
  1419. labels_per_cell, labels_index = reduce_max(
  1420. torch.as_tensor(labels, dtype=torch.long, device=labels.device), cell_index
  1421. )
  1422. # Mask for the selected column.
  1423. # column_id_for_cells: shape (batch_size, 64*32), indicating to which column each cell belongs
  1424. column_id_for_cells = cell_index.project_inner(labels_index).indices
  1425. # column_mask: shape (batch_size, 64*32), equal to 1 if cell belongs to column to be selected
  1426. column_mask = torch.as_tensor(
  1427. torch.eq(column_id_for_cells, torch.unsqueeze(column_label, dim=-1)),
  1428. dtype=torch.float32,
  1429. device=cell_mask.device,
  1430. )
  1431. # Compute the log-likelihood for cells, but only for the selected column.
  1432. cell_dist = torch.distributions.Bernoulli(logits=logits_per_cell) # shape (batch_size, 64*32)
  1433. cell_log_prob = cell_dist.log_prob(labels_per_cell.type(torch.float32)) # shape(batch_size, 64*32)
  1434. cell_loss = -torch.sum(cell_log_prob * column_mask * cell_mask, dim=1)
  1435. # We need to normalize the loss by the number of cells in the column.
  1436. cell_loss /= torch.sum(column_mask * cell_mask, dim=1) + EPSILON_ZERO_DIVISION
  1437. selection_loss_per_example = column_loss_per_example
  1438. selection_loss_per_example += torch.where(
  1439. no_cell_selected.view(selection_loss_per_example.size()),
  1440. torch.zeros_like(selection_loss_per_example),
  1441. cell_loss,
  1442. )
  1443. # Set the probs outside the selected column (selected by the *model*)
  1444. # to 0. This ensures backwards compatibility with models that select
  1445. # cells from multiple columns.
  1446. selected_column_id = torch.as_tensor(
  1447. torch.argmax(column_logits, dim=-1), dtype=torch.long, device=column_logits.device
  1448. ) # shape (batch_size,)
  1449. # selected_column_mask: shape (batch_size, 64*32), equal to 1 if cell belongs to column selected by the model
  1450. selected_column_mask = torch.as_tensor(
  1451. torch.eq(column_id_for_cells, torch.unsqueeze(selected_column_id, dim=-1)),
  1452. dtype=torch.float32,
  1453. device=selected_column_id.device,
  1454. )
  1455. # Never select cells with the special column id 0.
  1456. selected_column_mask = torch.where(
  1457. torch.eq(column_id_for_cells, 0).view(selected_column_mask.size()),
  1458. torch.zeros_like(selected_column_mask),
  1459. selected_column_mask,
  1460. )
  1461. new_logits_per_cell = logits_per_cell + CLOSE_ENOUGH_TO_LOG_ZERO * (1.0 - cell_mask * selected_column_mask)
  1462. logits = gather(new_logits_per_cell, cell_index)
  1463. return selection_loss_per_example, logits
  1464. def compute_token_logits(sequence_output, temperature, output_weights, output_bias):
  1465. """
  1466. Computes logits per token
  1467. Args:
  1468. sequence_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  1469. Also known as last_hidden_state. Sequence of hidden-states at the output of the last layer of the model.
  1470. temperature (`float`):
  1471. Temperature for the Bernoulli distribution.
  1472. output_weights (`torch.FloatTensor` of shape `(hidden_size,)`):
  1473. Weights of the linear layer for cell selection.
  1474. output_bias (`torch.FloatTensor` of shape `()`):
  1475. Bias of the linear layer for cell selection
  1476. Returns:
  1477. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): Logits per token.
  1478. """
  1479. logits = (torch.einsum("bsj,j->bs", sequence_output, output_weights) + output_bias) / temperature
  1480. return logits
  1481. def _calculate_aggregate_mask(answer, pooled_output, cell_selection_preference, labels, aggregation_classifier):
  1482. """
  1483. Finds examples where the model should select cells with no aggregation.
  1484. Returns a mask that determines for which examples should the model select answers directly from the table, without
  1485. any aggregation function. If the answer is a piece of text the case is unambiguous as aggregation functions only
  1486. apply to numbers. If the answer is a number but does not appear in the table then we must use some aggregation
  1487. case. The ambiguous case is when the answer is a number that also appears in the table. In this case we use the
  1488. aggregation function probabilities predicted by the model to decide whether to select or aggregate. The threshold
  1489. for this is a hyperparameter *cell_selection_preference*
  1490. Args:
  1491. answer (`torch.FloatTensor` of shape `(batch_size, )`):
  1492. Answer for every example in the batch. Nan if there is no scalar answer.
  1493. pooled_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
  1494. Output of the pooler (BertPooler) on top of the encoder layer.
  1495. cell_selection_preference (`float`):
  1496. Preference for cell selection in ambiguous cases.
  1497. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1498. Labels per token. aggregation_classifier (`torch.nn.Linear`): Aggregation head
  1499. Returns:
  1500. aggregate_mask (`torch.FloatTensor` of shape `(batch_size,)`): A mask set to 1 for examples that should use
  1501. aggregation functions.
  1502. """
  1503. # torch.FloatTensor(batch_size,)
  1504. aggregate_mask_init = torch.logical_not(torch.isnan(answer)).type(torch.FloatTensor).to(answer.device)
  1505. logits_aggregation = aggregation_classifier(pooled_output)
  1506. dist_aggregation = torch.distributions.categorical.Categorical(logits=logits_aggregation)
  1507. # Index 0 corresponds to "no aggregation".
  1508. aggregation_ops_total_mass = torch.sum(dist_aggregation.probs[:, 1:], dim=1)
  1509. # Cell selection examples according to current model.
  1510. is_pred_cell_selection = aggregation_ops_total_mass <= cell_selection_preference
  1511. # Examples with non-empty cell selection supervision.
  1512. is_cell_supervision_available = torch.sum(labels, dim=1) > 0
  1513. aggregate_mask = torch.where(
  1514. torch.logical_and(is_pred_cell_selection, is_cell_supervision_available).view(aggregate_mask_init.size()),
  1515. torch.zeros_like(aggregate_mask_init, dtype=torch.float32),
  1516. aggregate_mask_init,
  1517. )
  1518. aggregate_mask = aggregate_mask.detach()
  1519. return aggregate_mask
  1520. def _calculate_aggregation_loss_known(
  1521. logits_aggregation, aggregate_mask, aggregation_labels, use_answer_as_supervision, num_aggregation_labels
  1522. ):
  1523. """
  1524. Calculates aggregation loss when its type is known during training.
  1525. In the weakly supervised setting, the only known information is that for cell selection examples, "no aggregation"
  1526. should be predicted. For other examples (those that require aggregation), no loss is accumulated. In the setting
  1527. where aggregation type is always known, standard cross entropy loss is accumulated for all examples
  1528. Args:
  1529. logits_aggregation (`torch.FloatTensor` of shape `(batch_size, num_aggregation_labels)`):
  1530. Logits per aggregation operation.
  1531. aggregate_mask (`torch.FloatTensor` of shape `(batch_size, )`):
  1532. A mask set to 1 for examples that should use aggregation functions.
  1533. aggregation_labels (`torch.LongTensor` of shape `(batch_size, )`):
  1534. Aggregation function id for every example in the batch.
  1535. use_answer_as_supervision (`bool`, *optional*):
  1536. Whether to use the answer as the only supervision for aggregation examples.
  1537. num_aggregation_labels (`int`, *optional*, defaults to 0):
  1538. The number of aggregation operators to predict.
  1539. Returns:
  1540. aggregation_loss_known (`torch.FloatTensor` of shape `(batch_size,)`): Aggregation loss (when its type is known
  1541. during training) per example.
  1542. """
  1543. if use_answer_as_supervision:
  1544. # Prepare "no aggregation" targets for cell selection examples.
  1545. target_aggregation = torch.zeros_like(aggregate_mask, dtype=torch.long)
  1546. else:
  1547. # Use aggregation supervision as the target.
  1548. target_aggregation = aggregation_labels
  1549. one_hot_labels = nn.functional.one_hot(target_aggregation, num_classes=num_aggregation_labels).type(torch.float32)
  1550. log_probs = nn.functional.log_softmax(logits_aggregation, dim=-1)
  1551. # torch.FloatTensor[batch_size]
  1552. per_example_aggregation_intermediate = -torch.sum(one_hot_labels * log_probs, dim=-1)
  1553. if use_answer_as_supervision:
  1554. # Accumulate loss only for examples requiring cell selection
  1555. # (no aggregation).
  1556. return per_example_aggregation_intermediate * (1 - aggregate_mask)
  1557. else:
  1558. return per_example_aggregation_intermediate
  1559. def _calculate_aggregation_loss_unknown(logits_aggregation, aggregate_mask):
  1560. """
  1561. Calculates aggregation loss in the case of answer supervision.
  1562. Args:
  1563. logits_aggregation (`torch.FloatTensor` of shape `(batch_size, num_aggregation_labels)`):
  1564. Logits per aggregation operation.
  1565. aggregate_mask (`torch.FloatTensor` of shape `(batch_size, )`):
  1566. A mask set to 1 for examples that should use aggregation functions
  1567. Returns:
  1568. aggregation_loss_unknown (`torch.FloatTensor` of shape `(batch_size,)`): Aggregation loss (in case of answer
  1569. supervision) per example.
  1570. """
  1571. dist_aggregation = torch.distributions.categorical.Categorical(logits=logits_aggregation)
  1572. # Index 0 corresponds to "no aggregation".
  1573. aggregation_ops_total_mass = torch.sum(dist_aggregation.probs[:, 1:], dim=1)
  1574. # Predict some aggregation in case of an answer that needs aggregation.
  1575. # This increases the probability of all aggregation functions, in a way
  1576. # similar to MML, but without considering whether the function gives the
  1577. # correct answer.
  1578. return -torch.log(aggregation_ops_total_mass) * aggregate_mask
  1579. def _calculate_aggregation_loss(
  1580. logits_aggregation,
  1581. aggregate_mask,
  1582. aggregation_labels,
  1583. use_answer_as_supervision,
  1584. num_aggregation_labels,
  1585. aggregation_loss_weight,
  1586. ):
  1587. """
  1588. Calculates the aggregation loss per example.
  1589. Args:
  1590. logits_aggregation (`torch.FloatTensor` of shape `(batch_size, num_aggregation_labels)`):
  1591. Logits per aggregation operation.
  1592. aggregate_mask (`torch.FloatTensor` of shape `(batch_size, )`):
  1593. A mask set to 1 for examples that should use aggregation functions.
  1594. aggregation_labels (`torch.LongTensor` of shape `(batch_size, )`):
  1595. Aggregation function id for every example in the batch.
  1596. use_answer_as_supervision (`bool`, *optional*):
  1597. Whether to use the answer as the only supervision for aggregation examples.
  1598. num_aggregation_labels (`int`, *optional*, defaults to 0):
  1599. The number of aggregation operators to predict.
  1600. aggregation_loss_weight (`float`, *optional*, defaults to 1.0):
  1601. Importance weight for the aggregation loss.
  1602. Returns:
  1603. aggregation_loss (`torch.FloatTensor` of shape `(batch_size,)`): Aggregation loss per example.
  1604. """
  1605. per_example_aggregation_loss = _calculate_aggregation_loss_known(
  1606. logits_aggregation, aggregate_mask, aggregation_labels, use_answer_as_supervision, num_aggregation_labels
  1607. )
  1608. if use_answer_as_supervision:
  1609. # Add aggregation loss for numeric answers that need aggregation.
  1610. per_example_aggregation_loss += _calculate_aggregation_loss_unknown(logits_aggregation, aggregate_mask)
  1611. return aggregation_loss_weight * per_example_aggregation_loss
  1612. def _calculate_expected_result(
  1613. dist_per_cell, numeric_values, numeric_values_scale, input_mask_float, logits_aggregation, config
  1614. ):
  1615. """
  1616. Calculates the expected result given cell and aggregation probabilities.
  1617. Args:
  1618. dist_per_cell (`torch.distributions.Bernoulli`):
  1619. Cell selection distribution for each cell.
  1620. numeric_values (`torch.FloatTensor` of shape `(batch_size, seq_length)`):
  1621. Numeric values of every token. Nan for tokens which are not numeric values.
  1622. numeric_values_scale (`torch.FloatTensor` of shape `(batch_size, seq_length)`):
  1623. Scale of the numeric values of every token.
  1624. input_mask_float (`torch.FloatTensor` of shape `(batch_size, seq_length)`):
  1625. Mask for the table, without question tokens and table headers.
  1626. logits_aggregation (`torch.FloatTensor` of shape `(batch_size, num_aggregation_labels)`):
  1627. Logits per aggregation operation.
  1628. config ([`TapasConfig`]):
  1629. Model configuration class with all the hyperparameters of the model
  1630. Returns:
  1631. expected_result (`torch.FloatTensor` of shape `(batch_size,)`): The expected result per example.
  1632. """
  1633. if config.use_gumbel_for_cells:
  1634. gumbel_dist = torch.distributions.RelaxedBernoulli(
  1635. # The token logits where already divided by the temperature and used for
  1636. # computing cell selection errors so we need to multiply it again here
  1637. temperature=config.temperature,
  1638. logits=dist_per_cell.logits * config.temperature,
  1639. )
  1640. scaled_probability_per_cell = gumbel_dist.sample()
  1641. else:
  1642. scaled_probability_per_cell = dist_per_cell.probs
  1643. # <float32>[batch_size, seq_length]
  1644. scaled_probability_per_cell = (scaled_probability_per_cell / numeric_values_scale) * input_mask_float
  1645. count_result = torch.sum(scaled_probability_per_cell, dim=1)
  1646. numeric_values_masked = torch.where(
  1647. torch.isnan(numeric_values), torch.zeros_like(numeric_values), numeric_values
  1648. ) # Mask non-numeric table values to zero.
  1649. sum_result = torch.sum(scaled_probability_per_cell * numeric_values_masked, dim=1)
  1650. avg_approximation = config.average_approximation_function
  1651. if avg_approximation == AverageApproximationFunction.RATIO:
  1652. average_result = sum_result / (count_result + EPSILON_ZERO_DIVISION)
  1653. elif avg_approximation == AverageApproximationFunction.FIRST_ORDER:
  1654. # The sum of all probabilities except that correspond to other cells
  1655. # Ex here stands for expectation, more explicitly the expectation of the sum of N-1 Bernoulli random variables plus
  1656. # the constant 1, which is computed as adding all N expected values and subtracting the extra one. It corresponds to X_c
  1657. # in Appendix D of the original TAPAS paper which is trying to approximate the average of a random set.
  1658. ex = torch.sum(scaled_probability_per_cell, dim=1, keepdim=True) - scaled_probability_per_cell + 1
  1659. average_result = torch.sum(numeric_values_masked * scaled_probability_per_cell / ex, dim=1)
  1660. elif avg_approximation == AverageApproximationFunction.SECOND_ORDER:
  1661. # The sum of all probabilities except that correspond to other cells
  1662. ex = torch.sum(scaled_probability_per_cell, dim=1, keepdim=True) - scaled_probability_per_cell + 1
  1663. pointwise_var = scaled_probability_per_cell * (1 - scaled_probability_per_cell)
  1664. var = torch.sum(pointwise_var, dim=1, keepdim=True) - pointwise_var
  1665. multiplier = (var / torch.square(ex) + 1) / ex
  1666. average_result = torch.sum(numeric_values_masked * scaled_probability_per_cell * multiplier, dim=1)
  1667. else:
  1668. raise ValueError(f"Invalid average_approximation_function: {config.average_approximation_function}")
  1669. if config.use_gumbel_for_aggregation:
  1670. gumbel_dist = torch.distributions.RelaxedOneHotCategorical(
  1671. config.aggregation_temperature, logits=logits_aggregation[:, 1:]
  1672. )
  1673. # <float32>[batch_size, num_aggregation_labels - 1]
  1674. aggregation_op_only_probs = gumbel_dist.sample()
  1675. else:
  1676. # <float32>[batch_size, num_aggregation_labels - 1]
  1677. aggregation_op_only_probs = nn.functional.softmax(
  1678. logits_aggregation[:, 1:] / config.aggregation_temperature, dim=-1
  1679. )
  1680. all_results = torch.cat(
  1681. [
  1682. torch.unsqueeze(sum_result, dim=1),
  1683. torch.unsqueeze(average_result, dim=1),
  1684. torch.unsqueeze(count_result, dim=1),
  1685. ],
  1686. dim=1,
  1687. )
  1688. expected_result = torch.sum(all_results * aggregation_op_only_probs, dim=1)
  1689. return expected_result
  1690. # PyTorch does not currently support Huber loss with custom delta so we define it ourself
  1691. def huber_loss(input, target, delta: float = 1.0):
  1692. errors = torch.abs(input - target) # shape (batch_size,)
  1693. return torch.where(errors < delta, 0.5 * errors**2, errors * delta - (0.5 * delta**2))
  1694. def _calculate_regression_loss(
  1695. answer,
  1696. aggregate_mask,
  1697. dist_per_cell,
  1698. numeric_values,
  1699. numeric_values_scale,
  1700. input_mask_float,
  1701. logits_aggregation,
  1702. config,
  1703. ):
  1704. """
  1705. Calculates the regression loss per example.
  1706. Args:
  1707. answer (`torch.FloatTensor` of shape `(batch_size,)`):
  1708. Answer for every example in the batch. Nan if there is no scalar answer.
  1709. aggregate_mask (`torch.FloatTensor` of shape `(batch_size,)`):
  1710. A mask set to 1 for examples that should use aggregation functions.
  1711. dist_per_cell (`torch.distributions.Bernoulli`):
  1712. Cell selection distribution for each cell.
  1713. numeric_values (`torch.FloatTensor` of shape `(batch_size, seq_length)`):
  1714. Numeric values of every token. Nan for tokens which are not numeric values.
  1715. numeric_values_scale (`torch.FloatTensor` of shape `(batch_size, seq_length)`):
  1716. Scale of the numeric values of every token.
  1717. input_mask_float (`torch.FloatTensor` of shape `(batch_size, seq_length)`):
  1718. Mask for the table, without question tokens and table headers.
  1719. logits_aggregation (`torch.FloatTensor` of shape `(batch_size, num_aggregation_labels)`):
  1720. Logits per aggregation operation.
  1721. config ([`TapasConfig`]):
  1722. Model configuration class with all the parameters of the model
  1723. Returns:
  1724. per_example_answer_loss_scaled (`torch.FloatTensor` of shape `(batch_size,)`): Scales answer loss for each
  1725. example in the batch. large_answer_loss_mask (`torch.FloatTensor` of shape `(batch_size,)`): A mask which is 1
  1726. for examples for which their answer loss is larger than the answer_loss_cutoff.
  1727. """
  1728. # float32 (batch_size,)
  1729. expected_result = _calculate_expected_result(
  1730. dist_per_cell, numeric_values, numeric_values_scale, input_mask_float, logits_aggregation, config
  1731. )
  1732. # float32 (batch_size,)
  1733. answer_masked = torch.where(torch.isnan(answer), torch.zeros_like(answer), answer)
  1734. if config.use_normalized_answer_loss:
  1735. normalizer = (torch.max(torch.abs(expected_result), torch.abs(answer_masked)) + EPSILON_ZERO_DIVISION).detach()
  1736. normalized_answer_masked = answer_masked / normalizer
  1737. normalized_expected_result = expected_result / normalizer
  1738. per_example_answer_loss = huber_loss(
  1739. normalized_expected_result * aggregate_mask, normalized_answer_masked * aggregate_mask
  1740. )
  1741. else:
  1742. per_example_answer_loss = huber_loss(
  1743. expected_result * aggregate_mask, answer_masked * aggregate_mask, delta=config.huber_loss_delta
  1744. )
  1745. if config.answer_loss_cutoff is None:
  1746. large_answer_loss_mask = torch.ones_like(per_example_answer_loss, dtype=torch.float32)
  1747. else:
  1748. large_answer_loss_mask = torch.where(
  1749. per_example_answer_loss > config.answer_loss_cutoff,
  1750. torch.zeros_like(per_example_answer_loss, dtype=torch.float32),
  1751. torch.ones_like(per_example_answer_loss, dtype=torch.float32),
  1752. )
  1753. per_example_answer_loss_scaled = config.answer_loss_importance * (per_example_answer_loss * aggregate_mask)
  1754. return per_example_answer_loss_scaled, large_answer_loss_mask
  1755. __all__ = [
  1756. "TapasForMaskedLM",
  1757. "TapasForQuestionAnswering",
  1758. "TapasForSequenceClassification",
  1759. "TapasModel",
  1760. "TapasPreTrainedModel",
  1761. ]