| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070 |
- # Copyright 2021 Google Research and The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch FNet model."""
- from dataclasses import dataclass
- from functools import partial
- import torch
- from torch import nn
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
- from ... import initialization as init
- from ...utils import auto_docstring, is_scipy_available
- if is_scipy_available():
- from scipy import linalg
- from ...activations import ACT2FN
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import (
- BaseModelOutput,
- BaseModelOutputWithPooling,
- MaskedLMOutput,
- ModelOutput,
- MultipleChoiceModelOutput,
- NextSentencePredictorOutput,
- QuestionAnsweringModelOutput,
- SequenceClassifierOutput,
- TokenClassifierOutput,
- )
- from ...modeling_utils import PreTrainedModel
- from ...pytorch_utils import apply_chunking_to_forward
- from ...utils import logging
- from .configuration_fnet import FNetConfig
- logger = logging.get_logger(__name__)
- # Adapted from https://github.com/google-research/google-research/blob/master/f_net/fourier.py
- def _two_dim_matmul(x, matrix_dim_one, matrix_dim_two):
- """Applies 2D matrix multiplication to 3D input arrays."""
- seq_length = x.shape[1]
- matrix_dim_one = matrix_dim_one[:seq_length, :seq_length]
- x = x.type(torch.complex64)
- return torch.einsum("bij,jk,ni->bnk", x, matrix_dim_two, matrix_dim_one)
- # # Adapted from https://github.com/google-research/google-research/blob/master/f_net/fourier.py
- def two_dim_matmul(x, matrix_dim_one, matrix_dim_two):
- return _two_dim_matmul(x, matrix_dim_one, matrix_dim_two)
- # Adapted from https://github.com/google-research/google-research/blob/master/f_net/fourier.py
- def fftn(x):
- """
- Applies n-dimensional Fast Fourier Transform (FFT) to input array.
- Args:
- x: Input n-dimensional array.
- Returns:
- n-dimensional Fourier transform of input n-dimensional array.
- """
- out = x
- for axis in reversed(range(x.ndim)[1:]): # We don't need to apply FFT to last axis
- out = torch.fft.fft(out, axis=axis)
- return out
- class FNetEmbeddings(nn.Module):
- """Construct the embeddings from word, position and token_type embeddings."""
- def __init__(self, config):
- super().__init__()
- self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
- self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
- self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- # NOTE: This is the project layer and will be needed. The original code allows for different embedding and different model dimensions.
- self.projection = nn.Linear(config.hidden_size, config.hidden_size)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- # position_ids (1, len position emb) is contiguous in memory and exported when serialized
- self.register_buffer(
- "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
- )
- self.register_buffer(
- "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
- )
- def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
- if input_ids is not None:
- input_shape = input_ids.size()
- else:
- input_shape = inputs_embeds.size()[:-1]
- seq_length = input_shape[1]
- if position_ids is None:
- position_ids = self.position_ids[:, :seq_length]
- # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
- # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
- # issue #5664
- if token_type_ids is None:
- if hasattr(self, "token_type_ids"):
- buffered_token_type_ids = self.token_type_ids[:, :seq_length]
- buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
- token_type_ids = buffered_token_type_ids_expanded
- else:
- token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
- if inputs_embeds is None:
- inputs_embeds = self.word_embeddings(input_ids)
- token_type_embeddings = self.token_type_embeddings(token_type_ids)
- embeddings = inputs_embeds + token_type_embeddings
- position_embeddings = self.position_embeddings(position_ids)
- embeddings += position_embeddings
- embeddings = self.LayerNorm(embeddings)
- embeddings = self.projection(embeddings)
- embeddings = self.dropout(embeddings)
- return embeddings
- class FNetBasicFourierTransform(nn.Module):
- def __init__(self, config):
- super().__init__()
- self._init_fourier_transform(config)
- def _init_fourier_transform(self, config):
- if not config.use_tpu_fourier_optimizations:
- self.fourier_transform = partial(torch.fft.fftn, dim=(1, 2))
- elif config.max_position_embeddings <= 4096:
- if is_scipy_available():
- self.register_buffer(
- "dft_mat_hidden", torch.tensor(linalg.dft(config.hidden_size), dtype=torch.complex64)
- )
- self.register_buffer(
- "dft_mat_seq", torch.tensor(linalg.dft(config.tpu_short_seq_length), dtype=torch.complex64)
- )
- self.fourier_transform = partial(
- two_dim_matmul, matrix_dim_one=self.dft_mat_seq, matrix_dim_two=self.dft_mat_hidden
- )
- else:
- logging.warning(
- "SciPy is needed for DFT matrix calculation and is not found. Using TPU optimized fast fourier"
- " transform instead."
- )
- self.fourier_transform = fftn
- else:
- self.fourier_transform = fftn
- def forward(self, hidden_states):
- # NOTE: We do not use torch.vmap as it is not integrated into PyTorch stable versions.
- # Interested users can modify the code to use vmap from the nightly versions, getting the vmap from here:
- # https://pytorch.org/docs/master/generated/torch.vmap.html. Note that fourier transform methods will need
- # change accordingly.
- outputs = self.fourier_transform(hidden_states).real
- return (outputs,)
- class FNetBasicOutput(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- def forward(self, hidden_states, input_tensor):
- hidden_states = self.LayerNorm(input_tensor + hidden_states)
- return hidden_states
- class FNetFourierTransform(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.self = FNetBasicFourierTransform(config)
- self.output = FNetBasicOutput(config)
- def forward(self, hidden_states):
- self_outputs = self.self(hidden_states)
- fourier_output = self.output(self_outputs[0], hidden_states)
- outputs = (fourier_output,)
- return outputs
- # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->FNet
- class FNetIntermediate(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
- if isinstance(config.hidden_act, str):
- self.intermediate_act_fn = ACT2FN[config.hidden_act]
- else:
- self.intermediate_act_fn = config.hidden_act
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.intermediate_act_fn(hidden_states)
- return hidden_states
- # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->FNet
- class FNetOutput(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
- class FNetLayer(GradientCheckpointingLayer):
- def __init__(self, config):
- super().__init__()
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
- self.seq_len_dim = 1 # The dimension which has the sequence length
- self.fourier = FNetFourierTransform(config)
- self.intermediate = FNetIntermediate(config)
- self.output = FNetOutput(config)
- def forward(self, hidden_states):
- self_fourier_outputs = self.fourier(hidden_states)
- fourier_output = self_fourier_outputs[0]
- layer_output = apply_chunking_to_forward(
- self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, fourier_output
- )
- outputs = (layer_output,)
- return outputs
- def feed_forward_chunk(self, fourier_output):
- intermediate_output = self.intermediate(fourier_output)
- layer_output = self.output(intermediate_output, fourier_output)
- return layer_output
- class FNetEncoder(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.layer = nn.ModuleList([FNetLayer(config) for _ in range(config.num_hidden_layers)])
- self.gradient_checkpointing = False
- def forward(self, hidden_states, output_hidden_states=False, return_dict=True):
- all_hidden_states = () if output_hidden_states else None
- for i, layer_module in enumerate(self.layer):
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- layer_outputs = layer_module(hidden_states)
- hidden_states = layer_outputs[0]
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- if not return_dict:
- return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
- return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
- # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->FNet
- class FNetPooler(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- self.activation = nn.Tanh()
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- # We "pool" the model by simply taking the hidden state corresponding
- # to the first token.
- first_token_tensor = hidden_states[:, 0]
- pooled_output = self.dense(first_token_tensor)
- pooled_output = self.activation(pooled_output)
- return pooled_output
- # Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->FNet
- class FNetPredictionHeadTransform(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- if isinstance(config.hidden_act, str):
- self.transform_act_fn = ACT2FN[config.hidden_act]
- else:
- self.transform_act_fn = config.hidden_act
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.transform_act_fn(hidden_states)
- hidden_states = self.LayerNorm(hidden_states)
- return hidden_states
- class FNetLMPredictionHead(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.transform = FNetPredictionHeadTransform(config)
- self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
- self.bias = nn.Parameter(torch.zeros(config.vocab_size))
- def forward(self, hidden_states):
- hidden_states = self.transform(hidden_states)
- hidden_states = self.decoder(hidden_states)
- return hidden_states
- class FNetOnlyMLMHead(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.predictions = FNetLMPredictionHead(config)
- def forward(self, sequence_output):
- prediction_scores = self.predictions(sequence_output)
- return prediction_scores
- # Copied from transformers.models.bert.modeling_bert.BertOnlyNSPHead with Bert->FNet
- class FNetOnlyNSPHead(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.seq_relationship = nn.Linear(config.hidden_size, 2)
- def forward(self, pooled_output):
- seq_relationship_score = self.seq_relationship(pooled_output)
- return seq_relationship_score
- # Copied from transformers.models.bert.modeling_bert.BertPreTrainingHeads with Bert->FNet
- class FNetPreTrainingHeads(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.predictions = FNetLMPredictionHead(config)
- self.seq_relationship = nn.Linear(config.hidden_size, 2)
- def forward(self, sequence_output, pooled_output):
- prediction_scores = self.predictions(sequence_output)
- seq_relationship_score = self.seq_relationship(pooled_output)
- return prediction_scores, seq_relationship_score
- @auto_docstring
- class FNetPreTrainedModel(PreTrainedModel):
- config: FNetConfig
- base_model_prefix = "fnet"
- supports_gradient_checkpointing = True
- def _init_weights(self, module):
- super()._init_weights(module)
- if isinstance(module, FNetEmbeddings):
- init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
- init.zeros_(module.token_type_ids)
- @dataclass
- @auto_docstring(
- custom_intro="""
- Output type of [`FNetForPreTraining`].
- """
- )
- class FNetForPreTrainingOutput(ModelOutput):
- r"""
- loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
- Total loss as the sum of the masked language modeling loss and the next sequence prediction
- (classification) loss.
- prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
- Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
- before SoftMax).
- """
- loss: torch.FloatTensor | None = None
- prediction_logits: torch.FloatTensor | None = None
- seq_relationship_logits: torch.FloatTensor | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- @auto_docstring
- class FNetModel(FNetPreTrainedModel):
- """
- The model can behave as an encoder, following the architecture described in [FNet: Mixing Tokens with Fourier
- Transforms](https://huggingface.co/papers/2105.03824) by James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon.
- """
- def __init__(self, config, add_pooling_layer=True):
- r"""
- add_pooling_layer (bool, *optional*, defaults to `True`):
- Whether to add a pooling layer
- """
- super().__init__(config)
- self.config = config
- self.embeddings = FNetEmbeddings(config)
- self.encoder = FNetEncoder(config)
- self.pooler = FNetPooler(config) if add_pooling_layer else None
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.embeddings.word_embeddings
- def set_input_embeddings(self, value):
- self.embeddings.word_embeddings = value
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | BaseModelOutput:
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
- elif input_ids is not None:
- input_shape = input_ids.size()
- batch_size, seq_length = input_shape
- elif inputs_embeds is not None:
- input_shape = inputs_embeds.size()[:-1]
- batch_size, seq_length = input_shape
- else:
- raise ValueError("You have to specify either input_ids or inputs_embeds")
- if (
- self.config.use_tpu_fourier_optimizations
- and seq_length <= 4096
- and self.config.tpu_short_seq_length != seq_length
- ):
- raise ValueError(
- "The `tpu_short_seq_length` in FNetConfig should be set equal to the sequence length being passed to"
- " the model when using TPU optimizations."
- )
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- if token_type_ids is None:
- if hasattr(self.embeddings, "token_type_ids"):
- buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
- buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
- token_type_ids = buffered_token_type_ids_expanded
- else:
- token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
- embedding_output = self.embeddings(
- input_ids=input_ids,
- position_ids=position_ids,
- token_type_ids=token_type_ids,
- inputs_embeds=inputs_embeds,
- )
- encoder_outputs = self.encoder(
- embedding_output,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = encoder_outputs[0]
- pooler_output = self.pooler(sequence_output) if self.pooler is not None else None
- if not return_dict:
- return (sequence_output, pooler_output) + encoder_outputs[1:]
- return BaseModelOutputWithPooling(
- last_hidden_state=sequence_output,
- pooler_output=pooler_output,
- hidden_states=encoder_outputs.hidden_states,
- )
- @auto_docstring(
- custom_intro="""
- FNet Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
- sentence prediction (classification)` head.
- """
- )
- class FNetForPreTraining(FNetPreTrainedModel):
- _tied_weights_keys = {
- "cls.predictions.decoder.bias": "cls.predictions.bias",
- "cls.predictions.decoder.weight": "fnet.embeddings.word_embeddings.weight",
- }
- def __init__(self, config):
- super().__init__(config)
- self.fnet = FNetModel(config)
- self.cls = FNetPreTrainingHeads(config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_output_embeddings(self):
- return self.cls.predictions.decoder
- def set_output_embeddings(self, new_embeddings):
- self.cls.predictions.decoder = new_embeddings
- self.cls.predictions.bias = new_embeddings.bias
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- next_sentence_label: torch.Tensor | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | FNetForPreTrainingOutput:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
- config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
- loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
- next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
- (see `input_ids` docstring) Indices should be in `[0, 1]`:
- - 0 indicates sequence B is a continuation of sequence A,
- - 1 indicates sequence B is a random sequence.
- Example:
- ```python
- >>> from transformers import AutoTokenizer, FNetForPreTraining
- >>> import torch
- >>> tokenizer = AutoTokenizer.from_pretrained("google/fnet-base")
- >>> model = FNetForPreTraining.from_pretrained("google/fnet-base")
- >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
- >>> outputs = model(**inputs)
- >>> prediction_logits = outputs.prediction_logits
- >>> seq_relationship_logits = outputs.seq_relationship_logits
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.fnet(
- input_ids,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output, pooled_output = outputs[:2]
- prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
- total_loss = None
- if labels is not None and next_sentence_label is not None:
- loss_fct = CrossEntropyLoss()
- masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
- next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
- total_loss = masked_lm_loss + next_sentence_loss
- if not return_dict:
- output = (prediction_scores, seq_relationship_score) + outputs[2:]
- return ((total_loss,) + output) if total_loss is not None else output
- return FNetForPreTrainingOutput(
- loss=total_loss,
- prediction_logits=prediction_scores,
- seq_relationship_logits=seq_relationship_score,
- hidden_states=outputs.hidden_states,
- )
- @auto_docstring
- class FNetForMaskedLM(FNetPreTrainedModel):
- _tied_weights_keys = {
- "cls.predictions.decoder.bias": "cls.predictions.bias",
- "cls.predictions.decoder.weight": "fnet.embeddings.word_embeddings.weight",
- }
- def __init__(self, config):
- super().__init__(config)
- self.fnet = FNetModel(config)
- self.cls = FNetOnlyMLMHead(config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_output_embeddings(self):
- return self.cls.predictions.decoder
- def set_output_embeddings(self, new_embeddings):
- self.cls.predictions.decoder = new_embeddings
- self.cls.predictions.bias = new_embeddings.bias
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | MaskedLMOutput:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
- config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
- loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.fnet(
- input_ids,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = outputs[0]
- prediction_scores = self.cls(sequence_output)
- masked_lm_loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss() # -100 index = padding token
- masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
- if not return_dict:
- output = (prediction_scores,) + outputs[2:]
- return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
- return MaskedLMOutput(loss=masked_lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states)
- @auto_docstring(
- custom_intro="""
- FNet Model with a `next sentence prediction (classification)` head on top.
- """
- )
- class FNetForNextSentencePrediction(FNetPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.fnet = FNetModel(config)
- self.cls = FNetOnlyNSPHead(config)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | NextSentencePredictorOutput:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
- (see `input_ids` docstring). Indices should be in `[0, 1]`:
- - 0 indicates sequence B is a continuation of sequence A,
- - 1 indicates sequence B is a random sequence.
- Example:
- ```python
- >>> from transformers import AutoTokenizer, FNetForNextSentencePrediction
- >>> import torch
- >>> tokenizer = AutoTokenizer.from_pretrained("google/fnet-base")
- >>> model = FNetForNextSentencePrediction.from_pretrained("google/fnet-base")
- >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
- >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
- >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
- >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
- >>> logits = outputs.logits
- >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.fnet(
- input_ids,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- pooled_output = outputs[1]
- seq_relationship_scores = self.cls(pooled_output)
- next_sentence_loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
- if not return_dict:
- output = (seq_relationship_scores,) + outputs[2:]
- return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
- return NextSentencePredictorOutput(
- loss=next_sentence_loss,
- logits=seq_relationship_scores,
- hidden_states=outputs.hidden_states,
- )
- @auto_docstring(
- custom_intro="""
- FNet Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
- output) e.g. for GLUE tasks.
- """
- )
- class FNetForSequenceClassification(FNetPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.fnet = FNetModel(config)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | SequenceClassifierOutput:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.fnet(
- input_ids,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- pooled_output = outputs[1]
- pooled_output = self.dropout(pooled_output)
- logits = self.classifier(pooled_output)
- loss = None
- if labels is not None:
- if self.config.problem_type is None:
- if self.num_labels == 1:
- self.config.problem_type = "regression"
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
- self.config.problem_type = "single_label_classification"
- else:
- self.config.problem_type = "multi_label_classification"
- if self.config.problem_type == "regression":
- loss_fct = MSELoss()
- if self.num_labels == 1:
- loss = loss_fct(logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(logits, labels)
- elif self.config.problem_type == "single_label_classification":
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- elif self.config.problem_type == "multi_label_classification":
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(logits, labels)
- if not return_dict:
- output = (logits,) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return SequenceClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
- @auto_docstring
- class FNetForMultipleChoice(FNetPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.fnet = FNetModel(config)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- self.classifier = nn.Linear(config.hidden_size, 1)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | MultipleChoiceModelOutput:
- r"""
- input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
- Indices of input sequence tokens in the vocabulary.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
- Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
- 1]`:
- - 0 corresponds to a *sentence A* token,
- - 1 corresponds to a *sentence B* token.
- [What are token type IDs?](../glossary#token-type-ids)
- position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.max_position_embeddings - 1]`.
- [What are position IDs?](../glossary#position-ids)
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
- model's internal embedding lookup matrix.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
- num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
- `input_ids` above)
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
- input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
- token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
- position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
- inputs_embeds = (
- inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
- if inputs_embeds is not None
- else None
- )
- outputs = self.fnet(
- input_ids,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- pooled_output = outputs[1]
- pooled_output = self.dropout(pooled_output)
- logits = self.classifier(pooled_output)
- reshaped_logits = logits.view(-1, num_choices)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(reshaped_logits, labels)
- if not return_dict:
- output = (reshaped_logits,) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return MultipleChoiceModelOutput(loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states)
- @auto_docstring
- class FNetForTokenClassification(FNetPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.fnet = FNetModel(config)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | TokenClassifierOutput:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.fnet(
- input_ids,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = outputs[0]
- sequence_output = self.dropout(sequence_output)
- logits = self.classifier(sequence_output)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- # Only keep active parts of the loss
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- if not return_dict:
- output = (logits,) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
- @auto_docstring
- class FNetForQuestionAnswering(FNetPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.fnet = FNetModel(config)
- self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- start_positions: torch.Tensor | None = None,
- end_positions: torch.Tensor | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | QuestionAnsweringModelOutput:
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.fnet(
- input_ids,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = outputs[0]
- logits = self.qa_outputs(sequence_output)
- start_logits, end_logits = logits.split(1, dim=-1)
- start_logits = start_logits.squeeze(-1).contiguous()
- end_logits = end_logits.squeeze(-1).contiguous()
- total_loss = None
- if start_positions is not None and end_positions is not None:
- # If we are on multi-GPU, split add a dimension
- if len(start_positions.size()) > 1:
- start_positions = start_positions.squeeze(-1)
- if len(end_positions.size()) > 1:
- end_positions = end_positions.squeeze(-1)
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
- ignored_index = start_logits.size(1)
- start_positions = start_positions.clamp(0, ignored_index)
- end_positions = end_positions.clamp(0, ignored_index)
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
- start_loss = loss_fct(start_logits, start_positions)
- end_loss = loss_fct(end_logits, end_positions)
- total_loss = (start_loss + end_loss) / 2
- if not return_dict:
- output = (start_logits, end_logits) + outputs[2:]
- return ((total_loss,) + output) if total_loss is not None else output
- return QuestionAnsweringModelOutput(
- loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=outputs.hidden_states
- )
- __all__ = [
- "FNetForMaskedLM",
- "FNetForMultipleChoice",
- "FNetForNextSentencePrediction",
- "FNetForPreTraining",
- "FNetForQuestionAnswering",
- "FNetForSequenceClassification",
- "FNetForTokenClassification",
- "FNetLayer",
- "FNetModel",
- "FNetPreTrainedModel",
- ]
|