| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105 |
- # Copyright 2020 Google Research and The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch TAPAS model."""
- import enum
- import math
- from dataclasses import dataclass
- import torch
- from torch import nn
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput
- from ...modeling_utils import PreTrainedModel
- from ...pytorch_utils import apply_chunking_to_forward
- from ...utils import ModelOutput, auto_docstring, logging
- from .configuration_tapas import TapasConfig
- logger = logging.get_logger(__name__)
- EPSILON_ZERO_DIVISION = 1e-10
- CLOSE_ENOUGH_TO_LOG_ZERO = -10000.0
- @dataclass
- @auto_docstring(
- custom_intro="""
- Output type of [`TapasForQuestionAnswering`].
- """
- )
- class TableQuestionAnsweringOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` (and possibly `answer`, `aggregation_labels`, `numeric_values` and `numeric_values_scale` are provided)):
- Total loss as the sum of the hierarchical cell selection log-likelihood loss and (optionally) the
- semi-supervised regression loss and (optionally) supervised loss for aggregations.
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
- Prediction scores of the cell selection head, for every token.
- logits_aggregation (`torch.FloatTensor`, *optional*, of shape `(batch_size, num_aggregation_labels)`):
- Prediction scores of the aggregation head, for every aggregation operator.
- """
- loss: torch.FloatTensor | None = None
- logits: torch.FloatTensor | None = None
- logits_aggregation: torch.FloatTensor | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- class TapasEmbeddings(nn.Module):
- """
- Construct the embeddings from word, position and token_type embeddings. Same as BertEmbeddings but with a number of
- additional token type embeddings to encode tabular structure.
- """
- def __init__(self, config):
- super().__init__()
- # we do not include config.disabled_features and config.disable_position_embeddings from the original implementation
- # word embeddings
- self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
- # position embeddings
- self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
- # token type embeddings
- for i, type_vocab_sizes in enumerate(config.type_vocab_sizes):
- name = f"token_type_embeddings_{i}"
- setattr(self, name, nn.Embedding(type_vocab_sizes, config.hidden_size))
- self.number_of_token_type_embeddings = len(config.type_vocab_sizes)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- self.config = config
- def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
- if input_ids is not None:
- input_shape = input_ids.size()
- else:
- input_shape = inputs_embeds.size()[:-1]
- seq_length = input_shape[1]
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- if position_ids is None:
- # create absolute position embeddings
- position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
- position_ids = position_ids.unsqueeze(0).expand(input_shape)
- # when self.config.reset_position_index_per_cell is set to True, create relative position embeddings
- if self.config.reset_position_index_per_cell:
- # shape (batch_size, seq_len)
- col_index = IndexMap(token_type_ids[:, :, 1], self.config.type_vocab_sizes[1], batch_dims=1)
- # shape (batch_size, seq_len)
- row_index = IndexMap(token_type_ids[:, :, 2], self.config.type_vocab_sizes[2], batch_dims=1)
- # shape (batch_size, seq_len)
- full_index = ProductIndexMap(col_index, row_index)
- # shape (max_rows * max_columns,). First absolute position for every cell
- first_position_per_segment = reduce_min(position_ids, full_index)[0]
- # ? shape (batch_size, seq_len). First absolute position of the cell for every token
- first_position = gather(first_position_per_segment, full_index)
- # shape (1, seq_len)
- position = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0)
- position_ids = torch.min(
- torch.as_tensor(self.config.max_position_embeddings - 1, device=device), position - first_position
- )
- if token_type_ids is None:
- token_type_ids = torch.zeros(
- (input_shape + self.number_of_token_type_embeddings), dtype=torch.long, device=device
- )
- if inputs_embeds is None:
- inputs_embeds = self.word_embeddings(input_ids)
- position_embeddings = self.position_embeddings(position_ids)
- embeddings = inputs_embeds + position_embeddings
- for i in range(self.number_of_token_type_embeddings):
- name = f"token_type_embeddings_{i}"
- embeddings += getattr(self, name)(token_type_ids[:, :, i])
- embeddings = self.LayerNorm(embeddings)
- embeddings = self.dropout(embeddings)
- return embeddings
- class TapasSelfAttention(nn.Module):
- def __init__(self, config, layer_idx=None):
- super().__init__()
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
- raise ValueError(
- f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
- f"heads {config.num_attention_heads}"
- )
- self.num_attention_heads = config.num_attention_heads
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
- self.all_head_size = self.num_attention_heads * self.attention_head_size
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
- self.key = nn.Linear(config.hidden_size, self.all_head_size)
- self.value = nn.Linear(config.hidden_size, self.all_head_size)
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
- self.is_decoder = config.is_decoder
- self.layer_idx = layer_idx
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- encoder_hidden_states=None,
- past_key_values=None,
- output_attentions=False,
- **kwargs,
- ):
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.attention_head_size)
- query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
- is_updated = False
- is_cross_attention = encoder_hidden_states is not None
- if past_key_values is not None:
- if isinstance(past_key_values, EncoderDecoderCache):
- is_updated = past_key_values.is_updated.get(self.layer_idx)
- if is_cross_attention:
- # after the first generated id, we can subsequently re-use all key/value_layer from cache
- curr_past_key_values = past_key_values.cross_attention_cache
- else:
- curr_past_key_values = past_key_values.self_attention_cache
- else:
- curr_past_key_values = past_key_values
- current_states = encoder_hidden_states if is_cross_attention else hidden_states
- if is_cross_attention and past_key_values is not None and is_updated:
- # reuse k,v, cross_attentions
- key_layer = curr_past_key_values.layers[self.layer_idx].keys
- value_layer = curr_past_key_values.layers[self.layer_idx].values
- else:
- key_layer = self.key(current_states).view(hidden_shape).transpose(1, 2)
- value_layer = self.value(current_states).view(hidden_shape).transpose(1, 2)
- if past_key_values is not None:
- # save all key/value_layer to cache to be re-used for fast auto-regressive generation
- key_layer, value_layer = curr_past_key_values.update(key_layer, value_layer, self.layer_idx)
- # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
- if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
- past_key_values.is_updated[self.layer_idx] = True
- # Take the dot product between "query" and "key" to get the raw attention scores.
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
- attention_scores = attention_scores / math.sqrt(self.attention_head_size)
- if attention_mask is not None:
- # Apply the attention mask is (precomputed for all layers in TapasModel forward() function)
- attention_scores = attention_scores + attention_mask
- # Normalize the attention scores to probabilities.
- attention_probs = nn.functional.softmax(attention_scores, dim=-1)
- # This is actually dropping out entire tokens to attend to, which might
- # seem a bit unusual, but is taken from the original Transformer paper.
- attention_probs = self.dropout(attention_probs)
- context_layer = torch.matmul(attention_probs, value_layer)
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(*new_context_layer_shape)
- outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
- if self.is_decoder:
- outputs = outputs + (past_key_values,)
- return outputs
- # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
- class TapasSelfOutput(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
- class TapasAttention(nn.Module):
- def __init__(self, config, layer_idx=None):
- super().__init__()
- self.self = TapasSelfAttention(config, layer_idx=layer_idx)
- self.output = TapasSelfOutput(config)
- # Copied from transformers.models.rembert.modeling_rembert.RemBertAttention.forward
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.FloatTensor | None = None,
- encoder_hidden_states: torch.FloatTensor | None = None,
- past_key_values: Cache | None = None,
- output_attentions: bool | None = False,
- **kwargs,
- ) -> tuple[torch.Tensor]:
- self_outputs = self.self(
- hidden_states,
- attention_mask=attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- past_key_values=past_key_values,
- output_attentions=output_attentions,
- )
- attention_output = self.output(self_outputs[0], hidden_states)
- outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
- return outputs
- # Copied from transformers.models.bert.modeling_bert.BertIntermediate
- class TapasIntermediate(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
- if isinstance(config.hidden_act, str):
- self.intermediate_act_fn = ACT2FN[config.hidden_act]
- else:
- self.intermediate_act_fn = config.hidden_act
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.intermediate_act_fn(hidden_states)
- return hidden_states
- # Copied from transformers.models.bert.modeling_bert.BertOutput
- class TapasOutput(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
- class TapasLayer(GradientCheckpointingLayer):
- def __init__(self, config, layer_idx=None):
- super().__init__()
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
- self.seq_len_dim = 1
- self.attention = TapasAttention(config, layer_idx=layer_idx)
- self.is_decoder = config.is_decoder
- self.add_cross_attention = config.add_cross_attention
- if self.add_cross_attention:
- if not self.is_decoder:
- raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
- self.crossattention = TapasAttention(config, layer_idx=layer_idx)
- self.intermediate = TapasIntermediate(config)
- self.output = TapasOutput(config)
- # Copied from transformers.models.rembert.modeling_rembert.RemBertLayer.forward
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.FloatTensor | None = None,
- encoder_hidden_states: torch.FloatTensor | None = None,
- encoder_attention_mask: torch.FloatTensor | None = None,
- past_key_values: Cache | None = None,
- output_attentions: bool | None = False,
- **kwargs,
- ) -> tuple[torch.Tensor]:
- self_attention_outputs = self.attention(
- hidden_states,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- past_key_values=past_key_values,
- )
- attention_output = self_attention_outputs[0]
- outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
- if self.is_decoder and encoder_hidden_states is not None:
- if not hasattr(self, "crossattention"):
- raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
- " by setting `config.add_cross_attention=True`"
- )
- cross_attention_outputs = self.crossattention(
- attention_output,
- attention_mask=encoder_attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- past_key_values=past_key_values,
- output_attentions=output_attentions,
- )
- attention_output = cross_attention_outputs[0]
- outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
- layer_output = apply_chunking_to_forward(
- self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
- )
- outputs = (layer_output,) + outputs
- return outputs
- # Copied from transformers.models.bert.modeling_bert.BertLayer.feed_forward_chunk
- def feed_forward_chunk(self, attention_output):
- intermediate_output = self.intermediate(attention_output)
- layer_output = self.output(intermediate_output, attention_output)
- return layer_output
- class TapasEncoder(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.layer = nn.ModuleList([TapasLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
- self.gradient_checkpointing = False
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- past_key_values=None,
- use_cache=None,
- output_attentions=False,
- output_hidden_states=False,
- return_dict=True,
- **kwargs,
- ):
- if use_cache and past_key_values is None:
- past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
- all_hidden_states = () if output_hidden_states else None
- all_attentions = () if output_attentions else None
- for i, layer_module in enumerate(self.layer):
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- layer_outputs = layer_module(
- hidden_states,
- attention_mask,
- encoder_hidden_states, # as a positional argument for gradient checkpointing
- encoder_attention_mask=encoder_attention_mask,
- past_key_values=past_key_values,
- output_attentions=output_attentions,
- )
- hidden_states = layer_outputs[0]
- if output_attentions:
- all_attentions = all_attentions + (layer_outputs[1],)
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- if not return_dict:
- return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
- return BaseModelOutput(
- last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
- )
- # Copied from transformers.models.bert.modeling_bert.BertPooler
- class TapasPooler(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- self.activation = nn.Tanh()
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- # We "pool" the model by simply taking the hidden state corresponding
- # to the first token.
- first_token_tensor = hidden_states[:, 0]
- pooled_output = self.dense(first_token_tensor)
- pooled_output = self.activation(pooled_output)
- return pooled_output
- # Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->Tapas
- class TapasPredictionHeadTransform(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- if isinstance(config.hidden_act, str):
- self.transform_act_fn = ACT2FN[config.hidden_act]
- else:
- self.transform_act_fn = config.hidden_act
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.transform_act_fn(hidden_states)
- hidden_states = self.LayerNorm(hidden_states)
- return hidden_states
- # Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->Tapas
- class TapasLMPredictionHead(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.transform = TapasPredictionHeadTransform(config)
- # The output weights are the same as the input embeddings, but there is
- # an output-only bias for each token.
- self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
- self.bias = nn.Parameter(torch.zeros(config.vocab_size))
- def forward(self, hidden_states):
- hidden_states = self.transform(hidden_states)
- hidden_states = self.decoder(hidden_states)
- return hidden_states
- # Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->Tapas
- class TapasOnlyMLMHead(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.predictions = TapasLMPredictionHead(config)
- def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
- prediction_scores = self.predictions(sequence_output)
- return prediction_scores
- @auto_docstring
- class TapasPreTrainedModel(PreTrainedModel):
- config: TapasConfig
- base_model_prefix = "tapas"
- supports_gradient_checkpointing = True
- @torch.no_grad()
- def _init_weights(self, module):
- """Initialize the weights"""
- super()._init_weights(module)
- if isinstance(module, TapasLMPredictionHead):
- init.zeros_(module.bias)
- if isinstance(module, TapasForQuestionAnswering):
- if module.config.init_cell_selection_weights_to_zero:
- init.zeros_(module.output_weights)
- init.zeros_(module.column_output_weights)
- else:
- init.normal_(module.output_weights, std=module.config.initializer_range)
- init.normal_(module.column_output_weights, std=module.config.initializer_range)
- init.zeros_(module.output_bias)
- init.zeros_(module.column_output_bias)
- @auto_docstring
- class TapasModel(TapasPreTrainedModel):
- """
- This class is a small change compared to [`BertModel`], taking into account the additional token type ids.
- The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
- cross-attention is added between the self-attention layers, following the architecture described in [Attention is
- all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
- Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
- """
- def __init__(self, config, add_pooling_layer=True):
- r"""
- add_pooling_layer (bool, *optional*, defaults to `True`):
- Whether to add a pooling layer
- """
- super().__init__(config)
- self.config = config
- self.embeddings = TapasEmbeddings(config)
- self.encoder = TapasEncoder(config)
- self.pooler = TapasPooler(config) if add_pooling_layer else None
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.embeddings.word_embeddings
- def set_input_embeddings(self, value):
- self.embeddings.word_embeddings = value
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- encoder_hidden_states: torch.FloatTensor | None = None,
- encoder_attention_mask: torch.FloatTensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | BaseModelOutputWithPooling:
- r"""
- token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, 7)`, *optional*):
- Token indices that encode tabular structure. Indices can be obtained using [`AutoTokenizer`]. See this
- class for more info.
- [What are token type IDs?](../glossary#token-type-ids)
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. If
- `reset_position_index_per_cell` of [`TapasConfig`] is set to `True`, relative position embeddings will be
- used. Selected in the range `[0, config.max_position_embeddings - 1]`.
- [What are position IDs?](../glossary#position-ids)
- Examples:
- ```python
- >>> from transformers import AutoTokenizer, TapasModel
- >>> import pandas as pd
- >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base")
- >>> model = TapasModel.from_pretrained("google/tapas-base")
- >>> data = {
- ... "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
- ... "Age": ["56", "45", "59"],
- ... "Number of movies": ["87", "53", "69"],
- ... }
- >>> table = pd.DataFrame.from_dict(data)
- >>> queries = ["How many movies has George Clooney played in?", "How old is Brad Pitt?"]
- >>> inputs = tokenizer(table=table, queries=queries, padding="max_length", return_tensors="pt")
- >>> outputs = model(**inputs)
- >>> last_hidden_states = outputs.last_hidden_state
- ```"""
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
- elif input_ids is not None:
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
- input_shape = input_ids.size()
- elif inputs_embeds is not None:
- input_shape = inputs_embeds.size()[:-1]
- else:
- raise ValueError("You have to specify either input_ids or inputs_embeds")
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- if attention_mask is None:
- attention_mask = torch.ones(input_shape, device=device)
- if token_type_ids is None:
- token_type_ids = torch.zeros(
- (*input_shape, len(self.config.type_vocab_sizes)), dtype=torch.long, device=device
- )
- # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
- # ourselves in which case we just need to make it broadcastable to all heads.
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
- # If a 2D ou 3D attention mask is provided for the cross-attention
- # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
- if self.config.is_decoder and encoder_hidden_states is not None:
- encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
- encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
- if encoder_attention_mask is None:
- encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
- else:
- encoder_extended_attention_mask = None
- embedding_output = self.embeddings(
- input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
- )
- encoder_outputs = self.encoder(
- embedding_output,
- attention_mask=extended_attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_extended_attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = encoder_outputs[0]
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
- if not return_dict:
- return (sequence_output, pooled_output) + encoder_outputs[1:]
- return BaseModelOutputWithPooling(
- last_hidden_state=sequence_output,
- pooler_output=pooled_output,
- hidden_states=encoder_outputs.hidden_states,
- attentions=encoder_outputs.attentions,
- )
- @auto_docstring
- class TapasForMaskedLM(TapasPreTrainedModel):
- _tied_weights_keys = {
- "cls.predictions.decoder.bias": "cls.predictions.bias",
- "cls.predictions.decoder.weight": "tapas.embeddings.word_embeddings.weight",
- }
- config: TapasConfig
- base_model_prefix = "tapas"
- def __init__(self, config):
- super().__init__(config)
- self.tapas = TapasModel(config, add_pooling_layer=False)
- self.cls = TapasOnlyMLMHead(config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_output_embeddings(self):
- return self.cls.predictions.decoder
- def set_output_embeddings(self, new_embeddings):
- self.cls.predictions.decoder = new_embeddings
- self.cls.predictions.bias = new_embeddings.bias
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- encoder_hidden_states: torch.FloatTensor | None = None,
- encoder_attention_mask: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | MaskedLMOutput:
- r"""
- token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, 7)`, *optional*):
- Token indices that encode tabular structure. Indices can be obtained using [`AutoTokenizer`]. See this
- class for more info.
- [What are token type IDs?](../glossary#token-type-ids)
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. If
- `reset_position_index_per_cell` of [`TapasConfig`] is set to `True`, relative position embeddings will be
- used. Selected in the range `[0, config.max_position_embeddings - 1]`.
- [What are position IDs?](../glossary#position-ids)
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
- config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
- loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
- Examples:
- ```python
- >>> from transformers import AutoTokenizer, TapasForMaskedLM
- >>> import pandas as pd
- >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base")
- >>> model = TapasForMaskedLM.from_pretrained("google/tapas-base")
- >>> data = {
- ... "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
- ... "Age": ["56", "45", "59"],
- ... "Number of movies": ["87", "53", "69"],
- ... }
- >>> table = pd.DataFrame.from_dict(data)
- >>> inputs = tokenizer(
- ... table=table, queries="How many [MASK] has George [MASK] played in?", return_tensors="pt"
- ... )
- >>> labels = tokenizer(
- ... table=table, queries="How many movies has George Clooney played in?", return_tensors="pt"
- ... )["input_ids"]
- >>> outputs = model(**inputs, labels=labels)
- >>> logits = outputs.logits
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.tapas(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = outputs[0]
- prediction_scores = self.cls(sequence_output)
- masked_lm_loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss() # -100 index = padding token
- masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
- if not return_dict:
- output = (prediction_scores,) + outputs[2:]
- return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
- return MaskedLMOutput(
- loss=masked_lm_loss,
- logits=prediction_scores,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring(
- custom_intro="""
- Tapas Model with a cell selection head and optional aggregation head on top for question-answering tasks on tables
- (linear layers on top of the hidden-states output to compute `logits` and optional `logits_aggregation`), e.g. for
- SQA, WTQ or WikiSQL-supervised tasks.
- """
- )
- class TapasForQuestionAnswering(TapasPreTrainedModel):
- def __init__(self, config: TapasConfig):
- super().__init__(config)
- # base model
- self.tapas = TapasModel(config)
- # dropout (only used when training)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- # cell selection heads
- self.output_weights = nn.Parameter(torch.empty(config.hidden_size))
- self.column_output_weights = nn.Parameter(torch.empty(config.hidden_size))
- self.output_bias = nn.Parameter(torch.empty([]))
- self.column_output_bias = nn.Parameter(torch.empty([]))
- # aggregation head
- if config.num_aggregation_labels > 0:
- self.aggregation_classifier = nn.Linear(config.hidden_size, config.num_aggregation_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- table_mask: torch.LongTensor | None = None,
- labels: torch.LongTensor | None = None,
- aggregation_labels: torch.LongTensor | None = None,
- float_answer: torch.FloatTensor | None = None,
- numeric_values: torch.FloatTensor | None = None,
- numeric_values_scale: torch.FloatTensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | TableQuestionAnsweringOutput:
- r"""
- token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, 7)`, *optional*):
- Token indices that encode tabular structure. Indices can be obtained using [`AutoTokenizer`]. See this
- class for more info.
- [What are token type IDs?](../glossary#token-type-ids)
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. If
- `reset_position_index_per_cell` of [`TapasConfig`] is set to `True`, relative position embeddings will be
- used. Selected in the range `[0, config.max_position_embeddings - 1]`.
- [What are position IDs?](../glossary#position-ids)
- table_mask (`torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*):
- Mask for the table. Indicates which tokens belong to the table (1). Question tokens, table headers and
- padding are 0.
- labels (`torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*):
- Labels per token for computing the hierarchical cell selection loss. This encodes the positions of the
- answer appearing in the table. Can be obtained using [`AutoTokenizer`].
- - 1 for tokens that are **part of the answer**,
- - 0 for tokens that are **not part of the answer**.
- aggregation_labels (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
- Aggregation function index for every example in the batch for computing the aggregation loss. Indices
- should be in `[0, ..., config.num_aggregation_labels - 1]`. Only required in case of strong supervision for
- aggregation (WikiSQL-supervised).
- float_answer (`torch.FloatTensor` of shape `(batch_size, )`, *optional*):
- Float answer for every example in the batch. Set to *float('nan')* for cell selection questions. Only
- required in case of weak supervision (WTQ) to calculate the aggregate mask and regression loss.
- numeric_values (`torch.FloatTensor` of shape `(batch_size, seq_length)`, *optional*):
- Numeric values of every token, NaN for tokens which are not numeric values. Can be obtained using
- [`AutoTokenizer`]. Only required in case of weak supervision for aggregation (WTQ) to calculate the
- regression loss.
- numeric_values_scale (`torch.FloatTensor` of shape `(batch_size, seq_length)`, *optional*):
- Scale of the numeric values of every token. Can be obtained using [`AutoTokenizer`]. Only required in case
- of weak supervision for aggregation (WTQ) to calculate the regression loss.
- Examples:
- ```python
- >>> from transformers import AutoTokenizer, TapasForQuestionAnswering
- >>> import pandas as pd
- >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base-finetuned-wtq")
- >>> model = TapasForQuestionAnswering.from_pretrained("google/tapas-base-finetuned-wtq")
- >>> data = {
- ... "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
- ... "Age": ["56", "45", "59"],
- ... "Number of movies": ["87", "53", "69"],
- ... }
- >>> table = pd.DataFrame.from_dict(data)
- >>> queries = ["How many movies has George Clooney played in?", "How old is Brad Pitt?"]
- >>> inputs = tokenizer(table=table, queries=queries, padding="max_length", return_tensors="pt")
- >>> outputs = model(**inputs)
- >>> logits = outputs.logits
- >>> logits_aggregation = outputs.logits_aggregation
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.tapas(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = outputs[0]
- pooled_output = outputs[1]
- sequence_output = self.dropout(sequence_output)
- if input_ids is not None:
- input_shape = input_ids.size()
- else:
- input_shape = inputs_embeds.size()[:-1]
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- # Construct indices for the table.
- if token_type_ids is None:
- token_type_ids = torch.zeros(
- (*input_shape, len(self.config.type_vocab_sizes)), dtype=torch.long, device=device
- )
- token_types = [
- "segment_ids",
- "column_ids",
- "row_ids",
- "prev_labels",
- "column_ranks",
- "inv_column_ranks",
- "numeric_relations",
- ]
- row_ids = token_type_ids[:, :, token_types.index("row_ids")]
- column_ids = token_type_ids[:, :, token_types.index("column_ids")]
- row_index = IndexMap(
- indices=torch.min(row_ids, torch.as_tensor(self.config.max_num_rows - 1, device=row_ids.device)),
- num_segments=self.config.max_num_rows,
- batch_dims=1,
- )
- col_index = IndexMap(
- indices=torch.min(column_ids, torch.as_tensor(self.config.max_num_columns - 1, device=column_ids.device)),
- num_segments=self.config.max_num_columns,
- batch_dims=1,
- )
- cell_index = ProductIndexMap(row_index, col_index)
- # Masks.
- input_shape = input_ids.size() if input_ids is not None else inputs_embeds.size()[:-1]
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- if attention_mask is None:
- attention_mask = torch.ones(input_shape, device=device)
- # Table cells only, without question tokens and table headers.
- if table_mask is None:
- table_mask = torch.where(row_ids > 0, torch.ones_like(row_ids), torch.zeros_like(row_ids))
- # torch.FloatTensor[batch_size, seq_length]
- input_mask_float = attention_mask.to(device=device, dtype=torch.float)
- table_mask_float = table_mask.to(device=device, dtype=torch.float)
- # Mask for cells that exist in the table (i.e. that are not padding).
- cell_mask, _ = reduce_mean(input_mask_float, cell_index)
- # Compute logits per token. These are used to select individual cells.
- logits = compute_token_logits(sequence_output, self.config.temperature, self.output_weights, self.output_bias)
- # Compute logits per column. These are used to select a column.
- column_logits = None
- if self.config.select_one_column:
- column_logits = compute_column_logits(
- sequence_output,
- self.column_output_weights,
- self.column_output_bias,
- cell_index,
- cell_mask,
- self.config.allow_empty_column_selection,
- )
- # Aggregation logits
- logits_aggregation = None
- if self.config.num_aggregation_labels > 0:
- logits_aggregation = self.aggregation_classifier(pooled_output)
- # Total loss calculation
- total_loss = 0.0
- calculate_loss = False
- if labels is not None:
- calculate_loss = True
- is_supervised = not self.config.num_aggregation_labels > 0 or not self.config.use_answer_as_supervision
- # Semi-supervised cell selection in case of no aggregation:
- # If the answer (the denotation) appears directly in the table we might
- # select the answer without applying any aggregation function. There are
- # some ambiguous cases, see utils._calculate_aggregate_mask for more info.
- # `aggregate_mask` is 1 for examples where we chose to aggregate and 0
- # for examples where we chose to select the answer directly.
- # `labels` encodes the positions of the answer appearing in the table.
- if is_supervised:
- aggregate_mask = None
- else:
- if float_answer is not None:
- assert labels.shape[0] == float_answer.shape[0], (
- "Make sure the answers are a FloatTensor of shape (batch_size,)"
- )
- # <float32>[batch_size]
- aggregate_mask = _calculate_aggregate_mask(
- float_answer,
- pooled_output,
- self.config.cell_selection_preference,
- labels,
- self.aggregation_classifier,
- )
- else:
- raise ValueError("You have to specify float answers in order to calculate the aggregate mask")
- # Cell selection log-likelihood
- if self.config.average_logits_per_cell:
- logits_per_cell, _ = reduce_mean(logits, cell_index)
- logits = gather(logits_per_cell, cell_index)
- dist_per_token = torch.distributions.Bernoulli(logits=logits)
- # Compute cell selection loss per example.
- selection_loss_per_example = None
- if not self.config.select_one_column:
- weight = torch.where(
- labels == 0,
- torch.ones_like(labels, dtype=torch.float32),
- self.config.positive_label_weight * torch.ones_like(labels, dtype=torch.float32),
- )
- selection_loss_per_token = -dist_per_token.log_prob(labels) * weight
- selection_loss_per_example = torch.sum(selection_loss_per_token * input_mask_float, dim=1) / (
- torch.sum(input_mask_float, dim=1) + EPSILON_ZERO_DIVISION
- )
- else:
- selection_loss_per_example, logits = _single_column_cell_selection_loss(
- logits, column_logits, labels, cell_index, col_index, cell_mask
- )
- dist_per_token = torch.distributions.Bernoulli(logits=logits)
- # Supervised cell selection
- if self.config.disable_per_token_loss:
- pass
- elif is_supervised:
- total_loss += torch.mean(selection_loss_per_example)
- else:
- # For the not supervised case, do not assign loss for cell selection
- total_loss += torch.mean(selection_loss_per_example * (1.0 - aggregate_mask))
- # Semi-supervised regression loss and supervised loss for aggregations
- if self.config.num_aggregation_labels > 0:
- if is_supervised:
- # Note that `aggregate_mask` is None if the setting is supervised.
- if aggregation_labels is not None:
- assert labels.shape[0] == aggregation_labels.shape[0], (
- "Make sure the aggregation labels are a LongTensor of shape (batch_size,)"
- )
- per_example_additional_loss = _calculate_aggregation_loss(
- logits_aggregation,
- aggregate_mask,
- aggregation_labels,
- self.config.use_answer_as_supervision,
- self.config.num_aggregation_labels,
- self.config.aggregation_loss_weight,
- )
- else:
- raise ValueError(
- "You have to specify aggregation labels in order to calculate the aggregation loss"
- )
- else:
- # Set aggregation labels to zeros
- aggregation_labels = torch.zeros(labels.shape[0], dtype=torch.long, device=labels.device)
- per_example_additional_loss = _calculate_aggregation_loss(
- logits_aggregation,
- aggregate_mask,
- aggregation_labels,
- self.config.use_answer_as_supervision,
- self.config.num_aggregation_labels,
- self.config.aggregation_loss_weight,
- )
- if self.config.use_answer_as_supervision:
- if numeric_values is not None and numeric_values_scale is not None:
- assert numeric_values.shape == numeric_values_scale.shape
- # Add regression loss for numeric answers which require aggregation.
- answer_loss, large_answer_loss_mask = _calculate_regression_loss(
- float_answer,
- aggregate_mask,
- dist_per_token,
- numeric_values,
- numeric_values_scale,
- table_mask_float,
- logits_aggregation,
- self.config,
- )
- per_example_additional_loss += answer_loss
- # Zero loss for examples with answer_loss > cutoff.
- per_example_additional_loss *= large_answer_loss_mask
- else:
- raise ValueError(
- "You have to specify numeric values and numeric values scale in order to calculate the"
- " regression loss"
- )
- total_loss += torch.mean(per_example_additional_loss)
- else:
- # if no label ids are provided, set them to zeros in order to properly compute logits
- labels = torch.zeros_like(logits)
- _, logits = _single_column_cell_selection_loss(
- logits, column_logits, labels, cell_index, col_index, cell_mask
- )
- if not return_dict:
- output = (logits, logits_aggregation) + outputs[2:]
- return ((total_loss,) + output) if calculate_loss else output
- return TableQuestionAnsweringOutput(
- loss=total_loss if calculate_loss else None,
- logits=logits,
- logits_aggregation=logits_aggregation,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring(
- custom_intro="""
- Tapas Model with a sequence classification head on top (a linear layer on top of the pooled output), e.g. for table
- entailment tasks, such as TabFact (Chen et al., 2020).
- """
- )
- class TapasForSequenceClassification(TapasPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.tapas = TapasModel(config)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
- r"""
- token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, 7)`, *optional*):
- Token indices that encode tabular structure. Indices can be obtained using [`AutoTokenizer`]. See this
- class for more info.
- [What are token type IDs?](../glossary#token-type-ids)
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. If
- `reset_position_index_per_cell` of [`TapasConfig`] is set to `True`, relative position embeddings will be
- used. Selected in the range `[0, config.max_position_embeddings - 1]`.
- [What are position IDs?](../glossary#position-ids)
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy). Note: this is called
- "classification_class_index" in the original implementation.
- Examples:
- ```python
- >>> from transformers import AutoTokenizer, TapasForSequenceClassification
- >>> import torch
- >>> import pandas as pd
- >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base-finetuned-tabfact")
- >>> model = TapasForSequenceClassification.from_pretrained("google/tapas-base-finetuned-tabfact")
- >>> data = {
- ... "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
- ... "Age": ["56", "45", "59"],
- ... "Number of movies": ["87", "53", "69"],
- ... }
- >>> table = pd.DataFrame.from_dict(data)
- >>> queries = [
- ... "There is only one actor who is 45 years old",
- ... "There are 3 actors which played in more than 60 movies",
- ... ]
- >>> inputs = tokenizer(table=table, queries=queries, padding="max_length", return_tensors="pt")
- >>> labels = torch.tensor([1, 0]) # 1 means entailed, 0 means refuted
- >>> outputs = model(**inputs, labels=labels)
- >>> loss = outputs.loss
- >>> logits = outputs.logits
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.tapas(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- pooled_output = outputs[1]
- pooled_output = self.dropout(pooled_output)
- logits = self.classifier(pooled_output)
- loss = None
- if labels is not None:
- if self.config.problem_type is None:
- if self.num_labels == 1:
- self.config.problem_type = "regression"
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
- self.config.problem_type = "single_label_classification"
- else:
- self.config.problem_type = "multi_label_classification"
- if self.config.problem_type == "regression":
- loss_fct = MSELoss()
- if self.num_labels == 1:
- loss = loss_fct(logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(logits, labels)
- elif self.config.problem_type == "single_label_classification":
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- elif self.config.problem_type == "multi_label_classification":
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(logits, labels)
- if not return_dict:
- output = (logits,) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return SequenceClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- """ TAPAS utilities."""
- class AverageApproximationFunction(str, enum.Enum):
- RATIO = "ratio"
- FIRST_ORDER = "first_order"
- SECOND_ORDER = "second_order"
- # Beginning of everything related to segmented tensors
- class IndexMap:
- """Index grouping entries within a tensor."""
- def __init__(self, indices, num_segments, batch_dims=0):
- """
- Creates an index
- Args:
- indices (`torch.LongTensor`, same shape as a *values* Tensor to which the indices refer):
- Tensor containing the indices.
- num_segments (`torch.LongTensor`):
- Scalar tensor, the number of segments. All elements in a batched segmented tensor must have the same
- number of segments (although many segments can be empty).
- batch_dims (`int`, *optional*, defaults to 0):
- The number of batch dimensions. The first *batch_dims* dimensions of a SegmentedTensor are treated as
- batch dimensions. Segments in different batch elements are always distinct even if they have the same
- index.
- """
- self.indices = torch.as_tensor(indices, device=indices.device)
- self.num_segments = torch.as_tensor(num_segments, device=indices.device)
- self.batch_dims = batch_dims
- def batch_shape(self):
- return self.indices.size()[: self.batch_dims] # returns a torch.Size object
- class ProductIndexMap(IndexMap):
- """The product of two indices."""
- def __init__(self, outer_index, inner_index):
- """
- Combines indices i and j into pairs (i, j). The result is an index where each segment (i, j) is the
- intersection of segments i and j. For example if the inputs represent table cells indexed by respectively rows
- and columns the output will be a table indexed by (row, column) pairs, i.e. by cell. The implementation
- combines indices {0, .., n - 1} and {0, .., m - 1} into {0, .., nm - 1}. The output has *num_segments* equal to
- *outer_index.num_segments* * *inner_index.num_segments*
- Args:
- outer_index (`IndexMap`):
- IndexMap.
- inner_index (`IndexMap`):
- IndexMap, must have the same shape as *outer_index*.
- """
- if outer_index.batch_dims != inner_index.batch_dims:
- raise ValueError("outer_index.batch_dims and inner_index.batch_dims must be the same.")
- super().__init__(
- indices=(inner_index.indices + outer_index.indices * inner_index.num_segments),
- num_segments=inner_index.num_segments * outer_index.num_segments,
- batch_dims=inner_index.batch_dims,
- )
- self.outer_index = outer_index
- self.inner_index = inner_index
- def project_outer(self, index):
- """Projects an index with the same index set onto the outer components."""
- indices = torch.div(index.indices, self.inner_index.num_segments, rounding_mode="floor").type(torch.long)
- return IndexMap(indices=indices, num_segments=self.outer_index.num_segments, batch_dims=index.batch_dims)
- def project_inner(self, index):
- """Projects an index with the same index set onto the inner components."""
- return IndexMap(
- indices=torch.fmod(index.indices, self.inner_index.num_segments)
- .type(torch.float)
- .floor()
- .type(torch.long),
- num_segments=self.inner_index.num_segments,
- batch_dims=index.batch_dims,
- )
- def gather(values, index, name="segmented_gather"):
- """
- Gathers from *values* using the index map. For each element in the domain of the index map this operation looks up
- a value for that index in *values*. Two elements from the same segment always get assigned the same value.
- Args:
- values (`torch.Tensor` of shape (B1, ..., Bn, num_segments, V1, ...)):
- Tensor with segment values.
- index (`IndexMap` of shape (B1, ..., Bn, I1, ..., Ik)):
- IndexMap.
- name (`str`, *optional*, defaults to 'segmented_gather'):
- Name for the operation. Currently not used
- Returns:
- `tuple(torch.Tensor)`: Tensor of shape (B1, ..., Bn, I1, ..., Ik, V1, ...) with the gathered values.
- """
- indices = index.indices
- # first, check whether the indices of the index represent scalar values (i.e. not vectorized)
- if len(values.shape[index.batch_dims :]) < 2:
- return torch.gather(
- values,
- index.batch_dims,
- indices.view(
- values.size()[0], -1
- ), # torch.gather expects index to have the same number of dimensions as values
- ).view(indices.size())
- else:
- # this means we have a vectorized version
- # we have to adjust the index
- indices = indices.unsqueeze(-1).expand(values.shape)
- return torch.gather(values, index.batch_dims, indices)
- def flatten(index, name="segmented_flatten"):
- """
- Flattens a batched index map (which is typically of shape batch_size, seq_length) to a 1d index map. This operation
- relabels the segments to keep batch elements distinct. The k-th batch element will have indices shifted by
- *num_segments* * (k - 1). The result is a tensor with *num_segments* multiplied by the number of elements in the
- batch.
- Args:
- index (`IndexMap`):
- IndexMap to flatten.
- name (`str`, *optional*, defaults to 'segmented_flatten'):
- Name for the operation. Currently not used
- Returns:
- (`IndexMap`): The flattened IndexMap.
- """
- # first, get batch_size as scalar tensor
- batch_size = torch.prod(torch.tensor(list(index.batch_shape())))
- # next, create offset as 1-D tensor of length batch_size,
- # and multiply element-wise by num segments (to offset different elements in the batch) e.g. if batch size is 2: [0, 64]
- offset = torch.arange(start=0, end=batch_size, device=index.num_segments.device) * index.num_segments
- offset = offset.view(index.batch_shape())
- for _ in range(index.batch_dims, len(index.indices.size())): # typically range(1,2)
- offset = offset.unsqueeze(-1)
- indices = offset + index.indices
- return IndexMap(indices=indices.view(-1), num_segments=index.num_segments * batch_size, batch_dims=0)
- def range_index_map(batch_shape, num_segments, name="range_index_map"):
- """
- Constructs an index map equal to range(num_segments).
- Args:
- batch_shape (`torch.Size`):
- Batch shape
- num_segments (`int`):
- Number of segments
- name (`str`, *optional*, defaults to 'range_index_map'):
- Name for the operation. Currently not used
- Returns:
- (`IndexMap`): IndexMap of shape batch_shape with elements equal to range(num_segments).
- """
- device = num_segments.device if torch.is_tensor(num_segments) else "cpu"
- batch_shape = torch.as_tensor(
- batch_shape, dtype=torch.long, device=device
- ) # create a rank 1 tensor vector containing batch_shape (e.g. [2])
- assert len(batch_shape.size()) == 1
- num_segments = torch.as_tensor(
- num_segments, device=device
- ) # create a rank 0 tensor (scalar) containing num_segments (e.g. 64)
- assert len(num_segments.size()) == 0
- indices = torch.arange(
- start=0, end=num_segments, device=num_segments.device
- ) # create a rank 1 vector with num_segments elements
- new_tensor = torch.cat(
- [torch.ones_like(batch_shape, dtype=torch.long, device=num_segments.device), num_segments.unsqueeze(dim=0)],
- dim=0,
- )
- # new_tensor is just a vector of [1 64] for example (assuming only 1 batch dimension)
- new_shape = [int(x) for x in new_tensor.tolist()]
- indices = indices.view(new_shape)
- multiples = torch.cat([batch_shape, torch.as_tensor([1], device=device)], dim=0)
- indices = indices.repeat(multiples.tolist())
- # equivalent (in Numpy:)
- # indices = torch.as_tensor(np.tile(indices.numpy(), multiples.tolist()))
- return IndexMap(indices=indices, num_segments=num_segments, batch_dims=list(batch_shape.size())[0])
- def _segment_reduce(values, index, segment_reduce_fn, name):
- """
- Applies a segment reduction segment-wise.
- Args:
- values (`torch.Tensor`):
- Tensor with segment values.
- index (`IndexMap`):
- IndexMap.
- segment_reduce_fn (`str`):
- Name for the reduce operation. One of "sum", "mean", "max" or "min".
- name (`str`):
- Name for the operation. Currently not used
- Returns:
- (`IndexMap`): IndexMap of shape batch_shape with elements equal to range(num_segments).
- """
- # Flatten the batch dimensions, as segments ops (scatter) do not support batching.
- # However if `values` has extra dimensions to the right keep them
- # unflattened. Segmented ops support vector-valued operations.
- flat_index = flatten(index)
- vector_shape = values.size()[len(index.indices.size()) :] # torch.Size object
- flattened_shape = torch.cat(
- [torch.as_tensor([-1], dtype=torch.long), torch.as_tensor(vector_shape, dtype=torch.long)], dim=0
- )
- # changed "view" by "reshape" in the following line
- flat_values = values.reshape(flattened_shape.tolist())
- out = torch.zeros(int(flat_index.num_segments), dtype=torch.float, device=flat_values.device)
- segment_means = out.scatter_reduce(
- dim=0, index=flat_index.indices.long(), src=flat_values.float(), reduce=segment_reduce_fn, include_self=False
- )
- device = index.num_segments.device
- # Unflatten the values.
- new_shape = torch.cat(
- [
- torch.as_tensor(index.batch_shape(), dtype=torch.long, device=device),
- torch.as_tensor(index.num_segments, dtype=torch.long, device=device).unsqueeze(dim=0),
- torch.as_tensor(vector_shape, dtype=torch.long, device=device),
- ],
- dim=0,
- )
- output_values = segment_means.clone().view(new_shape.tolist()).to(values.dtype)
- output_index = range_index_map(index.batch_shape(), index.num_segments)
- return output_values, output_index
- def reduce_sum(values, index, name="segmented_reduce_sum"):
- """
- Sums a tensor over its segments.
- Outputs 0 for empty segments.
- This operations computes the sum over segments, with support for:
- - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices.
- - Vectorization using the last dimension [V1, V2, ...]. If they are present, the output will be a sum of
- vectors rather than scalars. Only the middle dimensions [I1, ..., Ik] are reduced by the operation.
- Args:
- values (`torch.Tensor` of shape [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..]):
- Tensor containing the values of which the sum must be taken segment-wise.
- index (`IndexMap`, indices are of shape [B1, B2, ..., Bn, I1, .., Ik].):
- Index defining the segments.
- name (`str`, *optional*, defaults to 'segmented_reduce_sum'):
- Name for the operation. Currently not used
- Returns:
- output_values (`torch.Tensor`of shape [B1, B2, ..., Bn, num_segments, V1, V2, ..]): Tensor containing the
- output values. output_index (`IndexMap`): IndexMap with shape [B1, B2, ..., Bn, num_segments]. .
- """
- return _segment_reduce(values, index, "sum", name)
- def reduce_mean(values, index, name="segmented_reduce_mean"):
- """
- Averages a tensor over its segments.
- Outputs 0 for empty segments.
- This operations computes the mean over segments, with support for:
- - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices.
- - Vectorization using the last dimension [V1, V2, ...]. If they are present, the output will be a mean of
- vectors rather than scalars.
- Only the middle dimensions [I1, ..., Ik] are reduced by the operation.
- Args:
- values (`torch.Tensor` of shape [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..]):
- Tensor containing the values of which the mean must be taken segment-wise.
- index (`IndexMap`, indices are of shape [B1, B2, ..., Bn, I1, .., Ik].):
- Index defining the segments.
- name (`str`, *optional*, defaults to 'segmented_reduce_sum'):
- Name for the operation. Currently not used
- Returns:
- output_values (`torch.Tensor`of shape [B1, B2, ..., Bn, num_segments, V1, V2, ..]): Tensor containing the
- output values. output_index (`IndexMap`): IndexMap with shape [B1, B2, ..., Bn, num_segments].
- """
- return _segment_reduce(values, index, "mean", name)
- def reduce_max(values, index, name="segmented_reduce_max"):
- """
- Computes the maximum over segments.
- This operation computes the maximum over segments, with support for:
- - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices.
- - Vectorization using the last dimension [V1, V2, ...]. If they are present, the output will be an element-wise
- maximum of vectors rather than scalars.
- Only the middle dimensions [I1, ..., Ik] are reduced by the operation.
- Args:
- values (`torch.Tensor` of shape [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..]):
- Tensor containing the values of which the max must be taken segment-wise.
- index (`IndexMap`, indices are of shape [B1, B2, ..., Bn, I1, .., Ik].):
- Index defining the segments.
- name (`str`, *optional*, defaults to 'segmented_reduce_sum'):
- Name for the operation. Currently not used
- Returns:
- output_values (`torch.Tensor`of shape [B1, B2, ..., Bn, num_segments, V1, V2, ..]): Tensor containing the
- output values. output_index (`IndexMap`): IndexMap with shape [B1, B2, ..., Bn, num_segments].
- """
- return _segment_reduce(values, index, "amax", name)
- def reduce_min(values, index, name="segmented_reduce_min"):
- """
- Computes the minimum over segments.
- This operations computes the minimum over segments, with support for:
- - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices.
- - Vectorization using the last dimension [V1, V2, ...]. If they are present, the output will be an element-wise
- minimum of vectors rather than scalars.
- Only the middle dimensions [I1, ..., Ik] are reduced by the operation.
- Args:
- values (`torch.Tensor` of shape [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..]):
- Tensor containing the values of which the min must be taken segment-wise.
- index (`IndexMap`, indices are of shape [B1, B2, ..., Bn, I1, .., Ik].):
- Index defining the segments.
- name (`str`, *optional*, defaults to 'segmented_reduce_sum'):
- Name for the operation. Currently not used
- Returns:
- output_values (`torch.Tensor`of shape [B1, B2, ..., Bn, num_segments, V1, V2, ..]): Tensor containing the
- output values. output_index (`IndexMap`): IndexMap with shape [B1, B2, ..., Bn, num_segments].
- """
- return _segment_reduce(values, index, "amin", name)
- # End of everything related to segmented tensors
- def compute_column_logits(
- sequence_output, column_output_weights, column_output_bias, cell_index, cell_mask, allow_empty_column_selection
- ):
- """
- Computes the column logits.
- Args:
- sequence_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Also known as last_hidden_state. Sequence of hidden-states at the output of the last layer of the model.
- column_output_weights (`torch.FloatTensor` of shape `(hidden_size)`):
- Weights of the linear layer for column selection.
- column_output_bias (`torch.FloatTensor` of shape `()`):
- Bias of the linear layer for column selection.
- cell_index (`ProductIndexMap`):
- Index that groups tokens into cells.
- cell_mask (`torch.FloatTensor` of shape `(batch_size, max_num_rows * max_num_cols)`):
- Mask for cells that exist in the table (i.e. that are not padding).
- allow_empty_column_selection (`bool`):
- Whether to allow not to select any column
- Returns:
- column_logits (`torch.FloatTensor`of shape `(batch_size, max_num_cols)`): Tensor containing the column logits
- for every example in the batch.
- """
- # First, compute the token logits (batch_size, seq_len) - without temperature
- token_logits = torch.einsum("bsj,j->bs", sequence_output, column_output_weights) + column_output_bias
- # Next, average the logits per cell (batch_size, max_num_cols*max_num_rows)
- cell_logits, cell_logits_index = reduce_mean(token_logits, cell_index)
- # Finally, average the logits per column (batch_size, max_num_cols)
- column_index = cell_index.project_inner(cell_logits_index)
- column_logits, out_index = reduce_sum(cell_logits * cell_mask, column_index)
- cell_count, _ = reduce_sum(cell_mask, column_index)
- column_logits /= cell_count + EPSILON_ZERO_DIVISION
- # Mask columns that do not appear in the example.
- is_padding = torch.logical_and(cell_count < 0.5, ~torch.eq(out_index.indices, 0))
- column_logits += CLOSE_ENOUGH_TO_LOG_ZERO * torch.as_tensor(
- is_padding, dtype=torch.float32, device=is_padding.device
- )
- if not allow_empty_column_selection:
- column_logits += CLOSE_ENOUGH_TO_LOG_ZERO * torch.as_tensor(
- torch.eq(out_index.indices, 0), dtype=torch.float32, device=out_index.indices.device
- )
- return column_logits
- def _single_column_cell_selection_loss(token_logits, column_logits, labels, cell_index, col_index, cell_mask):
- """
- Computes the loss for cell selection constrained to a single column. The loss is a hierarchical log-likelihood. The
- model first predicts a column and then selects cells within that column (conditioned on the column). Cells outside
- the selected column are never selected.
- Args:
- token_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
- Tensor containing the logits per token.
- column_logits (`torch.FloatTensor` of shape `(batch_size, max_num_cols)`):
- Tensor containing the logits per column.
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Labels per token.
- cell_index (`ProductIndexMap`):
- Index that groups tokens into cells.
- col_index (`IndexMap`):
- Index that groups tokens into columns.
- cell_mask (`torch.FloatTensor` of shape `(batch_size, max_num_rows * max_num_cols)`):
- Mask for cells that exist in the table (i.e. that are not padding).
- Returns:
- selection_loss_per_example (`torch.FloatTensor` of shape `(batch_size,)`): Loss for each example. logits
- (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): New logits which are only allowed to select
- cells in a single column. Logits outside of the most likely column according to *column_logits* will be set to
- a very low value (such that the probabilities are 0).
- """
- # Part 1: column loss
- # First find the column we should select. We use the column with maximum number of selected cells.
- labels_per_column, _ = reduce_sum(torch.as_tensor(labels, dtype=torch.float32, device=labels.device), col_index)
- # shape of labels_per_column is (batch_size, max_num_cols). It contains the number of label ids for every column, for every example
- column_label = torch.argmax(labels_per_column, dim=-1) # shape (batch_size,)
- # Check if there are no selected cells in the column. In that case the model
- # should predict the special column id 0, which means "select nothing".
- no_cell_selected = torch.eq(
- torch.max(labels_per_column, dim=-1)[0], 0
- ) # no_cell_selected is of shape (batch_size,) and equals True
- # if an example of the batch has no cells selected (i.e. if there are no labels set to 1 for that example)
- column_label = torch.where(
- no_cell_selected.view(column_label.size()), torch.zeros_like(column_label), column_label
- )
- column_dist = torch.distributions.Categorical(logits=column_logits) # shape (batch_size, max_num_cols)
- column_loss_per_example = -column_dist.log_prob(column_label)
- # Part 2: cell loss
- # Reduce the labels and logits to per-cell from per-token.
- # logits_per_cell: shape (batch_size, max_num_rows*max_num_cols) i.e. (batch_size, 64*32)
- logits_per_cell, _ = reduce_mean(token_logits, cell_index)
- # labels_per_cell: shape (batch_size, 64*32), indicating whether each cell should be selected (1) or not (0)
- labels_per_cell, labels_index = reduce_max(
- torch.as_tensor(labels, dtype=torch.long, device=labels.device), cell_index
- )
- # Mask for the selected column.
- # column_id_for_cells: shape (batch_size, 64*32), indicating to which column each cell belongs
- column_id_for_cells = cell_index.project_inner(labels_index).indices
- # column_mask: shape (batch_size, 64*32), equal to 1 if cell belongs to column to be selected
- column_mask = torch.as_tensor(
- torch.eq(column_id_for_cells, torch.unsqueeze(column_label, dim=-1)),
- dtype=torch.float32,
- device=cell_mask.device,
- )
- # Compute the log-likelihood for cells, but only for the selected column.
- cell_dist = torch.distributions.Bernoulli(logits=logits_per_cell) # shape (batch_size, 64*32)
- cell_log_prob = cell_dist.log_prob(labels_per_cell.type(torch.float32)) # shape(batch_size, 64*32)
- cell_loss = -torch.sum(cell_log_prob * column_mask * cell_mask, dim=1)
- # We need to normalize the loss by the number of cells in the column.
- cell_loss /= torch.sum(column_mask * cell_mask, dim=1) + EPSILON_ZERO_DIVISION
- selection_loss_per_example = column_loss_per_example
- selection_loss_per_example += torch.where(
- no_cell_selected.view(selection_loss_per_example.size()),
- torch.zeros_like(selection_loss_per_example),
- cell_loss,
- )
- # Set the probs outside the selected column (selected by the *model*)
- # to 0. This ensures backwards compatibility with models that select
- # cells from multiple columns.
- selected_column_id = torch.as_tensor(
- torch.argmax(column_logits, dim=-1), dtype=torch.long, device=column_logits.device
- ) # shape (batch_size,)
- # selected_column_mask: shape (batch_size, 64*32), equal to 1 if cell belongs to column selected by the model
- selected_column_mask = torch.as_tensor(
- torch.eq(column_id_for_cells, torch.unsqueeze(selected_column_id, dim=-1)),
- dtype=torch.float32,
- device=selected_column_id.device,
- )
- # Never select cells with the special column id 0.
- selected_column_mask = torch.where(
- torch.eq(column_id_for_cells, 0).view(selected_column_mask.size()),
- torch.zeros_like(selected_column_mask),
- selected_column_mask,
- )
- new_logits_per_cell = logits_per_cell + CLOSE_ENOUGH_TO_LOG_ZERO * (1.0 - cell_mask * selected_column_mask)
- logits = gather(new_logits_per_cell, cell_index)
- return selection_loss_per_example, logits
- def compute_token_logits(sequence_output, temperature, output_weights, output_bias):
- """
- Computes logits per token
- Args:
- sequence_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Also known as last_hidden_state. Sequence of hidden-states at the output of the last layer of the model.
- temperature (`float`):
- Temperature for the Bernoulli distribution.
- output_weights (`torch.FloatTensor` of shape `(hidden_size,)`):
- Weights of the linear layer for cell selection.
- output_bias (`torch.FloatTensor` of shape `()`):
- Bias of the linear layer for cell selection
- Returns:
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): Logits per token.
- """
- logits = (torch.einsum("bsj,j->bs", sequence_output, output_weights) + output_bias) / temperature
- return logits
- def _calculate_aggregate_mask(answer, pooled_output, cell_selection_preference, labels, aggregation_classifier):
- """
- Finds examples where the model should select cells with no aggregation.
- Returns a mask that determines for which examples should the model select answers directly from the table, without
- any aggregation function. If the answer is a piece of text the case is unambiguous as aggregation functions only
- apply to numbers. If the answer is a number but does not appear in the table then we must use some aggregation
- case. The ambiguous case is when the answer is a number that also appears in the table. In this case we use the
- aggregation function probabilities predicted by the model to decide whether to select or aggregate. The threshold
- for this is a hyperparameter *cell_selection_preference*
- Args:
- answer (`torch.FloatTensor` of shape `(batch_size, )`):
- Answer for every example in the batch. Nan if there is no scalar answer.
- pooled_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
- Output of the pooler (BertPooler) on top of the encoder layer.
- cell_selection_preference (`float`):
- Preference for cell selection in ambiguous cases.
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Labels per token. aggregation_classifier (`torch.nn.Linear`): Aggregation head
- Returns:
- aggregate_mask (`torch.FloatTensor` of shape `(batch_size,)`): A mask set to 1 for examples that should use
- aggregation functions.
- """
- # torch.FloatTensor(batch_size,)
- aggregate_mask_init = torch.logical_not(torch.isnan(answer)).type(torch.FloatTensor).to(answer.device)
- logits_aggregation = aggregation_classifier(pooled_output)
- dist_aggregation = torch.distributions.categorical.Categorical(logits=logits_aggregation)
- # Index 0 corresponds to "no aggregation".
- aggregation_ops_total_mass = torch.sum(dist_aggregation.probs[:, 1:], dim=1)
- # Cell selection examples according to current model.
- is_pred_cell_selection = aggregation_ops_total_mass <= cell_selection_preference
- # Examples with non-empty cell selection supervision.
- is_cell_supervision_available = torch.sum(labels, dim=1) > 0
- aggregate_mask = torch.where(
- torch.logical_and(is_pred_cell_selection, is_cell_supervision_available).view(aggregate_mask_init.size()),
- torch.zeros_like(aggregate_mask_init, dtype=torch.float32),
- aggregate_mask_init,
- )
- aggregate_mask = aggregate_mask.detach()
- return aggregate_mask
- def _calculate_aggregation_loss_known(
- logits_aggregation, aggregate_mask, aggregation_labels, use_answer_as_supervision, num_aggregation_labels
- ):
- """
- Calculates aggregation loss when its type is known during training.
- In the weakly supervised setting, the only known information is that for cell selection examples, "no aggregation"
- should be predicted. For other examples (those that require aggregation), no loss is accumulated. In the setting
- where aggregation type is always known, standard cross entropy loss is accumulated for all examples
- Args:
- logits_aggregation (`torch.FloatTensor` of shape `(batch_size, num_aggregation_labels)`):
- Logits per aggregation operation.
- aggregate_mask (`torch.FloatTensor` of shape `(batch_size, )`):
- A mask set to 1 for examples that should use aggregation functions.
- aggregation_labels (`torch.LongTensor` of shape `(batch_size, )`):
- Aggregation function id for every example in the batch.
- use_answer_as_supervision (`bool`, *optional*):
- Whether to use the answer as the only supervision for aggregation examples.
- num_aggregation_labels (`int`, *optional*, defaults to 0):
- The number of aggregation operators to predict.
- Returns:
- aggregation_loss_known (`torch.FloatTensor` of shape `(batch_size,)`): Aggregation loss (when its type is known
- during training) per example.
- """
- if use_answer_as_supervision:
- # Prepare "no aggregation" targets for cell selection examples.
- target_aggregation = torch.zeros_like(aggregate_mask, dtype=torch.long)
- else:
- # Use aggregation supervision as the target.
- target_aggregation = aggregation_labels
- one_hot_labels = nn.functional.one_hot(target_aggregation, num_classes=num_aggregation_labels).type(torch.float32)
- log_probs = nn.functional.log_softmax(logits_aggregation, dim=-1)
- # torch.FloatTensor[batch_size]
- per_example_aggregation_intermediate = -torch.sum(one_hot_labels * log_probs, dim=-1)
- if use_answer_as_supervision:
- # Accumulate loss only for examples requiring cell selection
- # (no aggregation).
- return per_example_aggregation_intermediate * (1 - aggregate_mask)
- else:
- return per_example_aggregation_intermediate
- def _calculate_aggregation_loss_unknown(logits_aggregation, aggregate_mask):
- """
- Calculates aggregation loss in the case of answer supervision.
- Args:
- logits_aggregation (`torch.FloatTensor` of shape `(batch_size, num_aggregation_labels)`):
- Logits per aggregation operation.
- aggregate_mask (`torch.FloatTensor` of shape `(batch_size, )`):
- A mask set to 1 for examples that should use aggregation functions
- Returns:
- aggregation_loss_unknown (`torch.FloatTensor` of shape `(batch_size,)`): Aggregation loss (in case of answer
- supervision) per example.
- """
- dist_aggregation = torch.distributions.categorical.Categorical(logits=logits_aggregation)
- # Index 0 corresponds to "no aggregation".
- aggregation_ops_total_mass = torch.sum(dist_aggregation.probs[:, 1:], dim=1)
- # Predict some aggregation in case of an answer that needs aggregation.
- # This increases the probability of all aggregation functions, in a way
- # similar to MML, but without considering whether the function gives the
- # correct answer.
- return -torch.log(aggregation_ops_total_mass) * aggregate_mask
- def _calculate_aggregation_loss(
- logits_aggregation,
- aggregate_mask,
- aggregation_labels,
- use_answer_as_supervision,
- num_aggregation_labels,
- aggregation_loss_weight,
- ):
- """
- Calculates the aggregation loss per example.
- Args:
- logits_aggregation (`torch.FloatTensor` of shape `(batch_size, num_aggregation_labels)`):
- Logits per aggregation operation.
- aggregate_mask (`torch.FloatTensor` of shape `(batch_size, )`):
- A mask set to 1 for examples that should use aggregation functions.
- aggregation_labels (`torch.LongTensor` of shape `(batch_size, )`):
- Aggregation function id for every example in the batch.
- use_answer_as_supervision (`bool`, *optional*):
- Whether to use the answer as the only supervision for aggregation examples.
- num_aggregation_labels (`int`, *optional*, defaults to 0):
- The number of aggregation operators to predict.
- aggregation_loss_weight (`float`, *optional*, defaults to 1.0):
- Importance weight for the aggregation loss.
- Returns:
- aggregation_loss (`torch.FloatTensor` of shape `(batch_size,)`): Aggregation loss per example.
- """
- per_example_aggregation_loss = _calculate_aggregation_loss_known(
- logits_aggregation, aggregate_mask, aggregation_labels, use_answer_as_supervision, num_aggregation_labels
- )
- if use_answer_as_supervision:
- # Add aggregation loss for numeric answers that need aggregation.
- per_example_aggregation_loss += _calculate_aggregation_loss_unknown(logits_aggregation, aggregate_mask)
- return aggregation_loss_weight * per_example_aggregation_loss
- def _calculate_expected_result(
- dist_per_cell, numeric_values, numeric_values_scale, input_mask_float, logits_aggregation, config
- ):
- """
- Calculates the expected result given cell and aggregation probabilities.
- Args:
- dist_per_cell (`torch.distributions.Bernoulli`):
- Cell selection distribution for each cell.
- numeric_values (`torch.FloatTensor` of shape `(batch_size, seq_length)`):
- Numeric values of every token. Nan for tokens which are not numeric values.
- numeric_values_scale (`torch.FloatTensor` of shape `(batch_size, seq_length)`):
- Scale of the numeric values of every token.
- input_mask_float (`torch.FloatTensor` of shape `(batch_size, seq_length)`):
- Mask for the table, without question tokens and table headers.
- logits_aggregation (`torch.FloatTensor` of shape `(batch_size, num_aggregation_labels)`):
- Logits per aggregation operation.
- config ([`TapasConfig`]):
- Model configuration class with all the hyperparameters of the model
- Returns:
- expected_result (`torch.FloatTensor` of shape `(batch_size,)`): The expected result per example.
- """
- if config.use_gumbel_for_cells:
- gumbel_dist = torch.distributions.RelaxedBernoulli(
- # The token logits where already divided by the temperature and used for
- # computing cell selection errors so we need to multiply it again here
- temperature=config.temperature,
- logits=dist_per_cell.logits * config.temperature,
- )
- scaled_probability_per_cell = gumbel_dist.sample()
- else:
- scaled_probability_per_cell = dist_per_cell.probs
- # <float32>[batch_size, seq_length]
- scaled_probability_per_cell = (scaled_probability_per_cell / numeric_values_scale) * input_mask_float
- count_result = torch.sum(scaled_probability_per_cell, dim=1)
- numeric_values_masked = torch.where(
- torch.isnan(numeric_values), torch.zeros_like(numeric_values), numeric_values
- ) # Mask non-numeric table values to zero.
- sum_result = torch.sum(scaled_probability_per_cell * numeric_values_masked, dim=1)
- avg_approximation = config.average_approximation_function
- if avg_approximation == AverageApproximationFunction.RATIO:
- average_result = sum_result / (count_result + EPSILON_ZERO_DIVISION)
- elif avg_approximation == AverageApproximationFunction.FIRST_ORDER:
- # The sum of all probabilities except that correspond to other cells
- # Ex here stands for expectation, more explicitly the expectation of the sum of N-1 Bernoulli random variables plus
- # the constant 1, which is computed as adding all N expected values and subtracting the extra one. It corresponds to X_c
- # in Appendix D of the original TAPAS paper which is trying to approximate the average of a random set.
- ex = torch.sum(scaled_probability_per_cell, dim=1, keepdim=True) - scaled_probability_per_cell + 1
- average_result = torch.sum(numeric_values_masked * scaled_probability_per_cell / ex, dim=1)
- elif avg_approximation == AverageApproximationFunction.SECOND_ORDER:
- # The sum of all probabilities except that correspond to other cells
- ex = torch.sum(scaled_probability_per_cell, dim=1, keepdim=True) - scaled_probability_per_cell + 1
- pointwise_var = scaled_probability_per_cell * (1 - scaled_probability_per_cell)
- var = torch.sum(pointwise_var, dim=1, keepdim=True) - pointwise_var
- multiplier = (var / torch.square(ex) + 1) / ex
- average_result = torch.sum(numeric_values_masked * scaled_probability_per_cell * multiplier, dim=1)
- else:
- raise ValueError(f"Invalid average_approximation_function: {config.average_approximation_function}")
- if config.use_gumbel_for_aggregation:
- gumbel_dist = torch.distributions.RelaxedOneHotCategorical(
- config.aggregation_temperature, logits=logits_aggregation[:, 1:]
- )
- # <float32>[batch_size, num_aggregation_labels - 1]
- aggregation_op_only_probs = gumbel_dist.sample()
- else:
- # <float32>[batch_size, num_aggregation_labels - 1]
- aggregation_op_only_probs = nn.functional.softmax(
- logits_aggregation[:, 1:] / config.aggregation_temperature, dim=-1
- )
- all_results = torch.cat(
- [
- torch.unsqueeze(sum_result, dim=1),
- torch.unsqueeze(average_result, dim=1),
- torch.unsqueeze(count_result, dim=1),
- ],
- dim=1,
- )
- expected_result = torch.sum(all_results * aggregation_op_only_probs, dim=1)
- return expected_result
- # PyTorch does not currently support Huber loss with custom delta so we define it ourself
- def huber_loss(input, target, delta: float = 1.0):
- errors = torch.abs(input - target) # shape (batch_size,)
- return torch.where(errors < delta, 0.5 * errors**2, errors * delta - (0.5 * delta**2))
- def _calculate_regression_loss(
- answer,
- aggregate_mask,
- dist_per_cell,
- numeric_values,
- numeric_values_scale,
- input_mask_float,
- logits_aggregation,
- config,
- ):
- """
- Calculates the regression loss per example.
- Args:
- answer (`torch.FloatTensor` of shape `(batch_size,)`):
- Answer for every example in the batch. Nan if there is no scalar answer.
- aggregate_mask (`torch.FloatTensor` of shape `(batch_size,)`):
- A mask set to 1 for examples that should use aggregation functions.
- dist_per_cell (`torch.distributions.Bernoulli`):
- Cell selection distribution for each cell.
- numeric_values (`torch.FloatTensor` of shape `(batch_size, seq_length)`):
- Numeric values of every token. Nan for tokens which are not numeric values.
- numeric_values_scale (`torch.FloatTensor` of shape `(batch_size, seq_length)`):
- Scale of the numeric values of every token.
- input_mask_float (`torch.FloatTensor` of shape `(batch_size, seq_length)`):
- Mask for the table, without question tokens and table headers.
- logits_aggregation (`torch.FloatTensor` of shape `(batch_size, num_aggregation_labels)`):
- Logits per aggregation operation.
- config ([`TapasConfig`]):
- Model configuration class with all the parameters of the model
- Returns:
- per_example_answer_loss_scaled (`torch.FloatTensor` of shape `(batch_size,)`): Scales answer loss for each
- example in the batch. large_answer_loss_mask (`torch.FloatTensor` of shape `(batch_size,)`): A mask which is 1
- for examples for which their answer loss is larger than the answer_loss_cutoff.
- """
- # float32 (batch_size,)
- expected_result = _calculate_expected_result(
- dist_per_cell, numeric_values, numeric_values_scale, input_mask_float, logits_aggregation, config
- )
- # float32 (batch_size,)
- answer_masked = torch.where(torch.isnan(answer), torch.zeros_like(answer), answer)
- if config.use_normalized_answer_loss:
- normalizer = (torch.max(torch.abs(expected_result), torch.abs(answer_masked)) + EPSILON_ZERO_DIVISION).detach()
- normalized_answer_masked = answer_masked / normalizer
- normalized_expected_result = expected_result / normalizer
- per_example_answer_loss = huber_loss(
- normalized_expected_result * aggregate_mask, normalized_answer_masked * aggregate_mask
- )
- else:
- per_example_answer_loss = huber_loss(
- expected_result * aggregate_mask, answer_masked * aggregate_mask, delta=config.huber_loss_delta
- )
- if config.answer_loss_cutoff is None:
- large_answer_loss_mask = torch.ones_like(per_example_answer_loss, dtype=torch.float32)
- else:
- large_answer_loss_mask = torch.where(
- per_example_answer_loss > config.answer_loss_cutoff,
- torch.zeros_like(per_example_answer_loss, dtype=torch.float32),
- torch.ones_like(per_example_answer_loss, dtype=torch.float32),
- )
- per_example_answer_loss_scaled = config.answer_loss_importance * (per_example_answer_loss * aggregate_mask)
- return per_example_answer_loss_scaled, large_answer_loss_mask
- __all__ = [
- "TapasForMaskedLM",
- "TapasForQuestionAnswering",
- "TapasForSequenceClassification",
- "TapasModel",
- "TapasPreTrainedModel",
- ]
|