| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920 |
- # Copyright 2020 The SqueezeBert authors and The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch SqueezeBert model."""
- import math
- import torch
- from torch import nn
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...modeling_outputs import (
- BaseModelOutput,
- BaseModelOutputWithPooling,
- MaskedLMOutput,
- MultipleChoiceModelOutput,
- QuestionAnsweringModelOutput,
- SequenceClassifierOutput,
- TokenClassifierOutput,
- )
- from ...modeling_utils import PreTrainedModel
- from ...utils import (
- auto_docstring,
- logging,
- )
- from .configuration_squeezebert import SqueezeBertConfig
- logger = logging.get_logger(__name__)
- class SqueezeBertEmbeddings(nn.Module):
- """Construct the embeddings from word, position and token_type embeddings."""
- def __init__(self, config):
- super().__init__()
- self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
- self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
- self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- # position_ids (1, len position emb) is contiguous in memory and exported when serialized
- self.register_buffer(
- "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
- )
- def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
- if input_ids is not None:
- input_shape = input_ids.size()
- else:
- input_shape = inputs_embeds.size()[:-1]
- seq_length = input_shape[1]
- if position_ids is None:
- position_ids = self.position_ids[:, :seq_length]
- if token_type_ids is None:
- token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
- if inputs_embeds is None:
- inputs_embeds = self.word_embeddings(input_ids)
- position_embeddings = self.position_embeddings(position_ids)
- token_type_embeddings = self.token_type_embeddings(token_type_ids)
- embeddings = inputs_embeds + position_embeddings + token_type_embeddings
- embeddings = self.LayerNorm(embeddings)
- embeddings = self.dropout(embeddings)
- return embeddings
- class MatMulWrapper(nn.Module):
- """
- Wrapper for torch.matmul(). This makes flop-counting easier to implement. Note that if you directly call
- torch.matmul() in your code, the flop counter will typically ignore the flops of the matmul.
- """
- def __init__(self):
- super().__init__()
- def forward(self, mat1, mat2):
- """
- :param inputs: two torch tensors :return: matmul of these tensors
- Here are the typical dimensions found in BERT (the B is optional) mat1.shape: [B, <optional extra dims>, M, K]
- mat2.shape: [B, <optional extra dims>, K, N] output shape: [B, <optional extra dims>, M, N]
- """
- return torch.matmul(mat1, mat2)
- class SqueezeBertLayerNorm(nn.LayerNorm):
- """
- This is a nn.LayerNorm subclass that accepts NCW data layout and performs normalization in the C dimension.
- N = batch C = channels W = sequence length
- """
- def __init__(self, hidden_size, eps=1e-12):
- nn.LayerNorm.__init__(self, normalized_shape=hidden_size, eps=eps) # instantiates self.{weight, bias, eps}
- def forward(self, x):
- x = x.permute(0, 2, 1)
- x = nn.LayerNorm.forward(self, x)
- return x.permute(0, 2, 1)
- class ConvDropoutLayerNorm(nn.Module):
- """
- ConvDropoutLayerNorm: Conv, Dropout, LayerNorm
- """
- def __init__(self, cin, cout, groups, dropout_prob):
- super().__init__()
- self.conv1d = nn.Conv1d(in_channels=cin, out_channels=cout, kernel_size=1, groups=groups)
- self.layernorm = SqueezeBertLayerNorm(cout)
- self.dropout = nn.Dropout(dropout_prob)
- def forward(self, hidden_states, input_tensor):
- x = self.conv1d(hidden_states)
- x = self.dropout(x)
- x = x + input_tensor
- x = self.layernorm(x)
- return x
- class ConvActivation(nn.Module):
- """
- ConvActivation: Conv, Activation
- """
- def __init__(self, cin, cout, groups, act):
- super().__init__()
- self.conv1d = nn.Conv1d(in_channels=cin, out_channels=cout, kernel_size=1, groups=groups)
- self.act = ACT2FN[act]
- def forward(self, x):
- output = self.conv1d(x)
- return self.act(output)
- class SqueezeBertSelfAttention(nn.Module):
- def __init__(self, config, cin, q_groups=1, k_groups=1, v_groups=1):
- """
- config = used for some things; ignored for others (work in progress...) cin = input channels = output channels
- groups = number of groups to use in conv1d layers
- """
- super().__init__()
- if cin % config.num_attention_heads != 0:
- raise ValueError(
- f"cin ({cin}) is not a multiple of the number of attention heads ({config.num_attention_heads})"
- )
- self.num_attention_heads = config.num_attention_heads
- self.attention_head_size = int(cin / config.num_attention_heads)
- self.all_head_size = self.num_attention_heads * self.attention_head_size
- self.query = nn.Conv1d(in_channels=cin, out_channels=cin, kernel_size=1, groups=q_groups)
- self.key = nn.Conv1d(in_channels=cin, out_channels=cin, kernel_size=1, groups=k_groups)
- self.value = nn.Conv1d(in_channels=cin, out_channels=cin, kernel_size=1, groups=v_groups)
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
- self.softmax = nn.Softmax(dim=-1)
- self.matmul_qk = MatMulWrapper()
- self.matmul_qkv = MatMulWrapper()
- def transpose_for_scores(self, x):
- """
- - input: [N, C, W]
- - output: [N, C1, W, C2] where C1 is the head index, and C2 is one head's contents
- """
- new_x_shape = (x.size()[0], self.num_attention_heads, self.attention_head_size, x.size()[-1]) # [N, C1, C2, W]
- x = x.view(*new_x_shape)
- return x.permute(0, 1, 3, 2) # [N, C1, C2, W] --> [N, C1, W, C2]
- def transpose_key_for_scores(self, x):
- """
- - input: [N, C, W]
- - output: [N, C1, C2, W] where C1 is the head index, and C2 is one head's contents
- """
- new_x_shape = (x.size()[0], self.num_attention_heads, self.attention_head_size, x.size()[-1]) # [N, C1, C2, W]
- x = x.view(*new_x_shape)
- # no `permute` needed
- return x
- def transpose_output(self, x):
- """
- - input: [N, C1, W, C2]
- - output: [N, C, W]
- """
- x = x.permute(0, 1, 3, 2).contiguous() # [N, C1, C2, W]
- new_x_shape = (x.size()[0], self.all_head_size, x.size()[3]) # [N, C, W]
- x = x.view(*new_x_shape)
- return x
- def forward(self, hidden_states, attention_mask, output_attentions):
- """
- expects hidden_states in [N, C, W] data layout.
- The attention_mask data layout is [N, W], and it does not need to be transposed.
- """
- mixed_query_layer = self.query(hidden_states)
- mixed_key_layer = self.key(hidden_states)
- mixed_value_layer = self.value(hidden_states)
- query_layer = self.transpose_for_scores(mixed_query_layer)
- key_layer = self.transpose_key_for_scores(mixed_key_layer)
- value_layer = self.transpose_for_scores(mixed_value_layer)
- # Take the dot product between "query" and "key" to get the raw attention scores.
- attention_score = self.matmul_qk(query_layer, key_layer)
- attention_score = attention_score / math.sqrt(self.attention_head_size)
- # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
- attention_score = attention_score + attention_mask
- # Normalize the attention scores to probabilities.
- attention_probs = self.softmax(attention_score)
- # This is actually dropping out entire tokens to attend to, which might
- # seem a bit unusual, but is taken from the original Transformer paper.
- attention_probs = self.dropout(attention_probs)
- context_layer = self.matmul_qkv(attention_probs, value_layer)
- context_layer = self.transpose_output(context_layer)
- result = {"context_layer": context_layer}
- if output_attentions:
- result["attention_score"] = attention_score
- return result
- class SqueezeBertModule(nn.Module):
- def __init__(self, config):
- """
- - hidden_size = input chans = output chans for Q, K, V (they are all the same ... for now) = output chans for
- the module
- - intermediate_size = output chans for intermediate layer
- - groups = number of groups for all layers in the BertModule. (eventually we could change the interface to
- allow different groups for different layers)
- """
- super().__init__()
- c0 = config.hidden_size
- c1 = config.hidden_size
- c2 = config.intermediate_size
- c3 = config.hidden_size
- self.attention = SqueezeBertSelfAttention(
- config=config, cin=c0, q_groups=config.q_groups, k_groups=config.k_groups, v_groups=config.v_groups
- )
- self.post_attention = ConvDropoutLayerNorm(
- cin=c0, cout=c1, groups=config.post_attention_groups, dropout_prob=config.hidden_dropout_prob
- )
- self.intermediate = ConvActivation(cin=c1, cout=c2, groups=config.intermediate_groups, act=config.hidden_act)
- self.output = ConvDropoutLayerNorm(
- cin=c2, cout=c3, groups=config.output_groups, dropout_prob=config.hidden_dropout_prob
- )
- def forward(self, hidden_states, attention_mask, output_attentions):
- att = self.attention(hidden_states, attention_mask, output_attentions)
- attention_output = att["context_layer"]
- post_attention_output = self.post_attention(attention_output, hidden_states)
- intermediate_output = self.intermediate(post_attention_output)
- layer_output = self.output(intermediate_output, post_attention_output)
- output_dict = {"feature_map": layer_output}
- if output_attentions:
- output_dict["attention_score"] = att["attention_score"]
- return output_dict
- class SqueezeBertEncoder(nn.Module):
- def __init__(self, config):
- super().__init__()
- assert config.embedding_size == config.hidden_size, (
- "If you want embedding_size != intermediate hidden_size, "
- "please insert a Conv1d layer to adjust the number of channels "
- "before the first SqueezeBertModule."
- )
- self.layers = nn.ModuleList(SqueezeBertModule(config) for _ in range(config.num_hidden_layers))
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- output_attentions=False,
- output_hidden_states=False,
- return_dict=True,
- ):
- # [batch_size, sequence_length, hidden_size] --> [batch_size, hidden_size, sequence_length]
- hidden_states = hidden_states.permute(0, 2, 1)
- all_hidden_states = () if output_hidden_states else None
- all_attentions = () if output_attentions else None
- for layer in self.layers:
- if output_hidden_states:
- hidden_states = hidden_states.permute(0, 2, 1)
- all_hidden_states += (hidden_states,)
- hidden_states = hidden_states.permute(0, 2, 1)
- layer_output = layer.forward(hidden_states, attention_mask, output_attentions)
- hidden_states = layer_output["feature_map"]
- if output_attentions:
- all_attentions += (layer_output["attention_score"],)
- # [batch_size, hidden_size, sequence_length] --> [batch_size, sequence_length, hidden_size]
- hidden_states = hidden_states.permute(0, 2, 1)
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
- if not return_dict:
- return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
- return BaseModelOutput(
- last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
- )
- class SqueezeBertPooler(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- self.activation = nn.Tanh()
- def forward(self, hidden_states):
- # We "pool" the model by simply taking the hidden state corresponding
- # to the first token.
- first_token_tensor = hidden_states[:, 0]
- pooled_output = self.dense(first_token_tensor)
- pooled_output = self.activation(pooled_output)
- return pooled_output
- class SqueezeBertPredictionHeadTransform(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- if isinstance(config.hidden_act, str):
- self.transform_act_fn = ACT2FN[config.hidden_act]
- else:
- self.transform_act_fn = config.hidden_act
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- def forward(self, hidden_states):
- hidden_states = self.dense(hidden_states)
- hidden_states = self.transform_act_fn(hidden_states)
- hidden_states = self.LayerNorm(hidden_states)
- return hidden_states
- class SqueezeBertLMPredictionHead(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.transform = SqueezeBertPredictionHeadTransform(config)
- # The output weights are the same as the input embeddings, but there is
- # an output-only bias for each token.
- self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
- self.bias = nn.Parameter(torch.zeros(config.vocab_size))
- # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
- def forward(self, hidden_states):
- hidden_states = self.transform(hidden_states)
- hidden_states = self.decoder(hidden_states)
- return hidden_states
- class SqueezeBertOnlyMLMHead(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.predictions = SqueezeBertLMPredictionHead(config)
- def forward(self, sequence_output):
- prediction_scores = self.predictions(sequence_output)
- return prediction_scores
- @auto_docstring
- class SqueezeBertPreTrainedModel(PreTrainedModel):
- config: SqueezeBertConfig
- base_model_prefix = "transformer"
- @torch.no_grad()
- def _init_weights(self, module):
- """Initialize the weights"""
- super()._init_weights(module)
- if isinstance(module, SqueezeBertLMPredictionHead):
- init.zeros_(module.bias)
- elif isinstance(module, SqueezeBertEmbeddings):
- init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
- @auto_docstring
- class SqueezeBertModel(SqueezeBertPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.embeddings = SqueezeBertEmbeddings(config)
- self.encoder = SqueezeBertEncoder(config)
- self.pooler = SqueezeBertPooler(config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.embeddings.word_embeddings
- def set_input_embeddings(self, new_embeddings):
- self.embeddings.word_embeddings = new_embeddings
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | BaseModelOutputWithPooling:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
- elif input_ids is not None:
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
- input_shape = input_ids.size()
- elif inputs_embeds is not None:
- input_shape = inputs_embeds.size()[:-1]
- else:
- raise ValueError("You have to specify either input_ids or inputs_embeds")
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- if attention_mask is None:
- attention_mask = torch.ones(input_shape, device=device)
- if token_type_ids is None:
- token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
- extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
- embedding_output = self.embeddings(
- input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
- )
- encoder_outputs = self.encoder(
- hidden_states=embedding_output,
- attention_mask=extended_attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = encoder_outputs[0]
- pooled_output = self.pooler(sequence_output)
- if not return_dict:
- return (sequence_output, pooled_output) + encoder_outputs[1:]
- return BaseModelOutputWithPooling(
- last_hidden_state=sequence_output,
- pooler_output=pooled_output,
- hidden_states=encoder_outputs.hidden_states,
- attentions=encoder_outputs.attentions,
- )
- @auto_docstring
- class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
- _tied_weights_keys = {
- "cls.predictions.decoder.bias": "cls.predictions.bias",
- "cls.predictions.decoder.weight": "transformer.embeddings.word_embeddings.weight",
- }
- def __init__(self, config):
- super().__init__(config)
- self.transformer = SqueezeBertModel(config)
- self.cls = SqueezeBertOnlyMLMHead(config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_output_embeddings(self):
- return self.cls.predictions.decoder
- def set_output_embeddings(self, new_embeddings):
- self.cls.predictions.decoder = new_embeddings
- self.cls.predictions.bias = new_embeddings.bias
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | MaskedLMOutput:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
- config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
- loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.transformer(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = outputs[0]
- prediction_scores = self.cls(sequence_output)
- masked_lm_loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss() # -100 index = padding token
- masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
- if not return_dict:
- output = (prediction_scores,) + outputs[2:]
- return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
- return MaskedLMOutput(
- loss=masked_lm_loss,
- logits=prediction_scores,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring(
- custom_intro="""
- SqueezeBERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the
- pooled output) e.g. for GLUE tasks.
- """
- )
- class SqueezeBertForSequenceClassification(SqueezeBertPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.config = config
- self.transformer = SqueezeBertModel(config)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | SequenceClassifierOutput:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.transformer(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- pooled_output = outputs[1]
- pooled_output = self.dropout(pooled_output)
- logits = self.classifier(pooled_output)
- loss = None
- if labels is not None:
- if self.config.problem_type is None:
- if self.num_labels == 1:
- self.config.problem_type = "regression"
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
- self.config.problem_type = "single_label_classification"
- else:
- self.config.problem_type = "multi_label_classification"
- if self.config.problem_type == "regression":
- loss_fct = MSELoss()
- if self.num_labels == 1:
- loss = loss_fct(logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(logits, labels)
- elif self.config.problem_type == "single_label_classification":
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- elif self.config.problem_type == "multi_label_classification":
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(logits, labels)
- if not return_dict:
- output = (logits,) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return SequenceClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring
- class SqueezeBertForMultipleChoice(SqueezeBertPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.transformer = SqueezeBertModel(config)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- self.classifier = nn.Linear(config.hidden_size, 1)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | MultipleChoiceModelOutput:
- r"""
- input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
- Indices of input sequence tokens in the vocabulary.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
- Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
- 1]`:
- - 0 corresponds to a *sentence A* token,
- - 1 corresponds to a *sentence B* token.
- [What are token type IDs?](../glossary#token-type-ids)
- position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.max_position_embeddings - 1]`.
- [What are position IDs?](../glossary#position-ids)
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
- num_choices-1]` where *num_choices* is the size of the second dimension of the input tensors. (see
- *input_ids* above)
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
- input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
- attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
- token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
- position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
- inputs_embeds = (
- inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
- if inputs_embeds is not None
- else None
- )
- outputs = self.transformer(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- pooled_output = outputs[1]
- pooled_output = self.dropout(pooled_output)
- logits = self.classifier(pooled_output)
- reshaped_logits = logits.view(-1, num_choices)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(reshaped_logits, labels)
- if not return_dict:
- output = (reshaped_logits,) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return MultipleChoiceModelOutput(
- loss=loss,
- logits=reshaped_logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring
- class SqueezeBertForTokenClassification(SqueezeBertPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.transformer = SqueezeBertModel(config)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | TokenClassifierOutput:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.transformer(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = outputs[0]
- sequence_output = self.dropout(sequence_output)
- logits = self.classifier(sequence_output)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- if not return_dict:
- output = (logits,) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return TokenClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring
- class SqueezeBertForQuestionAnswering(SqueezeBertPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.transformer = SqueezeBertModel(config)
- self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- start_positions: torch.Tensor | None = None,
- end_positions: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | QuestionAnsweringModelOutput:
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.transformer(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = outputs[0]
- logits = self.qa_outputs(sequence_output)
- start_logits, end_logits = logits.split(1, dim=-1)
- start_logits = start_logits.squeeze(-1).contiguous()
- end_logits = end_logits.squeeze(-1).contiguous()
- total_loss = None
- if start_positions is not None and end_positions is not None:
- # If we are on multi-GPU, split add a dimension
- if len(start_positions.size()) > 1:
- start_positions = start_positions.squeeze(-1)
- if len(end_positions.size()) > 1:
- end_positions = end_positions.squeeze(-1)
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
- ignored_index = start_logits.size(1)
- start_positions = start_positions.clamp(0, ignored_index)
- end_positions = end_positions.clamp(0, ignored_index)
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
- start_loss = loss_fct(start_logits, start_positions)
- end_loss = loss_fct(end_logits, end_positions)
- total_loss = (start_loss + end_loss) / 2
- if not return_dict:
- output = (start_logits, end_logits) + outputs[2:]
- return ((total_loss,) + output) if total_loss is not None else output
- return QuestionAnsweringModelOutput(
- loss=total_loss,
- start_logits=start_logits,
- end_logits=end_logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- __all__ = [
- "SqueezeBertForMaskedLM",
- "SqueezeBertForMultipleChoice",
- "SqueezeBertForQuestionAnswering",
- "SqueezeBertForSequenceClassification",
- "SqueezeBertForTokenClassification",
- "SqueezeBertModel",
- "SqueezeBertModule",
- "SqueezeBertPreTrainedModel",
- ]
|