modeling_fsmt.py 45 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136
  1. # Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # Original implementation: https://github.com/pytorch/fairseq/tree/master/examples/wmt19
  16. # Authors:
  17. # - @alexeib Alexei Baevski
  18. # - @edunov Sergey Edunov
  19. # - @michaelauli Michael Auli
  20. # - @myleott Myle Ott
  21. # - @nng555 Nathan Ng
  22. # - David Grangier
  23. # - Kyra Yee
  24. #
  25. # Paper: Facebook FAIR's WMT19 News Translation Task Submission https://huggingface.co/papers/1907.06616
  26. #
  27. """PyTorch Fairseq model, ported from https://github.com/pytorch/fairseq/tree/master/examples/wmt19"""
  28. import math
  29. from typing import Any
  30. import torch
  31. from torch import Tensor, nn
  32. from torch.nn import CrossEntropyLoss, LayerNorm
  33. from ... import initialization as init
  34. from ...activations import ACT2FN
  35. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  36. from ...generation import GenerationMixin
  37. from ...modeling_outputs import (
  38. BaseModelOutput,
  39. BaseModelOutputWithPastAndCrossAttentions,
  40. Seq2SeqLMOutput,
  41. Seq2SeqModelOutput,
  42. )
  43. from ...modeling_utils import PreTrainedModel
  44. from ...utils import auto_docstring, logging
  45. from .configuration_fsmt import FSMTConfig
  46. logger = logging.get_logger(__name__)
  47. # See all FSMT models at https://huggingface.co/models?filter=fsmt
  48. # Porting notes:
  49. # this one is modeled after BartModel*
  50. #
  51. # Currently only translation (fairseq also has weights for LM)
  52. #
  53. # fairseq provides weights for ru-en, en-ru and de-en, en-de pairs. All have been ported.
  54. # - ru-en, en-ru use asymmetric vocab
  55. # - de-en, en-de use a merged single vocab (but the code works as if they are separate)
  56. #
  57. # Differences with Bart:
  58. # - not using bos token
  59. # - 2 separate vocabs (src and target)
  60. # - embed weights aren't tied
  61. # - uses a model Ensemble (but that part isn't ported/implemented yet) - so we
  62. # aren't getting as good of a BLEU score
  63. # - uses a projection layer at the end of the decoder
  64. # - doesn't use final_logits_bias
  65. # - beam search: stops as soon as num_beams == len(hypos) (whereas transformers
  66. # is not satisfied there and will continue searching until the next cycles
  67. # aren't promising something better), comparing BLEU scores - the transformers
  68. # algorithm is slightly superior, therefore using the latter. But if you want
  69. # to match fairseq outputs, you need to pass ``early_stopping=True`` to ``generate()``.
  70. #
  71. # SinusoidalPositionalEmbedding is slightly different from Bart's - generates
  72. # different embeddings. This implementation is copied verbatim from fairseq with
  73. # some small changes to make it work here.
  74. #
  75. # Other changes:
  76. # - doesn't support use_cache as Bart's version does
  77. #
  78. #
  79. # FSMTConfig changes with BartConfig
  80. #
  81. # Differences with BART:
  82. # - src/tgt vocabs aren't shared
  83. # - token embeddings aren't shared
  84. # - needs a language pair
  85. # - scale_embedding are True
  86. #
  87. # some unused args were removed too
  88. #
  89. #
  90. # TODO:
  91. # - port model ensemble (fs uses 4 model checkpoints)
  92. # - solve beam search discrepancies
  93. # docstyle-ignore
  94. """
  95. Here is how to compare BLEU scores against fairseq implementation:
  96. (don't forget to install sacrebleu: `pip install sacrebleu`)
  97. # en-ru
  98. export PAIR=en-ru
  99. export DATA_DIR=data/$PAIR
  100. export SAVE_DIR=data/$PAIR
  101. export BS=8
  102. export NUM_BEAMS=50
  103. mkdir -p $DATA_DIR
  104. sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
  105. sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
  106. echo $PAIR
  107. PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py facebook/wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS
  108. # (fairseq BLEU: 36.4 http://matrix.statmt.org/matrix/output/1914?score_id=37605)
  109. # ru-en
  110. export PAIR=ru-en
  111. export DATA_DIR=data/$PAIR
  112. export SAVE_DIR=data/$PAIR
  113. export BS=8
  114. export NUM_BEAMS=50
  115. mkdir -p $DATA_DIR
  116. sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
  117. sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
  118. PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py facebook/wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS
  119. # (fairseq BLEU: 41.3 http://matrix.statmt.org/matrix/output/1907?run_id=6937)
  120. # de-en
  121. export PAIR=de-en
  122. export DATA_DIR=data/$PAIR
  123. export SAVE_DIR=data/$PAIR
  124. export BS=8
  125. export NUM_BEAMS=50
  126. mkdir -p $DATA_DIR
  127. sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
  128. sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
  129. echo $PAIR
  130. PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py facebook/wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS
  131. # (fairseq BLEU: 42.3 http://matrix.statmt.org/matrix/output/1902?run_id=6750)
  132. # en-de
  133. export PAIR=en-de
  134. export DATA_DIR=data/$PAIR
  135. export SAVE_DIR=data/$PAIR
  136. export BS=8
  137. mkdir -p $DATA_DIR
  138. sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
  139. sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
  140. echo $PAIR
  141. PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py facebook/wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS
  142. # (fairseq BLEU: 43.1 http://matrix.statmt.org/matrix/output/1909?run_id=6862)
  143. """
  144. def invert_mask(attention_mask):
  145. """Turns 1->0, 0->1, False->True, True-> False"""
  146. assert attention_mask.dim() == 2
  147. return attention_mask.eq(0)
  148. def triu_onnx(x, diagonal=0):
  149. l = x.shape[0]
  150. arange = torch.arange(l, device=x.device)
  151. mask = arange.expand(l, l)
  152. arange = arange.unsqueeze(-1)
  153. if diagonal:
  154. arange = arange + diagonal
  155. mask = mask >= arange
  156. return x.masked_fill(mask == 0, 0)
  157. def _prepare_fsmt_decoder_inputs(
  158. config,
  159. input_ids,
  160. decoder_input_ids=None,
  161. decoder_padding_mask=None,
  162. causal_mask_dtype=torch.float32,
  163. ):
  164. """
  165. Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if none are provided.
  166. This mimics the default behavior in fairseq. To override it pass in masks. Note: this is not called during
  167. generation
  168. """
  169. pad_token_id = config.pad_token_id
  170. if decoder_input_ids is None:
  171. decoder_input_ids = shift_tokens_right(input_ids, pad_token_id)
  172. bsz, tgt_len = decoder_input_ids.size()
  173. if decoder_padding_mask is None:
  174. decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
  175. else:
  176. decoder_padding_mask = invert_mask(decoder_padding_mask)
  177. causal_mask = triu_onnx(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len, dtype=causal_mask_dtype)), 1).to(
  178. device=decoder_input_ids.device
  179. )
  180. return decoder_input_ids, decoder_padding_mask, causal_mask
  181. @auto_docstring
  182. class PretrainedFSMTModel(PreTrainedModel):
  183. config: FSMTConfig
  184. base_model_prefix = "model"
  185. @torch.no_grad()
  186. def _init_weights(self, module):
  187. std = self.config.init_std
  188. if isinstance(module, nn.Linear):
  189. init.normal_(module.weight, mean=0.0, std=std)
  190. if module.bias is not None:
  191. init.zeros_(module.bias)
  192. elif isinstance(module, SinusoidalPositionalEmbedding):
  193. weight = module.get_embedding(*module.weight.shape, module.padding_idx)
  194. init.copy_(module.weight, weight)
  195. elif isinstance(module, nn.Embedding):
  196. init.normal_(module.weight, mean=0.0, std=std)
  197. # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
  198. if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
  199. init.zeros_(module.weight[module.padding_idx])
  200. @property
  201. def dummy_inputs(self):
  202. pad_token = self.config.pad_token_id
  203. input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
  204. dummy_inputs = {
  205. "attention_mask": input_ids.ne(pad_token),
  206. "input_ids": input_ids,
  207. }
  208. return dummy_inputs
  209. def _make_linear_from_emb(emb):
  210. vocab_size, emb_size = emb.weight.shape
  211. lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
  212. lin_layer.weight.data = emb.weight.data
  213. return lin_layer
  214. # Helper Functions, mostly for making masks
  215. def _check_shapes(shape_1, shape2):
  216. if shape_1 != shape2:
  217. raise AssertionError(f"shape mismatch: {shape_1} != {shape2}")
  218. def shift_tokens_right(input_ids, pad_token_id):
  219. """Shift input ids one token to the right, and wrap the last non pad token (usually <eos>)."""
  220. # replace possible -100 values in labels by `pad_token_id`
  221. input_ids.masked_fill_(input_ids == -100, pad_token_id)
  222. prev_output_tokens = input_ids.clone()
  223. index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
  224. prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
  225. prev_output_tokens[:, 1:] = input_ids[:, :-1]
  226. return prev_output_tokens
  227. def make_padding_mask(input_ids, padding_idx=1):
  228. """True for pad tokens"""
  229. padding_mask = input_ids.eq(padding_idx)
  230. if not padding_mask.any():
  231. padding_mask = None
  232. return padding_mask
  233. # Helper Modules
  234. class EncoderLayer(nn.Module):
  235. def __init__(self, config: FSMTConfig):
  236. super().__init__()
  237. self.embed_dim = config.d_model
  238. self.self_attn = Attention(self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout)
  239. self.self_attn_layer_norm = LayerNorm(self.embed_dim)
  240. self.dropout = config.dropout
  241. self.activation_fn = ACT2FN[config.activation_function]
  242. self.activation_dropout = config.activation_dropout
  243. self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
  244. self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
  245. self.final_layer_norm = LayerNorm(self.embed_dim)
  246. def forward(self, x, encoder_padding_mask, output_attentions=False):
  247. """
  248. Args:
  249. x (`torch.Tensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
  250. encoder_padding_mask (`torch.ByteTensor`): binary ByteTensor of shape
  251. *(batch, src_len)* where padding elements are indicated by `1`.
  252. for t_tgt, t_src is excluded (or masked out), =0 means it is
  253. included in attention
  254. Returns:
  255. encoded output of shape *(seq_len, batch, embed_dim)*
  256. """
  257. residual = x
  258. x, attn_weights = self.self_attn(
  259. query=x,
  260. key=x,
  261. key_padding_mask=encoder_padding_mask,
  262. output_attentions=output_attentions,
  263. )
  264. x = nn.functional.dropout(x, p=self.dropout, training=self.training)
  265. x = residual + x
  266. x = self.self_attn_layer_norm(x)
  267. residual = x
  268. x = self.activation_fn(self.fc1(x))
  269. x = nn.functional.dropout(x, p=self.activation_dropout, training=self.training)
  270. x = self.fc2(x)
  271. x = nn.functional.dropout(x, p=self.dropout, training=self.training)
  272. x = residual + x
  273. x = self.final_layer_norm(x)
  274. return x, attn_weights
  275. class FSMTEncoder(nn.Module):
  276. """
  277. Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a [`EncoderLayer`].
  278. Args:
  279. config: FSMTConfig
  280. """
  281. def __init__(self, config: FSMTConfig):
  282. super().__init__()
  283. self.dropout = config.dropout
  284. self.layerdrop = config.encoder_layerdrop
  285. self.padding_idx = config.pad_token_id
  286. self.embed_tokens = nn.Embedding(config.src_vocab_size, config.d_model, config.pad_token_id)
  287. embed_dim = self.embed_tokens.embedding_dim
  288. self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
  289. self.embed_positions = SinusoidalPositionalEmbedding(
  290. config.max_position_embeddings + self.padding_idx + 1, embed_dim, self.padding_idx
  291. )
  292. self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)]) # type: list[EncoderLayer]
  293. def forward(
  294. self,
  295. input_ids: torch.Tensor,
  296. attention_mask: torch.Tensor | None = None,
  297. inputs_embeds: torch.Tensor | None = None,
  298. output_attentions: bool = False,
  299. output_hidden_states: bool = False,
  300. return_dict: bool = True,
  301. ):
  302. """
  303. Args:
  304. input_ids (`torch.LongTensor`): tokens in the source language of shape
  305. *(batch, src_len)*
  306. attention_mask (`torch.LongTensor`): indicating which indices are padding tokens
  307. inputs_embeds (`torch.FloatTensor`):
  308. embedding vectors of shape *(batch, src_len, embed_dim)*
  309. Returns:
  310. BaseModelOutput or Tuple comprised of:
  311. - **x** (`torch.Tensor`): the last encoder layer's output of shape *(src_len, batch, embed_dim)*
  312. - **encoder_states** (`Tuple(torch.FloatTensor)`): all intermediate hidden states of shape *(src_len,
  313. batch, embed_dim)*. Only populated if *output_hidden_states:* is True.
  314. - **all_attentions** (`Tuple(torch.FloatTensor)`): Attention weights for each layer.
  315. During training might not be of length n_layers because of layer dropout.
  316. """
  317. # check attention mask and invert
  318. if attention_mask is not None:
  319. attention_mask = invert_mask(attention_mask)
  320. if input_ids is not None and inputs_embeds is not None:
  321. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  322. elif input_ids is not None:
  323. inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
  324. embed_pos = self.embed_positions(input_ids)
  325. elif inputs_embeds is not None:
  326. inputs_embeds = inputs_embeds * self.embed_scale
  327. # We assume zeros hidden states correspond to padding tokens
  328. # and create `position_ids` where inputs_embeds[:, :, 0] == 0
  329. position_ids = inputs_embeds[:, :, 0].masked_fill(
  330. inputs_embeds[:, :, 0].eq(0), self.embed_positions.padding_idx
  331. )
  332. embed_pos = self.embed_positions(position_ids)
  333. else:
  334. raise ValueError("You have to specify either input_ids or inputs_embeds")
  335. x = inputs_embeds + embed_pos
  336. x = nn.functional.dropout(x, p=self.dropout, training=self.training)
  337. # B x T x C -> T x B x C
  338. x = x.transpose(0, 1)
  339. encoder_states = () if output_hidden_states else None
  340. all_attentions = () if output_attentions else None
  341. for idx, encoder_layer in enumerate(self.layers):
  342. if output_hidden_states:
  343. x = x.transpose(0, 1) # T x B x C -> B x T x C
  344. encoder_states += (x,)
  345. x = x.transpose(0, 1) # B x T x C -> T x B x C
  346. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  347. dropout_probability = torch.rand([])
  348. if self.training and (dropout_probability < self.layerdrop): # skip the layer
  349. attn = None
  350. else:
  351. x, attn = encoder_layer(
  352. x,
  353. attention_mask,
  354. output_attentions=output_attentions,
  355. )
  356. if output_attentions:
  357. all_attentions = all_attentions + (attn,)
  358. # T x B x C -> B x T x C
  359. x = x.transpose(0, 1)
  360. if output_hidden_states:
  361. encoder_states += (x,)
  362. if not return_dict:
  363. return tuple(v for v in [x, encoder_states, all_attentions] if v is not None)
  364. return BaseModelOutput(last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions)
  365. class DecoderLayer(nn.Module):
  366. def __init__(self, config: FSMTConfig, layer_idx=None):
  367. super().__init__()
  368. self.embed_dim = config.d_model
  369. self.self_attn = Attention(
  370. embed_dim=self.embed_dim,
  371. num_heads=config.decoder_attention_heads,
  372. dropout=config.attention_dropout,
  373. layer_idx=layer_idx,
  374. )
  375. self.dropout = config.dropout
  376. self.activation_fn = ACT2FN[config.activation_function]
  377. self.activation_dropout = config.activation_dropout
  378. self.self_attn_layer_norm = LayerNorm(self.embed_dim)
  379. self.encoder_attn = Attention(
  380. self.embed_dim,
  381. config.decoder_attention_heads,
  382. dropout=config.attention_dropout,
  383. encoder_decoder_attention=True,
  384. layer_idx=layer_idx,
  385. )
  386. self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
  387. self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
  388. self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
  389. self.final_layer_norm = LayerNorm(self.embed_dim)
  390. def forward(
  391. self,
  392. x,
  393. encoder_hidden_states,
  394. encoder_attn_mask=None,
  395. layer_state=None,
  396. causal_mask=None,
  397. decoder_padding_mask=None,
  398. output_attentions=False,
  399. **kwargs,
  400. ):
  401. residual = x
  402. # Self Attention
  403. x, self_attn_weights = self.self_attn(
  404. query=x,
  405. key=x,
  406. layer_state=layer_state, # adds keys to layer state
  407. key_padding_mask=decoder_padding_mask,
  408. attn_mask=causal_mask,
  409. output_attentions=output_attentions,
  410. )
  411. x = nn.functional.dropout(x, p=self.dropout, training=self.training)
  412. x = residual + x
  413. x = self.self_attn_layer_norm(x)
  414. # Cross attention
  415. residual = x
  416. assert self.encoder_attn.cache_key != self.self_attn.cache_key
  417. x, cross_attn_weights = self.encoder_attn(
  418. query=x,
  419. key=encoder_hidden_states,
  420. key_padding_mask=encoder_attn_mask,
  421. layer_state=layer_state, # mutates layer state
  422. output_attentions=output_attentions,
  423. )
  424. x = nn.functional.dropout(x, p=self.dropout, training=self.training)
  425. x = residual + x
  426. x = self.encoder_attn_layer_norm(x)
  427. # Fully Connected
  428. residual = x
  429. x = self.activation_fn(self.fc1(x))
  430. x = nn.functional.dropout(x, p=self.activation_dropout, training=self.training)
  431. x = self.fc2(x)
  432. x = nn.functional.dropout(x, p=self.dropout, training=self.training)
  433. x = residual + x
  434. x = self.final_layer_norm(x)
  435. return (
  436. x,
  437. self_attn_weights,
  438. cross_attn_weights,
  439. )
  440. class FSMTDecoder(nn.Module):
  441. """
  442. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DecoderLayer`]
  443. Args:
  444. config: FSMTConfig
  445. embed_tokens (nn.Embedding): output embedding
  446. """
  447. def __init__(self, config: FSMTConfig):
  448. super().__init__()
  449. self.dropout = config.dropout
  450. self.layerdrop = config.decoder_layerdrop
  451. self.padding_idx = config.pad_token_id
  452. self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  453. self.embed_tokens = nn.Embedding(config.tgt_vocab_size, config.d_model, self.padding_idx)
  454. embed_dim = self.embed_tokens.embedding_dim
  455. self.embed_positions = SinusoidalPositionalEmbedding(
  456. config.max_position_embeddings + self.padding_idx + 1, embed_dim, self.padding_idx
  457. )
  458. self.layers = nn.ModuleList([DecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) # type: list[DecoderLayer]
  459. self.output_projection = nn.Linear(config.d_model, config.tgt_vocab_size, bias=False)
  460. def forward(
  461. self,
  462. input_ids: torch.Tensor,
  463. encoder_hidden_states: torch.Tensor,
  464. encoder_padding_mask: torch.Tensor,
  465. decoder_padding_mask: torch.Tensor,
  466. decoder_causal_mask: torch.Tensor,
  467. inputs_embeds: torch.Tensor | None = None,
  468. past_key_values: Cache | None = None,
  469. use_cache: bool | None = False,
  470. output_attentions: bool | None = False,
  471. output_hidden_states: bool | None = False,
  472. return_dict: bool | None = True,
  473. **kwargs,
  474. ):
  475. """
  476. Includes several features from "Jointly Learning to Align and Translate with Transformer Models" (Garg et al.,
  477. EMNLP 2019).
  478. Args:
  479. input_ids (`torch.LongTensor` of shape `(batch, tgt_len)`):
  480. previous decoder outputs for teacher forcing
  481. encoder_hidden_states: output from the encoder, used for
  482. encoder-side attention
  483. encoder_padding_mask: for ignoring pad tokens
  484. past_key_values (dict or None): dictionary used for storing state during generation
  485. Returns:
  486. BaseModelOutputWithPast or tuple:
  487. - the decoder's features of shape *(batch, tgt_len, embed_dim)*
  488. - the cache
  489. - hidden states
  490. - attentions
  491. """
  492. # check attention mask and invert
  493. if encoder_padding_mask is not None:
  494. encoder_padding_mask = invert_mask(encoder_padding_mask)
  495. if input_ids is not None and inputs_embeds is not None:
  496. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  497. elif input_ids is not None:
  498. # embed positions
  499. positions = self.embed_positions(input_ids)
  500. if use_cache:
  501. input_ids = input_ids[:, -1:]
  502. positions = positions[:, -1:] # happens after we embed them
  503. x = self.embed_tokens(input_ids) * self.embed_scale
  504. elif inputs_embeds is not None:
  505. # We assume zeros hidden states correspond to padding tokens
  506. # and create `position_ids` where inputs_embeds[:, :, 0] == 0
  507. position_ids = inputs_embeds[:, :, 0].masked_fill(
  508. inputs_embeds[:, :, 0].eq(0), self.embed_positions.padding_idx
  509. )
  510. positions = self.embed_positions(position_ids)
  511. x = inputs_embeds * self.embed_scale
  512. else:
  513. raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
  514. x += positions
  515. x = nn.functional.dropout(x, p=self.dropout, training=self.training)
  516. # Convert to FSMT output format: (BS, seq_len, model_dim) -> (seq_len, BS, model_dim)
  517. x = x.transpose(0, 1)
  518. encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
  519. # decoder layers
  520. all_hidden_states = () if output_hidden_states else None
  521. all_self_attns = () if output_attentions else None
  522. all_cross_attns = () if output_attentions else None
  523. for idx, decoder_layer in enumerate(self.layers):
  524. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  525. if output_hidden_states:
  526. x = x.transpose(0, 1)
  527. all_hidden_states += (x,)
  528. x = x.transpose(0, 1)
  529. if self.training:
  530. dropout_probability = torch.rand([])
  531. if dropout_probability < self.layerdrop:
  532. continue
  533. x, layer_self_attn, layer_cross_attn = decoder_layer(
  534. x,
  535. encoder_hidden_states,
  536. encoder_attn_mask=encoder_padding_mask,
  537. decoder_padding_mask=decoder_padding_mask,
  538. layer_state=past_key_values,
  539. causal_mask=decoder_causal_mask,
  540. output_attentions=output_attentions,
  541. )
  542. if output_attentions:
  543. all_self_attns += (layer_self_attn,)
  544. all_cross_attns += (layer_cross_attn,)
  545. # add hidden states from the last decoder layer
  546. if output_hidden_states:
  547. x = x.transpose(0, 1)
  548. all_hidden_states += (x,)
  549. x = x.transpose(0, 1)
  550. # Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
  551. x = x.transpose(0, 1)
  552. encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
  553. x = self.output_projection(x)
  554. if not return_dict:
  555. return tuple(
  556. v for v in [x, past_key_values, all_hidden_states, all_self_attns, all_cross_attns] if v is not None
  557. )
  558. return BaseModelOutputWithPastAndCrossAttentions(
  559. last_hidden_state=x,
  560. past_key_values=past_key_values,
  561. hidden_states=all_hidden_states,
  562. attentions=all_self_attns,
  563. cross_attentions=all_cross_attns,
  564. )
  565. def _reorder_buffer(attn_cache, new_order):
  566. for k, input_buffer_k in attn_cache.items():
  567. if input_buffer_k is not None:
  568. attn_cache[k] = input_buffer_k.index_select(0, new_order)
  569. return attn_cache
  570. class Attention(nn.Module):
  571. """Multi-headed attention from 'Attention Is All You Need' paper"""
  572. def __init__(
  573. self,
  574. embed_dim,
  575. num_heads,
  576. dropout=0.0,
  577. bias=True,
  578. encoder_decoder_attention=False, # otherwise self_attention
  579. layer_idx=None,
  580. ):
  581. super().__init__()
  582. self.embed_dim = embed_dim
  583. self.num_heads = num_heads
  584. self.dropout = dropout
  585. self.head_dim = embed_dim // num_heads
  586. assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
  587. self.scaling = self.head_dim**-0.5
  588. self.layer_idx = layer_idx
  589. self.encoder_decoder_attention = encoder_decoder_attention
  590. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  591. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  592. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  593. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  594. self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"
  595. def forward(
  596. self,
  597. query,
  598. key: Tensor | None,
  599. key_padding_mask: Tensor | None = None,
  600. layer_state: Cache | None = None,
  601. attn_mask: Tensor | None = None,
  602. output_attentions: bool | None = False,
  603. **kwargs,
  604. ) -> tuple[Tensor, Tensor | None]:
  605. """Input shape: Time(SeqLen) x Batch x Channel"""
  606. tgt_len, bsz, embed_dim = query.size()
  607. assert embed_dim == self.embed_dim
  608. assert list(query.size()) == [tgt_len, bsz, embed_dim]
  609. if layer_state is not None:
  610. if isinstance(layer_state, EncoderDecoderCache):
  611. is_updated = layer_state.is_updated.get(self.layer_idx)
  612. if self.encoder_decoder_attention:
  613. # after the first generated id, we can subsequently re-use all key/value_states from cache
  614. curr_past_key_values = layer_state.cross_attention_cache
  615. else:
  616. curr_past_key_values = layer_state.self_attention_cache
  617. else:
  618. curr_past_key_values = layer_state
  619. # NOTE: FSMT has format (seq_len, BS, model_dim) for inputs
  620. current_states = key if self.encoder_decoder_attention else query
  621. if self.encoder_decoder_attention and layer_state is not None and is_updated:
  622. # reuse k,v, cross_attentions
  623. key_states = curr_past_key_values.layers[self.layer_idx].keys
  624. value_states = curr_past_key_values.layers[self.layer_idx].values
  625. else:
  626. key_states = self.k_proj(current_states)
  627. value_states = self.v_proj(current_states)
  628. key_states = key_states.view(-1, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3)
  629. value_states = value_states.view(-1, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3)
  630. if layer_state is not None:
  631. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  632. key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
  633. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  634. if self.encoder_decoder_attention:
  635. layer_state.is_updated[self.layer_idx] = True
  636. query_states = self.q_proj(query) * self.scaling
  637. # Reshape back to 3D tensors for `bmm`
  638. query_states = query_states.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
  639. key_states = key_states.reshape(bsz * self.num_heads, -1, self.head_dim)
  640. value_states = value_states.reshape(bsz * self.num_heads, -1, self.head_dim)
  641. assert key_states is not None
  642. src_len = key_states.size(1)
  643. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  644. assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len)
  645. if attn_mask is not None:
  646. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask
  647. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  648. # This is part of a workaround to get around fork/join parallelism not supporting Optional types.
  649. if key_padding_mask is not None and key_padding_mask.dim() == 0:
  650. key_padding_mask = None
  651. assert key_padding_mask is None or key_padding_mask.size()[:2] == (
  652. bsz,
  653. src_len,
  654. )
  655. if key_padding_mask is not None: # don't attend to padding symbols
  656. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  657. reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2)
  658. attn_weights = attn_weights.masked_fill(reshaped, torch.finfo(attn_weights.dtype).min)
  659. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  660. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  661. if output_attentions:
  662. # make sure that attn_weights are included in graph
  663. attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  664. attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
  665. else:
  666. attn_weights_reshaped = None
  667. attn_probs = nn.functional.dropout(
  668. attn_weights,
  669. p=self.dropout,
  670. training=self.training,
  671. )
  672. assert value_states is not None
  673. attn_output = torch.bmm(attn_probs, value_states)
  674. assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
  675. attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
  676. attn_output = self.out_proj(attn_output)
  677. return attn_output, attn_weights_reshaped
  678. def fill_with_neg_inf(t):
  679. """FP16-compatible function that fills a input_ids with -inf."""
  680. return t.float().fill_(torch.finfo(t.dtype).min).type_as(t)
  681. # Public API
  682. def _get_shape(t):
  683. return getattr(t, "shape", None)
  684. @auto_docstring
  685. class FSMTModel(PretrainedFSMTModel):
  686. _tied_weights_keys = {
  687. "encoder.embed_tokens.weight": "decoder.embed_tokens.weight",
  688. "decoder.output_projection.weight": "decoder.embed_tokens.weight",
  689. }
  690. def __init__(self, config: FSMTConfig):
  691. super().__init__(config)
  692. self.encoder = FSMTEncoder(config)
  693. self.decoder = FSMTDecoder(config)
  694. self.post_init()
  695. @auto_docstring
  696. def forward(
  697. self,
  698. input_ids: torch.LongTensor,
  699. attention_mask: torch.Tensor | None = None,
  700. decoder_input_ids: torch.LongTensor | None = None,
  701. decoder_attention_mask: torch.BoolTensor | None = None,
  702. encoder_outputs: tuple[torch.FloatTensor] | None = None,
  703. past_key_values: Cache | None = None,
  704. use_cache: bool | None = None,
  705. output_attentions: bool | None = None,
  706. output_hidden_states: bool | None = None,
  707. inputs_embeds: torch.FloatTensor | None = None,
  708. decoder_inputs_embeds: torch.FloatTensor | None = None,
  709. return_dict: bool | None = None,
  710. **kwargs,
  711. ) -> tuple[torch.Tensor] | Seq2SeqModelOutput:
  712. r"""
  713. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  714. Indices of decoder input sequence tokens in the vocabulary.
  715. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  716. [`PreTrainedTokenizer.__call__`] for details.
  717. [What are decoder input IDs?](../glossary#decoder-input-ids)
  718. FSMT uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  719. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  720. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  721. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  722. be used by default.
  723. """
  724. if decoder_input_ids is None:
  725. use_cache = False
  726. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  727. output_hidden_states = (
  728. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  729. )
  730. use_cache = use_cache if use_cache is not None else self.config.use_cache
  731. return_dict = return_dict if return_dict is not None else self.config.return_dict
  732. # make masks if user doesn't supply
  733. if not use_cache and input_ids is not None:
  734. decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_fsmt_decoder_inputs(
  735. self.config,
  736. input_ids,
  737. decoder_input_ids=decoder_input_ids,
  738. decoder_padding_mask=decoder_attention_mask,
  739. causal_mask_dtype=self.decoder.embed_tokens.weight.dtype,
  740. )
  741. else:
  742. decoder_padding_mask, causal_mask = None, None
  743. if decoder_input_ids is None and decoder_inputs_embeds is None:
  744. raise ValueError("Make sure that `decoder_input_ids` or `decoder_inputs_embeds` are passed.")
  745. if use_cache and past_key_values is None:
  746. past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  747. if encoder_outputs is None:
  748. encoder_outputs = self.encoder(
  749. input_ids=input_ids,
  750. attention_mask=attention_mask,
  751. inputs_embeds=inputs_embeds,
  752. output_attentions=output_attentions,
  753. output_hidden_states=output_hidden_states,
  754. return_dict=return_dict,
  755. )
  756. # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=False
  757. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  758. encoder_outputs = BaseModelOutput(
  759. last_hidden_state=encoder_outputs[0],
  760. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  761. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  762. )
  763. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  764. decoder_outputs = self.decoder(
  765. decoder_input_ids,
  766. encoder_outputs[0],
  767. attention_mask,
  768. decoder_padding_mask,
  769. decoder_causal_mask=causal_mask,
  770. inputs_embeds=decoder_inputs_embeds,
  771. past_key_values=past_key_values,
  772. use_cache=use_cache,
  773. output_attentions=output_attentions,
  774. output_hidden_states=output_hidden_states,
  775. return_dict=return_dict,
  776. )
  777. if not return_dict:
  778. return decoder_outputs + encoder_outputs
  779. return Seq2SeqModelOutput(
  780. last_hidden_state=decoder_outputs.last_hidden_state,
  781. past_key_values=decoder_outputs.past_key_values,
  782. decoder_hidden_states=decoder_outputs.hidden_states,
  783. decoder_attentions=decoder_outputs.attentions,
  784. cross_attentions=decoder_outputs.cross_attentions,
  785. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  786. encoder_hidden_states=encoder_outputs.hidden_states,
  787. encoder_attentions=encoder_outputs.attentions,
  788. )
  789. def get_input_embeddings(self):
  790. return self.encoder.embed_tokens
  791. def set_input_embeddings(self, value):
  792. self.encoder.embed_tokens = value
  793. def get_output_embeddings(self):
  794. return self.decoder.embed_tokens
  795. def set_output_embeddings(self, value):
  796. self.decoder.embed_tokens = value
  797. @auto_docstring(
  798. custom_intro="""
  799. The FSMT Model with a language modeling head. Can be used for summarization.
  800. """
  801. )
  802. class FSMTForConditionalGeneration(PretrainedFSMTModel, GenerationMixin):
  803. base_model_prefix = "model"
  804. def __init__(self, config: FSMTConfig):
  805. super().__init__(config)
  806. base_model = FSMTModel(config)
  807. self.model = base_model
  808. # Initialize weights and apply final processing
  809. self.post_init()
  810. @auto_docstring
  811. def forward(
  812. self,
  813. input_ids: torch.LongTensor | None = None,
  814. attention_mask: torch.Tensor | None = None,
  815. decoder_input_ids: torch.LongTensor | None = None,
  816. decoder_attention_mask: torch.BoolTensor | None = None,
  817. encoder_outputs: tuple[torch.FloatTensor] | None = None,
  818. past_key_values: Cache | None = None,
  819. inputs_embeds: torch.Tensor | None = None,
  820. decoder_inputs_embeds: torch.Tensor | None = None,
  821. labels: torch.LongTensor | None = None,
  822. use_cache: bool | None = None,
  823. output_attentions: bool | None = None,
  824. output_hidden_states: bool | None = None,
  825. return_dict: bool | None = None,
  826. **kwargs,
  827. ) -> tuple[torch.Tensor] | Seq2SeqLMOutput:
  828. r"""
  829. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  830. Indices of decoder input sequence tokens in the vocabulary.
  831. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  832. [`PreTrainedTokenizer.__call__`] for details.
  833. [What are decoder input IDs?](../glossary#decoder-input-ids)
  834. FSMT uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  835. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  836. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  837. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  838. be used by default.
  839. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  840. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  841. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  842. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  843. Example Translation:
  844. ```python
  845. >>> from transformers import AutoTokenizer, FSMTForConditionalGeneration
  846. >>> mname = "facebook/wmt19-ru-en"
  847. >>> model = FSMTForConditionalGeneration.from_pretrained(mname)
  848. >>> tokenizer = AutoTokenizer.from_pretrained(mname)
  849. >>> src_text = "Машинное обучение - это здорово, не так ли?"
  850. >>> input_ids = tokenizer(src_text, return_tensors="pt").input_ids
  851. >>> outputs = model.generate(input_ids, num_beams=5, num_return_sequences=3)
  852. >>> tokenizer.decode(outputs[0], skip_special_tokens=True)
  853. "Machine learning is great, isn't it?"
  854. ```
  855. """
  856. return_dict = return_dict if return_dict is not None else self.config.return_dict
  857. if labels is not None:
  858. use_cache = False
  859. outputs = self.model(
  860. input_ids,
  861. inputs_embeds=inputs_embeds,
  862. attention_mask=attention_mask,
  863. decoder_input_ids=decoder_input_ids,
  864. decoder_inputs_embeds=decoder_inputs_embeds,
  865. encoder_outputs=encoder_outputs,
  866. decoder_attention_mask=decoder_attention_mask,
  867. past_key_values=past_key_values,
  868. use_cache=use_cache,
  869. output_attentions=output_attentions,
  870. output_hidden_states=output_hidden_states,
  871. return_dict=return_dict,
  872. )
  873. lm_logits = outputs[0]
  874. masked_lm_loss = None
  875. if labels is not None:
  876. loss_fct = CrossEntropyLoss()
  877. # TODO(SS): do we need to ignore pad tokens in labels?
  878. masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.tgt_vocab_size), labels.view(-1))
  879. if not return_dict:
  880. output = (lm_logits,) + outputs[1:]
  881. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  882. return Seq2SeqLMOutput(
  883. loss=masked_lm_loss,
  884. logits=lm_logits,
  885. past_key_values=outputs.past_key_values,
  886. decoder_hidden_states=outputs.decoder_hidden_states,
  887. decoder_attentions=outputs.decoder_attentions,
  888. cross_attentions=outputs.cross_attentions,
  889. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  890. encoder_hidden_states=outputs.encoder_hidden_states,
  891. encoder_attentions=outputs.encoder_attentions,
  892. )
  893. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  894. return shift_tokens_right(labels, self.config.pad_token_id)
  895. def get_output_embeddings(self):
  896. return self.model.decoder.embed_tokens
  897. def set_output_embeddings(self, value):
  898. self.model.decoder.embed_tokens = value
  899. class SinusoidalPositionalEmbedding(nn.Embedding):
  900. """
  901. This module produces sinusoidal positional embeddings of any length.
  902. We don't want to save the weight of this embedding since it's not trained (deterministic) and it can be huge.
  903. Padding symbols are ignored.
  904. These embeddings get automatically extended in forward if more positions is needed.
  905. """
  906. def __init__(self, num_positions, embedding_dim, padding_idx):
  907. super().__init__(num_positions, embedding_dim, padding_idx)
  908. def make_weight(self, num_positions, embedding_dim, padding_idx):
  909. weight = self.get_embedding(num_positions, embedding_dim, padding_idx)
  910. # in forward put the weights on the correct dtype and device of the param
  911. weight = weight.to(dtype=self.weight.dtype, device=self.weight.device)
  912. self.weight = nn.Parameter(weight)
  913. self.weight.detach_()
  914. self.weight.requires_grad = False
  915. @staticmethod
  916. def get_embedding(num_embeddings, embedding_dim, padding_idx):
  917. """
  918. Build sinusoidal embeddings.
  919. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
  920. "Attention Is All You Need".
  921. """
  922. half_dim = embedding_dim // 2
  923. emb = math.log(10000) / (half_dim - 1)
  924. emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
  925. emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
  926. emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
  927. if embedding_dim % 2 == 1:
  928. # zero pad
  929. emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
  930. if padding_idx is not None:
  931. emb[padding_idx, :] = 0
  932. return emb
  933. @staticmethod
  934. def make_positions(tensor, padding_idx: int):
  935. """
  936. Replace non-padding symbols with their position numbers.
  937. Position numbers begin at padding_idx+1. Padding symbols are ignored.
  938. """
  939. # The series of casts and type-conversions here are carefully
  940. # balanced to both work with ONNX export and XLA. In particular XLA
  941. # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
  942. # how to handle the dtype kwarg in cumsum.
  943. mask = tensor.ne(padding_idx).int()
  944. return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx
  945. def forward(
  946. self,
  947. input,
  948. incremental_state: Any | None = None,
  949. timestep: Tensor | None = None,
  950. ):
  951. """Input is expected to be of size [bsz x seqlen]."""
  952. bsz, seq_len = input.shape[:2]
  953. max_pos = self.padding_idx + 1 + seq_len
  954. if max_pos > self.weight.size(0):
  955. # expand embeddings if needed
  956. self.make_weight(max_pos, self.embedding_dim, self.padding_idx)
  957. positions = self.make_positions(input, self.padding_idx)
  958. return super().forward(positions)
  959. __all__ = ["FSMTForConditionalGeneration", "FSMTModel", "PretrainedFSMTModel"]