modeling_fnet.py 42 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070
  1. # Copyright 2021 Google Research and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch FNet model."""
  15. from dataclasses import dataclass
  16. from functools import partial
  17. import torch
  18. from torch import nn
  19. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  20. from ... import initialization as init
  21. from ...utils import auto_docstring, is_scipy_available
  22. if is_scipy_available():
  23. from scipy import linalg
  24. from ...activations import ACT2FN
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_outputs import (
  27. BaseModelOutput,
  28. BaseModelOutputWithPooling,
  29. MaskedLMOutput,
  30. ModelOutput,
  31. MultipleChoiceModelOutput,
  32. NextSentencePredictorOutput,
  33. QuestionAnsweringModelOutput,
  34. SequenceClassifierOutput,
  35. TokenClassifierOutput,
  36. )
  37. from ...modeling_utils import PreTrainedModel
  38. from ...pytorch_utils import apply_chunking_to_forward
  39. from ...utils import logging
  40. from .configuration_fnet import FNetConfig
  41. logger = logging.get_logger(__name__)
  42. # Adapted from https://github.com/google-research/google-research/blob/master/f_net/fourier.py
  43. def _two_dim_matmul(x, matrix_dim_one, matrix_dim_two):
  44. """Applies 2D matrix multiplication to 3D input arrays."""
  45. seq_length = x.shape[1]
  46. matrix_dim_one = matrix_dim_one[:seq_length, :seq_length]
  47. x = x.type(torch.complex64)
  48. return torch.einsum("bij,jk,ni->bnk", x, matrix_dim_two, matrix_dim_one)
  49. # # Adapted from https://github.com/google-research/google-research/blob/master/f_net/fourier.py
  50. def two_dim_matmul(x, matrix_dim_one, matrix_dim_two):
  51. return _two_dim_matmul(x, matrix_dim_one, matrix_dim_two)
  52. # Adapted from https://github.com/google-research/google-research/blob/master/f_net/fourier.py
  53. def fftn(x):
  54. """
  55. Applies n-dimensional Fast Fourier Transform (FFT) to input array.
  56. Args:
  57. x: Input n-dimensional array.
  58. Returns:
  59. n-dimensional Fourier transform of input n-dimensional array.
  60. """
  61. out = x
  62. for axis in reversed(range(x.ndim)[1:]): # We don't need to apply FFT to last axis
  63. out = torch.fft.fft(out, axis=axis)
  64. return out
  65. class FNetEmbeddings(nn.Module):
  66. """Construct the embeddings from word, position and token_type embeddings."""
  67. def __init__(self, config):
  68. super().__init__()
  69. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  70. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  71. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  72. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  73. # NOTE: This is the project layer and will be needed. The original code allows for different embedding and different model dimensions.
  74. self.projection = nn.Linear(config.hidden_size, config.hidden_size)
  75. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  76. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  77. self.register_buffer(
  78. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  79. )
  80. self.register_buffer(
  81. "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
  82. )
  83. def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
  84. if input_ids is not None:
  85. input_shape = input_ids.size()
  86. else:
  87. input_shape = inputs_embeds.size()[:-1]
  88. seq_length = input_shape[1]
  89. if position_ids is None:
  90. position_ids = self.position_ids[:, :seq_length]
  91. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  92. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  93. # issue #5664
  94. if token_type_ids is None:
  95. if hasattr(self, "token_type_ids"):
  96. buffered_token_type_ids = self.token_type_ids[:, :seq_length]
  97. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
  98. token_type_ids = buffered_token_type_ids_expanded
  99. else:
  100. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  101. if inputs_embeds is None:
  102. inputs_embeds = self.word_embeddings(input_ids)
  103. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  104. embeddings = inputs_embeds + token_type_embeddings
  105. position_embeddings = self.position_embeddings(position_ids)
  106. embeddings += position_embeddings
  107. embeddings = self.LayerNorm(embeddings)
  108. embeddings = self.projection(embeddings)
  109. embeddings = self.dropout(embeddings)
  110. return embeddings
  111. class FNetBasicFourierTransform(nn.Module):
  112. def __init__(self, config):
  113. super().__init__()
  114. self._init_fourier_transform(config)
  115. def _init_fourier_transform(self, config):
  116. if not config.use_tpu_fourier_optimizations:
  117. self.fourier_transform = partial(torch.fft.fftn, dim=(1, 2))
  118. elif config.max_position_embeddings <= 4096:
  119. if is_scipy_available():
  120. self.register_buffer(
  121. "dft_mat_hidden", torch.tensor(linalg.dft(config.hidden_size), dtype=torch.complex64)
  122. )
  123. self.register_buffer(
  124. "dft_mat_seq", torch.tensor(linalg.dft(config.tpu_short_seq_length), dtype=torch.complex64)
  125. )
  126. self.fourier_transform = partial(
  127. two_dim_matmul, matrix_dim_one=self.dft_mat_seq, matrix_dim_two=self.dft_mat_hidden
  128. )
  129. else:
  130. logging.warning(
  131. "SciPy is needed for DFT matrix calculation and is not found. Using TPU optimized fast fourier"
  132. " transform instead."
  133. )
  134. self.fourier_transform = fftn
  135. else:
  136. self.fourier_transform = fftn
  137. def forward(self, hidden_states):
  138. # NOTE: We do not use torch.vmap as it is not integrated into PyTorch stable versions.
  139. # Interested users can modify the code to use vmap from the nightly versions, getting the vmap from here:
  140. # https://pytorch.org/docs/master/generated/torch.vmap.html. Note that fourier transform methods will need
  141. # change accordingly.
  142. outputs = self.fourier_transform(hidden_states).real
  143. return (outputs,)
  144. class FNetBasicOutput(nn.Module):
  145. def __init__(self, config):
  146. super().__init__()
  147. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  148. def forward(self, hidden_states, input_tensor):
  149. hidden_states = self.LayerNorm(input_tensor + hidden_states)
  150. return hidden_states
  151. class FNetFourierTransform(nn.Module):
  152. def __init__(self, config):
  153. super().__init__()
  154. self.self = FNetBasicFourierTransform(config)
  155. self.output = FNetBasicOutput(config)
  156. def forward(self, hidden_states):
  157. self_outputs = self.self(hidden_states)
  158. fourier_output = self.output(self_outputs[0], hidden_states)
  159. outputs = (fourier_output,)
  160. return outputs
  161. # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->FNet
  162. class FNetIntermediate(nn.Module):
  163. def __init__(self, config):
  164. super().__init__()
  165. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  166. if isinstance(config.hidden_act, str):
  167. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  168. else:
  169. self.intermediate_act_fn = config.hidden_act
  170. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  171. hidden_states = self.dense(hidden_states)
  172. hidden_states = self.intermediate_act_fn(hidden_states)
  173. return hidden_states
  174. # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->FNet
  175. class FNetOutput(nn.Module):
  176. def __init__(self, config):
  177. super().__init__()
  178. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  179. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  180. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  181. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  182. hidden_states = self.dense(hidden_states)
  183. hidden_states = self.dropout(hidden_states)
  184. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  185. return hidden_states
  186. class FNetLayer(GradientCheckpointingLayer):
  187. def __init__(self, config):
  188. super().__init__()
  189. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  190. self.seq_len_dim = 1 # The dimension which has the sequence length
  191. self.fourier = FNetFourierTransform(config)
  192. self.intermediate = FNetIntermediate(config)
  193. self.output = FNetOutput(config)
  194. def forward(self, hidden_states):
  195. self_fourier_outputs = self.fourier(hidden_states)
  196. fourier_output = self_fourier_outputs[0]
  197. layer_output = apply_chunking_to_forward(
  198. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, fourier_output
  199. )
  200. outputs = (layer_output,)
  201. return outputs
  202. def feed_forward_chunk(self, fourier_output):
  203. intermediate_output = self.intermediate(fourier_output)
  204. layer_output = self.output(intermediate_output, fourier_output)
  205. return layer_output
  206. class FNetEncoder(nn.Module):
  207. def __init__(self, config):
  208. super().__init__()
  209. self.config = config
  210. self.layer = nn.ModuleList([FNetLayer(config) for _ in range(config.num_hidden_layers)])
  211. self.gradient_checkpointing = False
  212. def forward(self, hidden_states, output_hidden_states=False, return_dict=True):
  213. all_hidden_states = () if output_hidden_states else None
  214. for i, layer_module in enumerate(self.layer):
  215. if output_hidden_states:
  216. all_hidden_states = all_hidden_states + (hidden_states,)
  217. layer_outputs = layer_module(hidden_states)
  218. hidden_states = layer_outputs[0]
  219. if output_hidden_states:
  220. all_hidden_states = all_hidden_states + (hidden_states,)
  221. if not return_dict:
  222. return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
  223. return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
  224. # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->FNet
  225. class FNetPooler(nn.Module):
  226. def __init__(self, config):
  227. super().__init__()
  228. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  229. self.activation = nn.Tanh()
  230. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  231. # We "pool" the model by simply taking the hidden state corresponding
  232. # to the first token.
  233. first_token_tensor = hidden_states[:, 0]
  234. pooled_output = self.dense(first_token_tensor)
  235. pooled_output = self.activation(pooled_output)
  236. return pooled_output
  237. # Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->FNet
  238. class FNetPredictionHeadTransform(nn.Module):
  239. def __init__(self, config):
  240. super().__init__()
  241. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  242. if isinstance(config.hidden_act, str):
  243. self.transform_act_fn = ACT2FN[config.hidden_act]
  244. else:
  245. self.transform_act_fn = config.hidden_act
  246. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  247. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  248. hidden_states = self.dense(hidden_states)
  249. hidden_states = self.transform_act_fn(hidden_states)
  250. hidden_states = self.LayerNorm(hidden_states)
  251. return hidden_states
  252. class FNetLMPredictionHead(nn.Module):
  253. def __init__(self, config):
  254. super().__init__()
  255. self.transform = FNetPredictionHeadTransform(config)
  256. self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
  257. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  258. def forward(self, hidden_states):
  259. hidden_states = self.transform(hidden_states)
  260. hidden_states = self.decoder(hidden_states)
  261. return hidden_states
  262. class FNetOnlyMLMHead(nn.Module):
  263. def __init__(self, config):
  264. super().__init__()
  265. self.predictions = FNetLMPredictionHead(config)
  266. def forward(self, sequence_output):
  267. prediction_scores = self.predictions(sequence_output)
  268. return prediction_scores
  269. # Copied from transformers.models.bert.modeling_bert.BertOnlyNSPHead with Bert->FNet
  270. class FNetOnlyNSPHead(nn.Module):
  271. def __init__(self, config):
  272. super().__init__()
  273. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  274. def forward(self, pooled_output):
  275. seq_relationship_score = self.seq_relationship(pooled_output)
  276. return seq_relationship_score
  277. # Copied from transformers.models.bert.modeling_bert.BertPreTrainingHeads with Bert->FNet
  278. class FNetPreTrainingHeads(nn.Module):
  279. def __init__(self, config):
  280. super().__init__()
  281. self.predictions = FNetLMPredictionHead(config)
  282. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  283. def forward(self, sequence_output, pooled_output):
  284. prediction_scores = self.predictions(sequence_output)
  285. seq_relationship_score = self.seq_relationship(pooled_output)
  286. return prediction_scores, seq_relationship_score
  287. @auto_docstring
  288. class FNetPreTrainedModel(PreTrainedModel):
  289. config: FNetConfig
  290. base_model_prefix = "fnet"
  291. supports_gradient_checkpointing = True
  292. def _init_weights(self, module):
  293. super()._init_weights(module)
  294. if isinstance(module, FNetEmbeddings):
  295. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  296. init.zeros_(module.token_type_ids)
  297. @dataclass
  298. @auto_docstring(
  299. custom_intro="""
  300. Output type of [`FNetForPreTraining`].
  301. """
  302. )
  303. class FNetForPreTrainingOutput(ModelOutput):
  304. r"""
  305. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  306. Total loss as the sum of the masked language modeling loss and the next sequence prediction
  307. (classification) loss.
  308. prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  309. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  310. seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
  311. Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
  312. before SoftMax).
  313. """
  314. loss: torch.FloatTensor | None = None
  315. prediction_logits: torch.FloatTensor | None = None
  316. seq_relationship_logits: torch.FloatTensor | None = None
  317. hidden_states: tuple[torch.FloatTensor] | None = None
  318. @auto_docstring
  319. class FNetModel(FNetPreTrainedModel):
  320. """
  321. The model can behave as an encoder, following the architecture described in [FNet: Mixing Tokens with Fourier
  322. Transforms](https://huggingface.co/papers/2105.03824) by James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon.
  323. """
  324. def __init__(self, config, add_pooling_layer=True):
  325. r"""
  326. add_pooling_layer (bool, *optional*, defaults to `True`):
  327. Whether to add a pooling layer
  328. """
  329. super().__init__(config)
  330. self.config = config
  331. self.embeddings = FNetEmbeddings(config)
  332. self.encoder = FNetEncoder(config)
  333. self.pooler = FNetPooler(config) if add_pooling_layer else None
  334. # Initialize weights and apply final processing
  335. self.post_init()
  336. def get_input_embeddings(self):
  337. return self.embeddings.word_embeddings
  338. def set_input_embeddings(self, value):
  339. self.embeddings.word_embeddings = value
  340. @auto_docstring
  341. def forward(
  342. self,
  343. input_ids: torch.LongTensor | None = None,
  344. token_type_ids: torch.LongTensor | None = None,
  345. position_ids: torch.LongTensor | None = None,
  346. inputs_embeds: torch.FloatTensor | None = None,
  347. output_hidden_states: bool | None = None,
  348. return_dict: bool | None = None,
  349. **kwargs,
  350. ) -> tuple | BaseModelOutput:
  351. output_hidden_states = (
  352. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  353. )
  354. return_dict = return_dict if return_dict is not None else self.config.return_dict
  355. if input_ids is not None and inputs_embeds is not None:
  356. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  357. elif input_ids is not None:
  358. input_shape = input_ids.size()
  359. batch_size, seq_length = input_shape
  360. elif inputs_embeds is not None:
  361. input_shape = inputs_embeds.size()[:-1]
  362. batch_size, seq_length = input_shape
  363. else:
  364. raise ValueError("You have to specify either input_ids or inputs_embeds")
  365. if (
  366. self.config.use_tpu_fourier_optimizations
  367. and seq_length <= 4096
  368. and self.config.tpu_short_seq_length != seq_length
  369. ):
  370. raise ValueError(
  371. "The `tpu_short_seq_length` in FNetConfig should be set equal to the sequence length being passed to"
  372. " the model when using TPU optimizations."
  373. )
  374. device = input_ids.device if input_ids is not None else inputs_embeds.device
  375. if token_type_ids is None:
  376. if hasattr(self.embeddings, "token_type_ids"):
  377. buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
  378. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
  379. token_type_ids = buffered_token_type_ids_expanded
  380. else:
  381. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  382. embedding_output = self.embeddings(
  383. input_ids=input_ids,
  384. position_ids=position_ids,
  385. token_type_ids=token_type_ids,
  386. inputs_embeds=inputs_embeds,
  387. )
  388. encoder_outputs = self.encoder(
  389. embedding_output,
  390. output_hidden_states=output_hidden_states,
  391. return_dict=return_dict,
  392. )
  393. sequence_output = encoder_outputs[0]
  394. pooler_output = self.pooler(sequence_output) if self.pooler is not None else None
  395. if not return_dict:
  396. return (sequence_output, pooler_output) + encoder_outputs[1:]
  397. return BaseModelOutputWithPooling(
  398. last_hidden_state=sequence_output,
  399. pooler_output=pooler_output,
  400. hidden_states=encoder_outputs.hidden_states,
  401. )
  402. @auto_docstring(
  403. custom_intro="""
  404. FNet Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
  405. sentence prediction (classification)` head.
  406. """
  407. )
  408. class FNetForPreTraining(FNetPreTrainedModel):
  409. _tied_weights_keys = {
  410. "cls.predictions.decoder.bias": "cls.predictions.bias",
  411. "cls.predictions.decoder.weight": "fnet.embeddings.word_embeddings.weight",
  412. }
  413. def __init__(self, config):
  414. super().__init__(config)
  415. self.fnet = FNetModel(config)
  416. self.cls = FNetPreTrainingHeads(config)
  417. # Initialize weights and apply final processing
  418. self.post_init()
  419. def get_output_embeddings(self):
  420. return self.cls.predictions.decoder
  421. def set_output_embeddings(self, new_embeddings):
  422. self.cls.predictions.decoder = new_embeddings
  423. self.cls.predictions.bias = new_embeddings.bias
  424. @auto_docstring
  425. def forward(
  426. self,
  427. input_ids: torch.Tensor | None = None,
  428. token_type_ids: torch.Tensor | None = None,
  429. position_ids: torch.Tensor | None = None,
  430. inputs_embeds: torch.Tensor | None = None,
  431. labels: torch.Tensor | None = None,
  432. next_sentence_label: torch.Tensor | None = None,
  433. output_hidden_states: bool | None = None,
  434. return_dict: bool | None = None,
  435. **kwargs,
  436. ) -> tuple | FNetForPreTrainingOutput:
  437. r"""
  438. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  439. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  440. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  441. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  442. next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  443. Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
  444. (see `input_ids` docstring) Indices should be in `[0, 1]`:
  445. - 0 indicates sequence B is a continuation of sequence A,
  446. - 1 indicates sequence B is a random sequence.
  447. Example:
  448. ```python
  449. >>> from transformers import AutoTokenizer, FNetForPreTraining
  450. >>> import torch
  451. >>> tokenizer = AutoTokenizer.from_pretrained("google/fnet-base")
  452. >>> model = FNetForPreTraining.from_pretrained("google/fnet-base")
  453. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  454. >>> outputs = model(**inputs)
  455. >>> prediction_logits = outputs.prediction_logits
  456. >>> seq_relationship_logits = outputs.seq_relationship_logits
  457. ```"""
  458. return_dict = return_dict if return_dict is not None else self.config.return_dict
  459. outputs = self.fnet(
  460. input_ids,
  461. token_type_ids=token_type_ids,
  462. position_ids=position_ids,
  463. inputs_embeds=inputs_embeds,
  464. output_hidden_states=output_hidden_states,
  465. return_dict=return_dict,
  466. )
  467. sequence_output, pooled_output = outputs[:2]
  468. prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
  469. total_loss = None
  470. if labels is not None and next_sentence_label is not None:
  471. loss_fct = CrossEntropyLoss()
  472. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  473. next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
  474. total_loss = masked_lm_loss + next_sentence_loss
  475. if not return_dict:
  476. output = (prediction_scores, seq_relationship_score) + outputs[2:]
  477. return ((total_loss,) + output) if total_loss is not None else output
  478. return FNetForPreTrainingOutput(
  479. loss=total_loss,
  480. prediction_logits=prediction_scores,
  481. seq_relationship_logits=seq_relationship_score,
  482. hidden_states=outputs.hidden_states,
  483. )
  484. @auto_docstring
  485. class FNetForMaskedLM(FNetPreTrainedModel):
  486. _tied_weights_keys = {
  487. "cls.predictions.decoder.bias": "cls.predictions.bias",
  488. "cls.predictions.decoder.weight": "fnet.embeddings.word_embeddings.weight",
  489. }
  490. def __init__(self, config):
  491. super().__init__(config)
  492. self.fnet = FNetModel(config)
  493. self.cls = FNetOnlyMLMHead(config)
  494. # Initialize weights and apply final processing
  495. self.post_init()
  496. def get_output_embeddings(self):
  497. return self.cls.predictions.decoder
  498. def set_output_embeddings(self, new_embeddings):
  499. self.cls.predictions.decoder = new_embeddings
  500. self.cls.predictions.bias = new_embeddings.bias
  501. @auto_docstring
  502. def forward(
  503. self,
  504. input_ids: torch.Tensor | None = None,
  505. token_type_ids: torch.Tensor | None = None,
  506. position_ids: torch.Tensor | None = None,
  507. inputs_embeds: torch.Tensor | None = None,
  508. labels: torch.Tensor | None = None,
  509. output_hidden_states: bool | None = None,
  510. return_dict: bool | None = None,
  511. **kwargs,
  512. ) -> tuple | MaskedLMOutput:
  513. r"""
  514. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  515. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  516. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  517. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  518. """
  519. return_dict = return_dict if return_dict is not None else self.config.return_dict
  520. outputs = self.fnet(
  521. input_ids,
  522. token_type_ids=token_type_ids,
  523. position_ids=position_ids,
  524. inputs_embeds=inputs_embeds,
  525. output_hidden_states=output_hidden_states,
  526. return_dict=return_dict,
  527. )
  528. sequence_output = outputs[0]
  529. prediction_scores = self.cls(sequence_output)
  530. masked_lm_loss = None
  531. if labels is not None:
  532. loss_fct = CrossEntropyLoss() # -100 index = padding token
  533. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  534. if not return_dict:
  535. output = (prediction_scores,) + outputs[2:]
  536. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  537. return MaskedLMOutput(loss=masked_lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states)
  538. @auto_docstring(
  539. custom_intro="""
  540. FNet Model with a `next sentence prediction (classification)` head on top.
  541. """
  542. )
  543. class FNetForNextSentencePrediction(FNetPreTrainedModel):
  544. def __init__(self, config):
  545. super().__init__(config)
  546. self.fnet = FNetModel(config)
  547. self.cls = FNetOnlyNSPHead(config)
  548. # Initialize weights and apply final processing
  549. self.post_init()
  550. @auto_docstring
  551. def forward(
  552. self,
  553. input_ids: torch.Tensor | None = None,
  554. token_type_ids: torch.Tensor | None = None,
  555. position_ids: torch.Tensor | None = None,
  556. inputs_embeds: torch.Tensor | None = None,
  557. labels: torch.Tensor | None = None,
  558. output_hidden_states: bool | None = None,
  559. return_dict: bool | None = None,
  560. **kwargs,
  561. ) -> tuple | NextSentencePredictorOutput:
  562. r"""
  563. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  564. Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
  565. (see `input_ids` docstring). Indices should be in `[0, 1]`:
  566. - 0 indicates sequence B is a continuation of sequence A,
  567. - 1 indicates sequence B is a random sequence.
  568. Example:
  569. ```python
  570. >>> from transformers import AutoTokenizer, FNetForNextSentencePrediction
  571. >>> import torch
  572. >>> tokenizer = AutoTokenizer.from_pretrained("google/fnet-base")
  573. >>> model = FNetForNextSentencePrediction.from_pretrained("google/fnet-base")
  574. >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
  575. >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
  576. >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
  577. >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
  578. >>> logits = outputs.logits
  579. >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
  580. ```"""
  581. return_dict = return_dict if return_dict is not None else self.config.return_dict
  582. outputs = self.fnet(
  583. input_ids,
  584. token_type_ids=token_type_ids,
  585. position_ids=position_ids,
  586. inputs_embeds=inputs_embeds,
  587. output_hidden_states=output_hidden_states,
  588. return_dict=return_dict,
  589. )
  590. pooled_output = outputs[1]
  591. seq_relationship_scores = self.cls(pooled_output)
  592. next_sentence_loss = None
  593. if labels is not None:
  594. loss_fct = CrossEntropyLoss()
  595. next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
  596. if not return_dict:
  597. output = (seq_relationship_scores,) + outputs[2:]
  598. return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
  599. return NextSentencePredictorOutput(
  600. loss=next_sentence_loss,
  601. logits=seq_relationship_scores,
  602. hidden_states=outputs.hidden_states,
  603. )
  604. @auto_docstring(
  605. custom_intro="""
  606. FNet Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
  607. output) e.g. for GLUE tasks.
  608. """
  609. )
  610. class FNetForSequenceClassification(FNetPreTrainedModel):
  611. def __init__(self, config):
  612. super().__init__(config)
  613. self.num_labels = config.num_labels
  614. self.fnet = FNetModel(config)
  615. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  616. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  617. # Initialize weights and apply final processing
  618. self.post_init()
  619. @auto_docstring
  620. def forward(
  621. self,
  622. input_ids: torch.Tensor | None = None,
  623. token_type_ids: torch.Tensor | None = None,
  624. position_ids: torch.Tensor | None = None,
  625. inputs_embeds: torch.Tensor | None = None,
  626. labels: torch.Tensor | None = None,
  627. output_hidden_states: bool | None = None,
  628. return_dict: bool | None = None,
  629. **kwargs,
  630. ) -> tuple | SequenceClassifierOutput:
  631. r"""
  632. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  633. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  634. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  635. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  636. """
  637. return_dict = return_dict if return_dict is not None else self.config.return_dict
  638. outputs = self.fnet(
  639. input_ids,
  640. token_type_ids=token_type_ids,
  641. position_ids=position_ids,
  642. inputs_embeds=inputs_embeds,
  643. output_hidden_states=output_hidden_states,
  644. return_dict=return_dict,
  645. )
  646. pooled_output = outputs[1]
  647. pooled_output = self.dropout(pooled_output)
  648. logits = self.classifier(pooled_output)
  649. loss = None
  650. if labels is not None:
  651. if self.config.problem_type is None:
  652. if self.num_labels == 1:
  653. self.config.problem_type = "regression"
  654. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  655. self.config.problem_type = "single_label_classification"
  656. else:
  657. self.config.problem_type = "multi_label_classification"
  658. if self.config.problem_type == "regression":
  659. loss_fct = MSELoss()
  660. if self.num_labels == 1:
  661. loss = loss_fct(logits.squeeze(), labels.squeeze())
  662. else:
  663. loss = loss_fct(logits, labels)
  664. elif self.config.problem_type == "single_label_classification":
  665. loss_fct = CrossEntropyLoss()
  666. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  667. elif self.config.problem_type == "multi_label_classification":
  668. loss_fct = BCEWithLogitsLoss()
  669. loss = loss_fct(logits, labels)
  670. if not return_dict:
  671. output = (logits,) + outputs[2:]
  672. return ((loss,) + output) if loss is not None else output
  673. return SequenceClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
  674. @auto_docstring
  675. class FNetForMultipleChoice(FNetPreTrainedModel):
  676. def __init__(self, config):
  677. super().__init__(config)
  678. self.fnet = FNetModel(config)
  679. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  680. self.classifier = nn.Linear(config.hidden_size, 1)
  681. # Initialize weights and apply final processing
  682. self.post_init()
  683. @auto_docstring
  684. def forward(
  685. self,
  686. input_ids: torch.Tensor | None = None,
  687. token_type_ids: torch.Tensor | None = None,
  688. position_ids: torch.Tensor | None = None,
  689. inputs_embeds: torch.Tensor | None = None,
  690. labels: torch.Tensor | None = None,
  691. output_hidden_states: bool | None = None,
  692. return_dict: bool | None = None,
  693. **kwargs,
  694. ) -> tuple | MultipleChoiceModelOutput:
  695. r"""
  696. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  697. Indices of input sequence tokens in the vocabulary.
  698. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  699. [`PreTrainedTokenizer.__call__`] for details.
  700. [What are input IDs?](../glossary#input-ids)
  701. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  702. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  703. 1]`:
  704. - 0 corresponds to a *sentence A* token,
  705. - 1 corresponds to a *sentence B* token.
  706. [What are token type IDs?](../glossary#token-type-ids)
  707. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  708. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  709. config.max_position_embeddings - 1]`.
  710. [What are position IDs?](../glossary#position-ids)
  711. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  712. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  713. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  714. model's internal embedding lookup matrix.
  715. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  716. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  717. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  718. `input_ids` above)
  719. """
  720. return_dict = return_dict if return_dict is not None else self.config.return_dict
  721. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  722. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  723. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  724. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  725. inputs_embeds = (
  726. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  727. if inputs_embeds is not None
  728. else None
  729. )
  730. outputs = self.fnet(
  731. input_ids,
  732. token_type_ids=token_type_ids,
  733. position_ids=position_ids,
  734. inputs_embeds=inputs_embeds,
  735. output_hidden_states=output_hidden_states,
  736. return_dict=return_dict,
  737. )
  738. pooled_output = outputs[1]
  739. pooled_output = self.dropout(pooled_output)
  740. logits = self.classifier(pooled_output)
  741. reshaped_logits = logits.view(-1, num_choices)
  742. loss = None
  743. if labels is not None:
  744. loss_fct = CrossEntropyLoss()
  745. loss = loss_fct(reshaped_logits, labels)
  746. if not return_dict:
  747. output = (reshaped_logits,) + outputs[2:]
  748. return ((loss,) + output) if loss is not None else output
  749. return MultipleChoiceModelOutput(loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states)
  750. @auto_docstring
  751. class FNetForTokenClassification(FNetPreTrainedModel):
  752. def __init__(self, config):
  753. super().__init__(config)
  754. self.num_labels = config.num_labels
  755. self.fnet = FNetModel(config)
  756. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  757. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  758. # Initialize weights and apply final processing
  759. self.post_init()
  760. @auto_docstring
  761. def forward(
  762. self,
  763. input_ids: torch.Tensor | None = None,
  764. token_type_ids: torch.Tensor | None = None,
  765. position_ids: torch.Tensor | None = None,
  766. inputs_embeds: torch.Tensor | None = None,
  767. labels: torch.Tensor | None = None,
  768. output_hidden_states: bool | None = None,
  769. return_dict: bool | None = None,
  770. **kwargs,
  771. ) -> tuple | TokenClassifierOutput:
  772. r"""
  773. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  774. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  775. """
  776. return_dict = return_dict if return_dict is not None else self.config.return_dict
  777. outputs = self.fnet(
  778. input_ids,
  779. token_type_ids=token_type_ids,
  780. position_ids=position_ids,
  781. inputs_embeds=inputs_embeds,
  782. output_hidden_states=output_hidden_states,
  783. return_dict=return_dict,
  784. )
  785. sequence_output = outputs[0]
  786. sequence_output = self.dropout(sequence_output)
  787. logits = self.classifier(sequence_output)
  788. loss = None
  789. if labels is not None:
  790. loss_fct = CrossEntropyLoss()
  791. # Only keep active parts of the loss
  792. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  793. if not return_dict:
  794. output = (logits,) + outputs[2:]
  795. return ((loss,) + output) if loss is not None else output
  796. return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
  797. @auto_docstring
  798. class FNetForQuestionAnswering(FNetPreTrainedModel):
  799. def __init__(self, config):
  800. super().__init__(config)
  801. self.num_labels = config.num_labels
  802. self.fnet = FNetModel(config)
  803. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  804. # Initialize weights and apply final processing
  805. self.post_init()
  806. @auto_docstring
  807. def forward(
  808. self,
  809. input_ids: torch.Tensor | None = None,
  810. token_type_ids: torch.Tensor | None = None,
  811. position_ids: torch.Tensor | None = None,
  812. inputs_embeds: torch.Tensor | None = None,
  813. start_positions: torch.Tensor | None = None,
  814. end_positions: torch.Tensor | None = None,
  815. output_hidden_states: bool | None = None,
  816. return_dict: bool | None = None,
  817. **kwargs,
  818. ) -> tuple | QuestionAnsweringModelOutput:
  819. return_dict = return_dict if return_dict is not None else self.config.return_dict
  820. outputs = self.fnet(
  821. input_ids,
  822. token_type_ids=token_type_ids,
  823. position_ids=position_ids,
  824. inputs_embeds=inputs_embeds,
  825. output_hidden_states=output_hidden_states,
  826. return_dict=return_dict,
  827. )
  828. sequence_output = outputs[0]
  829. logits = self.qa_outputs(sequence_output)
  830. start_logits, end_logits = logits.split(1, dim=-1)
  831. start_logits = start_logits.squeeze(-1).contiguous()
  832. end_logits = end_logits.squeeze(-1).contiguous()
  833. total_loss = None
  834. if start_positions is not None and end_positions is not None:
  835. # If we are on multi-GPU, split add a dimension
  836. if len(start_positions.size()) > 1:
  837. start_positions = start_positions.squeeze(-1)
  838. if len(end_positions.size()) > 1:
  839. end_positions = end_positions.squeeze(-1)
  840. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  841. ignored_index = start_logits.size(1)
  842. start_positions = start_positions.clamp(0, ignored_index)
  843. end_positions = end_positions.clamp(0, ignored_index)
  844. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  845. start_loss = loss_fct(start_logits, start_positions)
  846. end_loss = loss_fct(end_logits, end_positions)
  847. total_loss = (start_loss + end_loss) / 2
  848. if not return_dict:
  849. output = (start_logits, end_logits) + outputs[2:]
  850. return ((total_loss,) + output) if total_loss is not None else output
  851. return QuestionAnsweringModelOutput(
  852. loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=outputs.hidden_states
  853. )
  854. __all__ = [
  855. "FNetForMaskedLM",
  856. "FNetForMultipleChoice",
  857. "FNetForNextSentencePrediction",
  858. "FNetForPreTraining",
  859. "FNetForQuestionAnswering",
  860. "FNetForSequenceClassification",
  861. "FNetForTokenClassification",
  862. "FNetLayer",
  863. "FNetModel",
  864. "FNetPreTrainedModel",
  865. ]