| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251 |
- # MIT License
- #
- # Copyright (c) 2020 The Google AI Language Team Authors, The HuggingFace Inc. team and github/lonePatient
- #
- # Permission is hereby granted, free of charge, to any person obtaining a copy
- # of this software and associated documentation files (the "Software"), to deal
- # in the Software without restriction, including without limitation the rights
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- # copies of the Software, and to permit persons to whom the Software is
- # furnished to do so, subject to the following conditions:
- #
- # The above copyright notice and this permission notice shall be included in all
- # copies or substantial portions of the Software.
- #
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- # SOFTWARE.
- from collections.abc import Callable
- from dataclasses import dataclass
- import torch
- from torch import nn
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...masking_utils import create_bidirectional_mask
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import (
- BaseModelOutput,
- BaseModelOutputWithPooling,
- MaskedLMOutput,
- MultipleChoiceModelOutput,
- NextSentencePredictorOutput,
- QuestionAnsweringModelOutput,
- SequenceClassifierOutput,
- TokenClassifierOutput,
- )
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging
- from ...utils.generic import can_return_tuple, merge_with_config_defaults
- from ...utils.output_capturing import capture_outputs
- from .configuration_mobilebert import MobileBertConfig
- logger = logging.get_logger(__name__)
- class NoNorm(nn.Module):
- def __init__(self, feat_size, eps=None):
- super().__init__()
- self.bias = nn.Parameter(torch.zeros(feat_size))
- self.weight = nn.Parameter(torch.ones(feat_size))
- def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
- return input_tensor * self.weight + self.bias
- NORM2FN = {"layer_norm": nn.LayerNorm, "no_norm": NoNorm}
- class MobileBertEmbeddings(nn.Module):
- """Construct the embeddings from word, position and token_type embeddings."""
- def __init__(self, config):
- super().__init__()
- self.trigram_input = config.trigram_input
- self.embedding_size = config.embedding_size
- self.hidden_size = config.hidden_size
- self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
- self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
- self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
- embed_dim_multiplier = 3 if self.trigram_input else 1
- embedded_input_size = self.embedding_size * embed_dim_multiplier
- self.embedding_transformation = nn.Linear(embedded_input_size, config.hidden_size)
- self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- # position_ids (1, len position emb) is contiguous in memory and exported when serialized
- self.register_buffer(
- "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
- )
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- ) -> torch.Tensor:
- if input_ids is not None:
- input_shape = input_ids.size()
- else:
- input_shape = inputs_embeds.size()[:-1]
- seq_length = input_shape[1]
- if position_ids is None:
- position_ids = self.position_ids[:, :seq_length]
- if token_type_ids is None:
- token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
- if inputs_embeds is None:
- inputs_embeds = self.word_embeddings(input_ids)
- if self.trigram_input:
- # From the paper MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited
- # Devices (https://huggingface.co/papers/2004.02984)
- #
- # The embedding table in BERT models accounts for a substantial proportion of model size. To compress
- # the embedding layer, we reduce the embedding dimension to 128 in MobileBERT.
- # Then, we apply a 1D convolution with kernel size 3 on the raw token embedding to produce a 512
- # dimensional output.
- inputs_embeds = torch.cat(
- [
- nn.functional.pad(inputs_embeds[:, 1:], [0, 0, 0, 1, 0, 0], value=0.0),
- inputs_embeds,
- nn.functional.pad(inputs_embeds[:, :-1], [0, 0, 1, 0, 0, 0], value=0.0),
- ],
- dim=2,
- )
- if self.trigram_input or self.embedding_size != self.hidden_size:
- inputs_embeds = self.embedding_transformation(inputs_embeds)
- # Add positional embeddings and token type embeddings, then layer
- # normalize and perform dropout.
- position_embeddings = self.position_embeddings(position_ids)
- token_type_embeddings = self.token_type_embeddings(token_type_ids)
- embeddings = inputs_embeds + position_embeddings + token_type_embeddings
- embeddings = self.LayerNorm(embeddings)
- embeddings = self.dropout(embeddings)
- return embeddings
- # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
- def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: torch.Tensor | None,
- scaling: float | None = None,
- dropout: float = 0.0,
- **kwargs: Unpack[TransformersKwargs],
- ):
- if scaling is None:
- scaling = query.size(-1) ** -0.5
- # Take the dot product between "query" and "key" to get the raw attention scores.
- attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
- if attention_mask is not None:
- attn_weights = attn_weights + attention_mask
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
- class MobileBertSelfAttention(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.num_attention_heads = config.num_attention_heads
- self.attention_head_size = int(config.true_hidden_size / config.num_attention_heads)
- self.all_head_size = self.num_attention_heads * self.attention_head_size
- self.scaling = self.attention_head_size**-0.5
- self.query = nn.Linear(config.true_hidden_size, self.all_head_size)
- self.key = nn.Linear(config.true_hidden_size, self.all_head_size)
- self.value = nn.Linear(
- config.true_hidden_size if config.use_bottleneck_attention else config.hidden_size, self.all_head_size
- )
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
- self.is_causal = False
- def forward(
- self,
- query_tensor: torch.Tensor,
- key_tensor: torch.Tensor,
- value_tensor: torch.Tensor,
- attention_mask: torch.FloatTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor]:
- input_shape = query_tensor.shape[:-1]
- hidden_shape = (*input_shape, -1, self.attention_head_size)
- # get all proj
- query_layer = self.query(query_tensor).view(*hidden_shape).transpose(1, 2)
- key_layer = self.key(key_tensor).view(*hidden_shape).transpose(1, 2)
- value_layer = self.value(value_tensor).view(*hidden_shape).transpose(1, 2)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- query_layer,
- key_layer,
- value_layer,
- attention_mask,
- dropout=0.0 if not self.training else self.dropout.p,
- scaling=self.scaling,
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- return attn_output, attn_weights
- class MobileBertSelfOutput(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.use_bottleneck = config.use_bottleneck
- self.dense = nn.Linear(config.true_hidden_size, config.true_hidden_size)
- self.LayerNorm = NORM2FN[config.normalization_type](config.true_hidden_size, eps=config.layer_norm_eps)
- if not self.use_bottleneck:
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states: torch.Tensor, residual_tensor: torch.Tensor) -> torch.Tensor:
- layer_outputs = self.dense(hidden_states)
- if not self.use_bottleneck:
- layer_outputs = self.dropout(layer_outputs)
- layer_outputs = self.LayerNorm(layer_outputs + residual_tensor)
- return layer_outputs
- class MobileBertAttention(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.self = MobileBertSelfAttention(config)
- self.output = MobileBertSelfOutput(config)
- def forward(
- self,
- query_tensor: torch.Tensor,
- key_tensor: torch.Tensor,
- value_tensor: torch.Tensor,
- layer_input: torch.Tensor,
- attention_mask: torch.FloatTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor]:
- attention_output, attn_weights = self.self(
- query_tensor,
- key_tensor,
- value_tensor,
- attention_mask,
- **kwargs,
- )
- # Run a linear projection of `hidden_size` then add a residual
- # with `layer_input`.
- attention_output = self.output(attention_output, layer_input)
- return attention_output, attn_weights
- class MobileBertIntermediate(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.true_hidden_size, config.intermediate_size)
- if isinstance(config.hidden_act, str):
- self.intermediate_act_fn = ACT2FN[config.hidden_act]
- else:
- self.intermediate_act_fn = config.hidden_act
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.intermediate_act_fn(hidden_states)
- return hidden_states
- class OutputBottleneck(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.true_hidden_size, config.hidden_size)
- self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states: torch.Tensor, residual_tensor: torch.Tensor) -> torch.Tensor:
- layer_outputs = self.dense(hidden_states)
- layer_outputs = self.dropout(layer_outputs)
- layer_outputs = self.LayerNorm(layer_outputs + residual_tensor)
- return layer_outputs
- class MobileBertOutput(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.use_bottleneck = config.use_bottleneck
- self.dense = nn.Linear(config.intermediate_size, config.true_hidden_size)
- self.LayerNorm = NORM2FN[config.normalization_type](config.true_hidden_size)
- if not self.use_bottleneck:
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- else:
- self.bottleneck = OutputBottleneck(config)
- def forward(
- self, intermediate_states: torch.Tensor, residual_tensor_1: torch.Tensor, residual_tensor_2: torch.Tensor
- ) -> torch.Tensor:
- layer_output = self.dense(intermediate_states)
- if not self.use_bottleneck:
- layer_output = self.dropout(layer_output)
- layer_output = self.LayerNorm(layer_output + residual_tensor_1)
- else:
- layer_output = self.LayerNorm(layer_output + residual_tensor_1)
- layer_output = self.bottleneck(layer_output, residual_tensor_2)
- return layer_output
- class BottleneckLayer(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.intra_bottleneck_size)
- self.LayerNorm = NORM2FN[config.normalization_type](config.intra_bottleneck_size, eps=config.layer_norm_eps)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- layer_input = self.dense(hidden_states)
- layer_input = self.LayerNorm(layer_input)
- return layer_input
- class Bottleneck(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.key_query_shared_bottleneck = config.key_query_shared_bottleneck
- self.use_bottleneck_attention = config.use_bottleneck_attention
- self.input = BottleneckLayer(config)
- if self.key_query_shared_bottleneck:
- self.attention = BottleneckLayer(config)
- def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor]:
- # This method can return three different tuples of values. These different values make use of bottlenecks,
- # which are linear layers used to project the hidden states to a lower-dimensional vector, reducing memory
- # usage. These linear layer have weights that are learned during training.
- #
- # If `config.use_bottleneck_attention`, it will return the result of the bottleneck layer four times for the
- # key, query, value, and "layer input" to be used by the attention layer.
- # This bottleneck is used to project the hidden. This last layer input will be used as a residual tensor
- # in the attention self output, after the attention scores have been computed.
- #
- # If not `config.use_bottleneck_attention` and `config.key_query_shared_bottleneck`, this will return
- # four values, three of which have been passed through a bottleneck: the query and key, passed through the same
- # bottleneck, and the residual layer to be applied in the attention self output, through another bottleneck.
- #
- # Finally, in the last case, the values for the query, key and values are the hidden states without bottleneck,
- # and the residual layer will be this value passed through a bottleneck.
- bottlenecked_hidden_states = self.input(hidden_states)
- if self.use_bottleneck_attention:
- return (bottlenecked_hidden_states,) * 4
- elif self.key_query_shared_bottleneck:
- shared_attention_input = self.attention(hidden_states)
- return (shared_attention_input, shared_attention_input, hidden_states, bottlenecked_hidden_states)
- else:
- return (hidden_states, hidden_states, hidden_states, bottlenecked_hidden_states)
- class FFNOutput(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.intermediate_size, config.true_hidden_size)
- self.LayerNorm = NORM2FN[config.normalization_type](config.true_hidden_size, eps=config.layer_norm_eps)
- def forward(self, hidden_states: torch.Tensor, residual_tensor: torch.Tensor) -> torch.Tensor:
- layer_outputs = self.dense(hidden_states)
- layer_outputs = self.LayerNorm(layer_outputs + residual_tensor)
- return layer_outputs
- class FFNLayer(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.intermediate = MobileBertIntermediate(config)
- self.output = FFNOutput(config)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- intermediate_output = self.intermediate(hidden_states)
- layer_outputs = self.output(intermediate_output, hidden_states)
- return layer_outputs
- class MobileBertLayer(GradientCheckpointingLayer):
- def __init__(self, config):
- super().__init__()
- self.use_bottleneck = config.use_bottleneck
- self.num_feedforward_networks = config.num_feedforward_networks
- self.attention = MobileBertAttention(config)
- self.intermediate = MobileBertIntermediate(config)
- self.output = MobileBertOutput(config)
- if self.use_bottleneck:
- self.bottleneck = Bottleneck(config)
- if config.num_feedforward_networks > 1:
- self.ffn = nn.ModuleList([FFNLayer(config) for _ in range(config.num_feedforward_networks - 1)])
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.FloatTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.Tensor:
- if self.use_bottleneck:
- query_tensor, key_tensor, value_tensor, layer_input = self.bottleneck(hidden_states)
- else:
- query_tensor, key_tensor, value_tensor, layer_input = [hidden_states] * 4
- self_attention_output, _ = self.attention(
- query_tensor,
- key_tensor,
- value_tensor,
- layer_input,
- attention_mask,
- **kwargs,
- )
- attention_output = self_attention_output
- if self.num_feedforward_networks != 1:
- for ffn_module in self.ffn:
- attention_output = ffn_module(attention_output)
- intermediate_output = self.intermediate(attention_output)
- layer_output = self.output(intermediate_output, attention_output, hidden_states)
- return layer_output
- class MobileBertEncoder(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.layer = nn.ModuleList([MobileBertLayer(config) for _ in range(config.num_hidden_layers)])
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.FloatTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BaseModelOutput:
- for i, layer_module in enumerate(self.layer):
- hidden_states = layer_module(
- hidden_states,
- attention_mask,
- **kwargs,
- )
- return BaseModelOutput(last_hidden_state=hidden_states)
- class MobileBertPooler(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.do_activate = config.classifier_activation
- if self.do_activate:
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- # We "pool" the model by simply taking the hidden state corresponding
- # to the first token.
- first_token_tensor = hidden_states[:, 0]
- if not self.do_activate:
- return first_token_tensor
- else:
- pooled_output = self.dense(first_token_tensor)
- pooled_output = torch.tanh(pooled_output)
- return pooled_output
- class MobileBertPredictionHeadTransform(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- if isinstance(config.hidden_act, str):
- self.transform_act_fn = ACT2FN[config.hidden_act]
- else:
- self.transform_act_fn = config.hidden_act
- self.LayerNorm = NORM2FN["layer_norm"](config.hidden_size, eps=config.layer_norm_eps)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.transform_act_fn(hidden_states)
- hidden_states = self.LayerNorm(hidden_states)
- return hidden_states
- class MobileBertLMPredictionHead(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.transform = MobileBertPredictionHeadTransform(config)
- # The output weights are the same as the input embeddings, but there is
- # an output-only bias for each token.
- self.dense = nn.Linear(config.vocab_size, config.hidden_size - config.embedding_size, bias=False)
- self.decoder = nn.Linear(config.embedding_size, config.vocab_size, bias=True)
- self.bias = nn.Parameter(torch.zeros(config.vocab_size))
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.transform(hidden_states)
- hidden_states = hidden_states.matmul(torch.cat([self.decoder.weight.t(), self.dense.weight], dim=0))
- hidden_states += self.decoder.bias
- return hidden_states
- class MobileBertOnlyMLMHead(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.predictions = MobileBertLMPredictionHead(config)
- def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
- prediction_scores = self.predictions(sequence_output)
- return prediction_scores
- class MobileBertPreTrainingHeads(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.predictions = MobileBertLMPredictionHead(config)
- self.seq_relationship = nn.Linear(config.hidden_size, 2)
- def forward(self, sequence_output: torch.Tensor, pooled_output: torch.Tensor) -> tuple[torch.Tensor]:
- prediction_scores = self.predictions(sequence_output)
- seq_relationship_score = self.seq_relationship(pooled_output)
- return prediction_scores, seq_relationship_score
- @auto_docstring
- class MobileBertPreTrainedModel(PreTrainedModel):
- config: MobileBertConfig
- base_model_prefix = "mobilebert"
- supports_gradient_checkpointing = True
- _supports_flash_attn = True
- _supports_sdpa = True
- _supports_flex_attn = True
- _supports_attention_backend = True
- _can_record_outputs = {
- "hidden_states": MobileBertLayer,
- "attentions": MobileBertSelfAttention,
- }
- @torch.no_grad()
- def _init_weights(self, module):
- """Initialize the weights"""
- super()._init_weights(module)
- if isinstance(module, NoNorm):
- init.zeros_(module.bias)
- init.ones_(module.weight)
- elif isinstance(module, MobileBertLMPredictionHead):
- init.zeros_(module.bias)
- elif isinstance(module, MobileBertEmbeddings):
- init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
- @dataclass
- @auto_docstring(
- custom_intro="""
- Output type of [`MobileBertForPreTraining`].
- """
- )
- class MobileBertForPreTrainingOutput(ModelOutput):
- r"""
- loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
- Total loss as the sum of the masked language modeling loss and the next sequence prediction
- (classification) loss.
- prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
- Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
- before SoftMax).
- """
- loss: torch.FloatTensor | None = None
- prediction_logits: torch.FloatTensor | None = None
- seq_relationship_logits: torch.FloatTensor | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- @auto_docstring
- class MobileBertModel(MobileBertPreTrainedModel):
- """
- https://huggingface.co/papers/2004.02984
- """
- def __init__(self, config, add_pooling_layer=True):
- r"""
- add_pooling_layer (bool, *optional*, defaults to `True`):
- Whether to add a pooling layer
- """
- super().__init__(config)
- self.config = config
- self.gradient_checkpointing = False
- self.embeddings = MobileBertEmbeddings(config)
- self.encoder = MobileBertEncoder(config)
- self.pooler = MobileBertPooler(config) if add_pooling_layer else None
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.embeddings.word_embeddings
- def set_input_embeddings(self, value):
- self.embeddings.word_embeddings = value
- @merge_with_config_defaults
- @capture_outputs
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BaseModelOutputWithPooling:
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- embedding_output = self.embeddings(
- input_ids=input_ids,
- position_ids=position_ids,
- token_type_ids=token_type_ids,
- inputs_embeds=inputs_embeds,
- )
- attention_mask = create_bidirectional_mask(
- config=self.config,
- inputs_embeds=embedding_output,
- attention_mask=attention_mask,
- )
- encoder_outputs = self.encoder(
- embedding_output,
- attention_mask=attention_mask,
- **kwargs,
- )
- sequence_output = encoder_outputs[0]
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
- return BaseModelOutputWithPooling(
- last_hidden_state=sequence_output,
- pooler_output=pooled_output,
- )
- @auto_docstring(
- custom_intro="""
- MobileBert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
- `next sentence prediction (classification)` head.
- """
- )
- class MobileBertForPreTraining(MobileBertPreTrainedModel):
- _tied_weights_keys = {
- "cls.predictions.decoder.bias": "cls.predictions.bias",
- "cls.predictions.decoder.weight": "mobilebert.embeddings.word_embeddings.weight",
- }
- def __init__(self, config):
- super().__init__(config)
- self.mobilebert = MobileBertModel(config)
- self.cls = MobileBertPreTrainingHeads(config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_output_embeddings(self):
- return self.cls.predictions.decoder
- def set_output_embeddings(self, new_embeddings):
- self.cls.predictions.decoder = new_embeddings
- self.cls.predictions.bias = new_embeddings.bias
- def resize_token_embeddings(self, new_num_tokens: int | None = None) -> nn.Embedding:
- # resize dense output embedings at first
- self.cls.predictions.dense = self._get_resized_lm_head(
- self.cls.predictions.dense, new_num_tokens=new_num_tokens, transposed=True
- )
- return super().resize_token_embeddings(new_num_tokens=new_num_tokens)
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- next_sentence_label: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | MobileBertForPreTrainingOutput:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
- config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
- loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
- next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
- (see `input_ids` docstring) Indices should be in `[0, 1]`:
- - 0 indicates sequence B is a continuation of sequence A,
- - 1 indicates sequence B is a random sequence.
- Examples:
- ```python
- >>> from transformers import AutoTokenizer, MobileBertForPreTraining
- >>> import torch
- >>> tokenizer = AutoTokenizer.from_pretrained("google/mobilebert-uncased")
- >>> model = MobileBertForPreTraining.from_pretrained("google/mobilebert-uncased")
- >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)
- >>> # Batch size 1
- >>> outputs = model(input_ids)
- >>> prediction_logits = outputs.prediction_logits
- >>> seq_relationship_logits = outputs.seq_relationship_logits
- ```"""
- outputs = self.mobilebert(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- return_dict=True,
- **kwargs,
- )
- sequence_output, pooled_output = outputs[:2]
- prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
- total_loss = None
- if labels is not None and next_sentence_label is not None:
- loss_fct = CrossEntropyLoss()
- masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
- next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
- total_loss = masked_lm_loss + next_sentence_loss
- return MobileBertForPreTrainingOutput(
- loss=total_loss,
- prediction_logits=prediction_scores,
- seq_relationship_logits=seq_relationship_score,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring
- class MobileBertForMaskedLM(MobileBertPreTrainedModel):
- _tied_weights_keys = {
- "cls.predictions.decoder.bias": "cls.predictions.bias",
- "cls.predictions.decoder.weight": "mobilebert.embeddings.word_embeddings.weight",
- }
- def __init__(self, config):
- super().__init__(config)
- self.mobilebert = MobileBertModel(config, add_pooling_layer=False)
- self.cls = MobileBertOnlyMLMHead(config)
- self.config = config
- # Initialize weights and apply final processing
- self.post_init()
- def get_output_embeddings(self):
- return self.cls.predictions.decoder
- def set_output_embeddings(self, new_embeddings):
- self.cls.predictions.decoder = new_embeddings
- self.cls.predictions.bias = new_embeddings.bias
- def resize_token_embeddings(self, new_num_tokens: int | None = None) -> nn.Embedding:
- # resize dense output embedings at first
- self.cls.predictions.dense = self._get_resized_lm_head(
- self.cls.predictions.dense, new_num_tokens=new_num_tokens, transposed=True
- )
- return super().resize_token_embeddings(new_num_tokens=new_num_tokens)
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | MaskedLMOutput:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
- config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
- loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
- """
- outputs = self.mobilebert(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- return_dict=True,
- **kwargs,
- )
- sequence_output = outputs[0]
- prediction_scores = self.cls(sequence_output)
- masked_lm_loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss() # -100 index = padding token
- masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
- return MaskedLMOutput(
- loss=masked_lm_loss,
- logits=prediction_scores,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- class MobileBertOnlyNSPHead(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.seq_relationship = nn.Linear(config.hidden_size, 2)
- def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
- seq_relationship_score = self.seq_relationship(pooled_output)
- return seq_relationship_score
- @auto_docstring(
- custom_intro="""
- MobileBert Model with a `next sentence prediction (classification)` head on top.
- """
- )
- class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.mobilebert = MobileBertModel(config)
- self.cls = MobileBertOnlyNSPHead(config)
- # Initialize weights and apply final processing
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | NextSentencePredictorOutput:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
- (see `input_ids` docstring) Indices should be in `[0, 1]`.
- - 0 indicates sequence B is a continuation of sequence A,
- - 1 indicates sequence B is a random sequence.
- Examples:
- ```python
- >>> from transformers import AutoTokenizer, MobileBertForNextSentencePrediction
- >>> import torch
- >>> tokenizer = AutoTokenizer.from_pretrained("google/mobilebert-uncased")
- >>> model = MobileBertForNextSentencePrediction.from_pretrained("google/mobilebert-uncased")
- >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
- >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
- >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
- >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
- >>> loss = outputs.loss
- >>> logits = outputs.logits
- ```"""
- outputs = self.mobilebert(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- return_dict=True,
- **kwargs,
- )
- pooled_output = outputs[1]
- seq_relationship_score = self.cls(pooled_output)
- next_sentence_loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), labels.view(-1))
- return NextSentencePredictorOutput(
- loss=next_sentence_loss,
- logits=seq_relationship_score,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring(
- custom_intro="""
- MobileBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the
- pooled output) e.g. for GLUE tasks.
- """
- )
- # Copied from transformers.models.bert.modeling_bert.BertForSequenceClassification with Bert->MobileBert all-casing
- class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.config = config
- self.mobilebert = MobileBertModel(config)
- classifier_dropout = (
- config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
- )
- self.dropout = nn.Dropout(classifier_dropout)
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- outputs = self.mobilebert(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- return_dict=True,
- **kwargs,
- )
- pooled_output = outputs[1]
- pooled_output = self.dropout(pooled_output)
- logits = self.classifier(pooled_output)
- loss = None
- if labels is not None:
- if self.config.problem_type is None:
- if self.num_labels == 1:
- self.config.problem_type = "regression"
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
- self.config.problem_type = "single_label_classification"
- else:
- self.config.problem_type = "multi_label_classification"
- if self.config.problem_type == "regression":
- loss_fct = MSELoss()
- if self.num_labels == 1:
- loss = loss_fct(logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(logits, labels)
- elif self.config.problem_type == "single_label_classification":
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- elif self.config.problem_type == "multi_label_classification":
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(logits, labels)
- return SequenceClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring
- # Copied from transformers.models.bert.modeling_bert.BertForQuestionAnswering with Bert->MobileBert all-casing
- class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.mobilebert = MobileBertModel(config, add_pooling_layer=False)
- self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- start_positions: torch.Tensor | None = None,
- end_positions: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor] | QuestionAnsweringModelOutput:
- outputs = self.mobilebert(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- return_dict=True,
- **kwargs,
- )
- sequence_output = outputs[0]
- logits = self.qa_outputs(sequence_output)
- start_logits, end_logits = logits.split(1, dim=-1)
- start_logits = start_logits.squeeze(-1).contiguous()
- end_logits = end_logits.squeeze(-1).contiguous()
- total_loss = None
- if start_positions is not None and end_positions is not None:
- # If we are on multi-GPU, split add a dimension
- if len(start_positions.size()) > 1:
- start_positions = start_positions.squeeze(-1)
- if len(end_positions.size()) > 1:
- end_positions = end_positions.squeeze(-1)
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
- ignored_index = start_logits.size(1)
- start_positions = start_positions.clamp(0, ignored_index)
- end_positions = end_positions.clamp(0, ignored_index)
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
- start_loss = loss_fct(start_logits, start_positions)
- end_loss = loss_fct(end_logits, end_positions)
- total_loss = (start_loss + end_loss) / 2
- return QuestionAnsweringModelOutput(
- loss=total_loss,
- start_logits=start_logits,
- end_logits=end_logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring
- # Copied from transformers.models.bert.modeling_bert.BertForMultipleChoice with Bert->MobileBert all-casing
- class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.mobilebert = MobileBertModel(config)
- classifier_dropout = (
- config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
- )
- self.dropout = nn.Dropout(classifier_dropout)
- self.classifier = nn.Linear(config.hidden_size, 1)
- # Initialize weights and apply final processing
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor] | MultipleChoiceModelOutput:
- r"""
- input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
- Indices of input sequence tokens in the vocabulary.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
- Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
- 1]`:
- - 0 corresponds to a *sentence A* token,
- - 1 corresponds to a *sentence B* token.
- [What are token type IDs?](../glossary#token-type-ids)
- position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.max_position_embeddings - 1]`.
- [What are position IDs?](../glossary#position-ids)
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
- num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
- `input_ids` above)
- """
- num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
- input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
- attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
- token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
- position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
- inputs_embeds = (
- inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
- if inputs_embeds is not None
- else None
- )
- outputs = self.mobilebert(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- return_dict=True,
- **kwargs,
- )
- pooled_output = outputs[1]
- pooled_output = self.dropout(pooled_output)
- logits = self.classifier(pooled_output)
- reshaped_logits = logits.view(-1, num_choices)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(reshaped_logits, labels)
- return MultipleChoiceModelOutput(
- loss=loss,
- logits=reshaped_logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring
- # Copied from transformers.models.bert.modeling_bert.BertForTokenClassification with Bert->MobileBert all-casing
- class MobileBertForTokenClassification(MobileBertPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.mobilebert = MobileBertModel(config, add_pooling_layer=False)
- classifier_dropout = (
- config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
- )
- self.dropout = nn.Dropout(classifier_dropout)
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor] | TokenClassifierOutput:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
- """
- outputs = self.mobilebert(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- return_dict=True,
- **kwargs,
- )
- sequence_output = outputs[0]
- sequence_output = self.dropout(sequence_output)
- logits = self.classifier(sequence_output)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- return TokenClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- __all__ = [
- "MobileBertForMaskedLM",
- "MobileBertForMultipleChoice",
- "MobileBertForNextSentencePrediction",
- "MobileBertForPreTraining",
- "MobileBertForQuestionAnswering",
- "MobileBertForSequenceClassification",
- "MobileBertForTokenClassification",
- "MobileBertLayer",
- "MobileBertModel",
- "MobileBertPreTrainedModel",
- ]
|