modeling_marian.py 46 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113
  1. # Copyright 2021 The Marian Team Authors and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch MarianMTModel model, ported from the Marian C++ repo."""
  15. import math
  16. from collections.abc import Callable
  17. import numpy as np
  18. import torch
  19. from torch import nn
  20. from torch.nn import CrossEntropyLoss
  21. from ... import initialization as init
  22. from ...activations import ACT2FN
  23. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  24. from ...generation import GenerationMixin
  25. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  26. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  27. from ...modeling_layers import GradientCheckpointingLayer
  28. from ...modeling_outputs import (
  29. BaseModelOutput,
  30. BaseModelOutputWithPastAndCrossAttentions,
  31. CausalLMOutputWithCrossAttentions,
  32. Seq2SeqLMOutput,
  33. Seq2SeqModelOutput,
  34. )
  35. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  36. from ...processing_utils import Unpack
  37. from ...utils import (
  38. TransformersKwargs,
  39. auto_docstring,
  40. can_return_tuple,
  41. is_torchdynamo_compiling,
  42. logging,
  43. )
  44. from ...utils.generic import merge_with_config_defaults
  45. from ...utils.output_capturing import OutputRecorder, capture_outputs
  46. from .configuration_marian import MarianConfig
  47. logger = logging.get_logger(__name__)
  48. # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
  49. def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
  50. """
  51. Shift input ids one token to the right.
  52. """
  53. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  54. shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
  55. shifted_input_ids[:, 0] = decoder_start_token_id
  56. if pad_token_id is None:
  57. raise ValueError("self.model.config.pad_token_id has to be defined.")
  58. # replace possible -100 values in labels by `pad_token_id`
  59. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  60. return shifted_input_ids
  61. class MarianSinusoidalPositionalEmbedding(nn.Embedding):
  62. """This module produces sinusoidal positional embeddings of any length."""
  63. def __init__(self, num_positions: int, embedding_dim: int, padding_idx: int | None = None) -> None:
  64. super().__init__(num_positions, embedding_dim, _freeze=True)
  65. def create_weight(self):
  66. """
  67. Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
  68. the 2nd half of the vector. [dim // 2:]
  69. """
  70. n_pos, dim = self.weight.shape
  71. position_enc = np.array(
  72. [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
  73. )
  74. out = torch.empty(n_pos, dim, dtype=self.weight.dtype, requires_grad=False)
  75. sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
  76. out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
  77. out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
  78. return out
  79. @torch.no_grad()
  80. def forward(
  81. self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: torch.Tensor | None = None
  82. ) -> torch.Tensor:
  83. """`input_ids_shape` is expected to be [bsz x seqlen]."""
  84. if position_ids is None:
  85. bsz, seq_len = input_ids_shape[:2]
  86. position_ids = torch.arange(
  87. past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
  88. )
  89. return super().forward(position_ids)
  90. # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
  91. def eager_attention_forward(
  92. module: nn.Module,
  93. query: torch.Tensor,
  94. key: torch.Tensor,
  95. value: torch.Tensor,
  96. attention_mask: torch.Tensor | None,
  97. scaling: float | None = None,
  98. dropout: float = 0.0,
  99. **kwargs: Unpack[TransformersKwargs],
  100. ):
  101. if scaling is None:
  102. scaling = query.size(-1) ** -0.5
  103. # Take the dot product between "query" and "key" to get the raw attention scores.
  104. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  105. if attention_mask is not None:
  106. attn_weights = attn_weights + attention_mask
  107. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  108. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  109. attn_output = torch.matmul(attn_weights, value)
  110. attn_output = attn_output.transpose(1, 2).contiguous()
  111. return attn_output, attn_weights
  112. # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Marian
  113. class MarianAttention(nn.Module):
  114. """Multi-headed attention from 'Attention Is All You Need' paper"""
  115. def __init__(
  116. self,
  117. embed_dim: int,
  118. num_heads: int,
  119. dropout: float = 0.0,
  120. is_decoder: bool = False,
  121. bias: bool = True,
  122. is_causal: bool = False,
  123. config: MarianConfig | None = None,
  124. layer_idx: int | None = None,
  125. ):
  126. super().__init__()
  127. self.embed_dim = embed_dim
  128. self.num_heads = num_heads
  129. self.dropout = dropout
  130. self.head_dim = embed_dim // num_heads
  131. self.config = config
  132. if (self.head_dim * num_heads) != self.embed_dim:
  133. raise ValueError(
  134. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  135. f" and `num_heads`: {num_heads})."
  136. )
  137. self.scaling = self.head_dim**-0.5
  138. self.is_decoder = is_decoder
  139. self.is_causal = is_causal
  140. self.layer_idx = layer_idx
  141. if layer_idx is None and self.is_decoder:
  142. logger.warning_once(
  143. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  144. "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  145. "when creating this class."
  146. )
  147. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  148. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  149. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  150. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  151. def forward(
  152. self,
  153. hidden_states: torch.Tensor,
  154. key_value_states: torch.Tensor | None = None,
  155. past_key_values: Cache | None = None,
  156. attention_mask: torch.Tensor | None = None,
  157. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  158. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  159. **kwargs: Unpack[FlashAttentionKwargs],
  160. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  161. """Input shape: Batch x Time x Channel"""
  162. # if key_value_states are provided this layer is used as a cross-attention layer
  163. # for the decoder
  164. is_cross_attention = key_value_states is not None
  165. # determine input shapes
  166. input_shape = hidden_states.shape[:-1]
  167. hidden_shape = (*input_shape, -1, self.head_dim)
  168. # get query proj
  169. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  170. is_updated = False
  171. if past_key_values is not None:
  172. if isinstance(past_key_values, EncoderDecoderCache):
  173. is_updated = past_key_values.is_updated.get(self.layer_idx)
  174. if is_cross_attention:
  175. # after the first generated id, we can subsequently re-use all key/value_states from cache
  176. curr_past_key_values = past_key_values.cross_attention_cache
  177. else:
  178. curr_past_key_values = past_key_values.self_attention_cache
  179. else:
  180. curr_past_key_values = past_key_values
  181. current_states = key_value_states if is_cross_attention else hidden_states
  182. if is_cross_attention and past_key_values is not None and is_updated:
  183. # reuse k,v, cross_attentions
  184. key_states = curr_past_key_values.layers[self.layer_idx].keys
  185. value_states = curr_past_key_values.layers[self.layer_idx].values
  186. else:
  187. key_states = self.k_proj(current_states)
  188. value_states = self.v_proj(current_states)
  189. kv_shape = (*current_states.shape[:-1], -1, self.head_dim)
  190. key_states = key_states.view(kv_shape).transpose(1, 2)
  191. value_states = value_states.view(kv_shape).transpose(1, 2)
  192. if past_key_values is not None:
  193. key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
  194. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  195. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  196. past_key_values.is_updated[self.layer_idx] = True
  197. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  198. self.config._attn_implementation, eager_attention_forward
  199. )
  200. attn_output, attn_weights = attention_interface(
  201. self,
  202. query_states,
  203. key_states,
  204. value_states,
  205. attention_mask,
  206. dropout=0.0 if not self.training else self.dropout,
  207. scaling=self.scaling,
  208. **kwargs,
  209. )
  210. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  211. attn_output = self.out_proj(attn_output)
  212. return attn_output, attn_weights
  213. # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->Marian, BART->MARIAN
  214. class MarianEncoderLayer(GradientCheckpointingLayer):
  215. def __init__(self, config: MarianConfig, layer_idx: int | None = None):
  216. super().__init__()
  217. self.embed_dim = config.d_model
  218. self.self_attn = MarianAttention(
  219. embed_dim=self.embed_dim,
  220. num_heads=config.encoder_attention_heads,
  221. dropout=config.attention_dropout,
  222. config=config,
  223. layer_idx=layer_idx,
  224. )
  225. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  226. self.dropout = config.dropout
  227. self.activation_fn = ACT2FN[config.activation_function]
  228. self.activation_dropout = config.activation_dropout
  229. self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
  230. self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
  231. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  232. def forward(
  233. self,
  234. hidden_states: torch.FloatTensor,
  235. attention_mask: torch.FloatTensor,
  236. **kwargs: Unpack[TransformersKwargs],
  237. ) -> torch.Tensor:
  238. residual = hidden_states
  239. hidden_states, _ = self.self_attn(
  240. hidden_states,
  241. attention_mask=attention_mask,
  242. **kwargs,
  243. )
  244. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  245. hidden_states = residual + hidden_states
  246. hidden_states = self.self_attn_layer_norm(hidden_states)
  247. residual = hidden_states
  248. hidden_states = self.activation_fn(self.fc1(hidden_states))
  249. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  250. hidden_states = self.fc2(hidden_states)
  251. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  252. hidden_states = residual + hidden_states
  253. hidden_states = self.final_layer_norm(hidden_states)
  254. if hidden_states.dtype == torch.float16 and not torch.isfinite(hidden_states).all():
  255. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  256. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  257. return hidden_states
  258. # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->Marian, BART->MARIAN
  259. class MarianDecoderLayer(GradientCheckpointingLayer):
  260. def __init__(self, config: MarianConfig, layer_idx: int | None = None):
  261. super().__init__()
  262. self.embed_dim = config.d_model
  263. self.self_attn = MarianAttention(
  264. embed_dim=self.embed_dim,
  265. num_heads=config.decoder_attention_heads,
  266. dropout=config.attention_dropout,
  267. is_decoder=True,
  268. is_causal=True,
  269. config=config,
  270. layer_idx=layer_idx,
  271. )
  272. self.dropout = config.dropout
  273. self.activation_fn = ACT2FN[config.activation_function]
  274. self.activation_dropout = config.activation_dropout
  275. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  276. self.encoder_attn = MarianAttention(
  277. self.embed_dim,
  278. config.decoder_attention_heads,
  279. dropout=config.attention_dropout,
  280. is_decoder=True,
  281. config=config,
  282. layer_idx=layer_idx,
  283. )
  284. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  285. self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
  286. self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
  287. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  288. def forward(
  289. self,
  290. hidden_states: torch.Tensor,
  291. attention_mask: torch.Tensor | None = None,
  292. encoder_hidden_states: torch.Tensor | None = None,
  293. encoder_attention_mask: torch.Tensor | None = None,
  294. past_key_values: Cache | None = None,
  295. use_cache: bool | None = True,
  296. **kwargs: Unpack[TransformersKwargs],
  297. ) -> torch.Tensor:
  298. residual = hidden_states
  299. # Self Attention
  300. hidden_states, _ = self.self_attn(
  301. hidden_states,
  302. past_key_values=past_key_values,
  303. attention_mask=attention_mask,
  304. **kwargs,
  305. )
  306. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  307. hidden_states = residual + hidden_states
  308. hidden_states = self.self_attn_layer_norm(hidden_states)
  309. # Cross-Attention Block
  310. if encoder_hidden_states is not None:
  311. residual = hidden_states
  312. hidden_states, _ = self.encoder_attn(
  313. hidden_states,
  314. key_value_states=encoder_hidden_states,
  315. attention_mask=encoder_attention_mask,
  316. past_key_values=past_key_values,
  317. **kwargs,
  318. )
  319. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  320. hidden_states = residual + hidden_states
  321. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  322. # Fully Connected
  323. residual = hidden_states
  324. hidden_states = self.activation_fn(self.fc1(hidden_states))
  325. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  326. hidden_states = self.fc2(hidden_states)
  327. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  328. hidden_states = residual + hidden_states
  329. hidden_states = self.final_layer_norm(hidden_states)
  330. return hidden_states
  331. @auto_docstring
  332. class MarianPreTrainedModel(PreTrainedModel):
  333. config: MarianConfig
  334. base_model_prefix = "model"
  335. supports_gradient_checkpointing = True
  336. _supports_flash_attn = True
  337. _supports_sdpa = True
  338. _supports_flex_attn = True
  339. _can_compile_fullgraph = True
  340. @torch.no_grad()
  341. def _init_weights(self, module):
  342. super()._init_weights(module)
  343. if isinstance(module, MarianSinusoidalPositionalEmbedding):
  344. init.copy_(module.weight, module.create_weight())
  345. elif isinstance(module, MarianMTModel):
  346. init.zeros_(module.final_logits_bias)
  347. @property
  348. def dummy_inputs(self):
  349. pad_token = self.config.pad_token_id
  350. input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
  351. dummy_inputs = {
  352. "attention_mask": input_ids.ne(pad_token),
  353. "input_ids": input_ids,
  354. "decoder_input_ids": input_ids,
  355. }
  356. return dummy_inputs
  357. class MarianEncoder(MarianPreTrainedModel):
  358. """
  359. Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
  360. [`MarianEncoderLayer`].
  361. Args:
  362. config: MarianConfig
  363. embed_tokens (nn.Embedding): output embedding
  364. """
  365. _can_record_outputs = {
  366. "hidden_states": MarianEncoderLayer,
  367. "attentions": MarianAttention,
  368. }
  369. def __init__(self, config: MarianConfig):
  370. super().__init__(config)
  371. self.dropout = config.dropout
  372. self.layerdrop = config.encoder_layerdrop
  373. embed_dim = config.d_model
  374. self.padding_idx = config.pad_token_id
  375. self.max_source_positions = config.max_position_embeddings
  376. self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
  377. self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
  378. self.embed_positions = MarianSinusoidalPositionalEmbedding(
  379. config.max_position_embeddings, embed_dim, self.padding_idx
  380. )
  381. self.layers = nn.ModuleList([MarianEncoderLayer(config) for _ in range(config.encoder_layers)])
  382. self.gradient_checkpointing = False
  383. # Initialize weights and apply final processing
  384. self.post_init()
  385. @merge_with_config_defaults
  386. @capture_outputs
  387. @auto_docstring
  388. def forward(
  389. self,
  390. input_ids: torch.LongTensor | None = None,
  391. attention_mask: torch.LongTensor | None = None,
  392. inputs_embeds: torch.FloatTensor | None = None,
  393. **kwargs: Unpack[TransformersKwargs],
  394. ) -> BaseModelOutput:
  395. if (input_ids is None) ^ (inputs_embeds is not None):
  396. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  397. if inputs_embeds is None:
  398. inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
  399. embed_pos = self.embed_positions(inputs_embeds.shape[:-1])
  400. hidden_states = inputs_embeds + embed_pos
  401. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  402. attention_mask = create_bidirectional_mask(
  403. config=self.config,
  404. inputs_embeds=inputs_embeds,
  405. attention_mask=attention_mask,
  406. )
  407. for idx, encoder_layer in enumerate(self.layers):
  408. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  409. to_drop = False
  410. if self.training:
  411. dropout_probability = torch.rand([])
  412. if dropout_probability < self.layerdrop: # skip the layer
  413. to_drop = True
  414. if not to_drop:
  415. hidden_states = encoder_layer(
  416. hidden_states,
  417. attention_mask,
  418. **kwargs,
  419. )
  420. return BaseModelOutput(
  421. last_hidden_state=hidden_states,
  422. )
  423. class MarianDecoder(MarianPreTrainedModel):
  424. """
  425. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MarianDecoderLayer`]
  426. Args:
  427. config: MarianConfig
  428. embed_tokens (nn.Embedding): output embedding
  429. """
  430. _can_record_outputs = {
  431. "hidden_states": MarianDecoderLayer,
  432. "attentions": OutputRecorder(MarianAttention, index=1, layer_name="self_attn"),
  433. "cross_attentions": OutputRecorder(MarianAttention, index=1, layer_name="encoder_attn"),
  434. }
  435. def __init__(self, config: MarianConfig):
  436. super().__init__(config)
  437. self.dropout = config.dropout
  438. self.layerdrop = config.decoder_layerdrop
  439. self.padding_idx = config.pad_token_id
  440. self.max_target_positions = config.max_position_embeddings
  441. self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  442. self.embed_tokens = nn.Embedding(config.decoder_vocab_size, config.d_model, self.padding_idx)
  443. self.embed_positions = MarianSinusoidalPositionalEmbedding(
  444. config.max_position_embeddings, config.d_model, self.padding_idx
  445. )
  446. self.layers = nn.ModuleList([MarianDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)])
  447. self.gradient_checkpointing = False
  448. # Initialize weights and apply final processing
  449. self.post_init()
  450. @merge_with_config_defaults
  451. @capture_outputs
  452. @auto_docstring
  453. def forward(
  454. self,
  455. input_ids: torch.LongTensor | None = None,
  456. attention_mask: torch.Tensor | None = None,
  457. encoder_hidden_states: torch.FloatTensor | None = None,
  458. encoder_attention_mask: torch.LongTensor | None = None,
  459. past_key_values: Cache | None = None,
  460. inputs_embeds: torch.FloatTensor | None = None,
  461. use_cache: bool | None = None,
  462. **kwargs: Unpack[TransformersKwargs],
  463. ) -> BaseModelOutputWithPastAndCrossAttentions:
  464. if (input_ids is None) ^ (inputs_embeds is not None):
  465. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  466. if inputs_embeds is None:
  467. inputs_embeds = self.embed_tokens(input_ids)
  468. # Important to apply outside of the above `if`, in case user passes `embeds`
  469. inputs_embeds = inputs_embeds * self.embed_scale
  470. # initialize `past_key_values`
  471. if use_cache and past_key_values is None:
  472. past_key_values = (
  473. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  474. if encoder_hidden_states is not None or self.config.is_encoder_decoder
  475. else DynamicCache(config=self.config)
  476. )
  477. batch_size, seq_length = inputs_embeds.size()[:-1]
  478. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  479. position_ids = torch.arange(seq_length, device=inputs_embeds.device) + past_key_values_length
  480. if attention_mask is None and not is_torchdynamo_compiling():
  481. # required mask seq length can be calculated via length of past cache
  482. mask_seq_length = past_key_values_length + seq_length
  483. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  484. self_attn_cache = (
  485. past_key_values.self_attention_cache
  486. if isinstance(past_key_values, EncoderDecoderCache)
  487. else past_key_values
  488. )
  489. causal_mask = create_causal_mask(
  490. config=self.config,
  491. inputs_embeds=inputs_embeds,
  492. attention_mask=attention_mask,
  493. past_key_values=self_attn_cache,
  494. )
  495. encoder_attention_mask = create_bidirectional_mask(
  496. config=self.config,
  497. inputs_embeds=inputs_embeds,
  498. attention_mask=encoder_attention_mask,
  499. encoder_hidden_states=encoder_hidden_states,
  500. )
  501. # embed positions
  502. position_ids = self.embed_positions(
  503. (batch_size, seq_length), past_key_values_length, position_ids=position_ids
  504. )
  505. hidden_states = inputs_embeds + position_ids
  506. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  507. for idx, decoder_layer in enumerate(self.layers):
  508. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  509. if self.training:
  510. dropout_probability = torch.rand([])
  511. if dropout_probability < self.layerdrop:
  512. continue
  513. hidden_states = decoder_layer(
  514. hidden_states,
  515. causal_mask,
  516. encoder_hidden_states, # as a positional argument for gradient checkpointing
  517. encoder_attention_mask=encoder_attention_mask,
  518. past_key_values=past_key_values,
  519. use_cache=use_cache,
  520. **kwargs,
  521. )
  522. return BaseModelOutputWithPastAndCrossAttentions(
  523. last_hidden_state=hidden_states,
  524. past_key_values=past_key_values,
  525. )
  526. @auto_docstring
  527. class MarianModel(MarianPreTrainedModel):
  528. _keys_to_ignore_on_load_missing = [
  529. "model.encoder.embed_positions.weight",
  530. "model.decoder.embed_positions.weight",
  531. ]
  532. def __init__(self, config: MarianConfig):
  533. super().__init__(config)
  534. padding_idx, vocab_size = config.pad_token_id, config.vocab_size
  535. # We always use self.shared for token embeddings to ensure compatibility with all marian models
  536. if self.config.share_encoder_decoder_embeddings:
  537. self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
  538. self._tied_weights_keys = {
  539. "decoder.embed_tokens.weight": "shared.weight",
  540. "encoder.embed_tokens.weight": "shared.weight",
  541. }
  542. else:
  543. self._tied_weights_keys = None
  544. self.encoder = MarianEncoder(config)
  545. self.decoder = MarianDecoder(config)
  546. # Initialize weights and apply final processing
  547. self.post_init()
  548. def get_input_embeddings(self):
  549. # This will return shared embeddings if they are shared else specific to encoder.
  550. return self.get_encoder().get_input_embeddings()
  551. def set_input_embeddings(self, value):
  552. if self.config.share_encoder_decoder_embeddings:
  553. self.shared = value
  554. self.encoder.embed_tokens = self.shared
  555. self.decoder.embed_tokens = self.shared
  556. else: # if not shared only set encoder embeedings
  557. self.encoder.embed_tokens = value
  558. def get_decoder_input_embeddings(self):
  559. if self.config.share_encoder_decoder_embeddings:
  560. raise ValueError(
  561. "`get_decoder_input_embeddings` should not be called if `config.share_encoder_decoder_embeddings` "
  562. "is `True`. Please use `get_input_embeddings` instead."
  563. )
  564. return self.get_decoder().get_input_embeddings()
  565. def set_decoder_input_embeddings(self, value):
  566. if self.config.share_encoder_decoder_embeddings:
  567. raise ValueError(
  568. "`config.share_encoder_decoder_embeddings` is set to `True` meaning the decoder input embeddings "
  569. "are shared with the encoder. In order to set the decoder input embeddings, you should simply set "
  570. "the encoder input embeddings by calling `set_input_embeddings` with the appropriate embeddings."
  571. )
  572. self.decoder.embed_tokens = value
  573. def resize_decoder_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
  574. if self.config.share_encoder_decoder_embeddings:
  575. raise ValueError(
  576. "`resize_decoder_token_embeddings` should not be called if `config.share_encoder_decoder_embeddings` "
  577. "is `True`. Please use `resize_token_embeddings` instead."
  578. )
  579. old_embeddings = self.get_decoder_input_embeddings()
  580. new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
  581. self.set_decoder_input_embeddings(new_embeddings)
  582. model_embeds = self.get_decoder_input_embeddings()
  583. if new_num_tokens is None:
  584. return model_embeds
  585. # Update base model and current model config
  586. self.config.decoder_vocab_size = new_num_tokens
  587. # Tie weights again if needed
  588. self.tie_weights()
  589. return model_embeds
  590. @can_return_tuple
  591. @auto_docstring
  592. def forward(
  593. self,
  594. input_ids: torch.LongTensor | None = None,
  595. attention_mask: torch.Tensor | None = None,
  596. decoder_input_ids: torch.LongTensor | None = None,
  597. decoder_attention_mask: torch.Tensor | None = None,
  598. encoder_outputs: tuple[torch.Tensor] | BaseModelOutput | None = None,
  599. past_key_values: Cache | None = None,
  600. inputs_embeds: torch.FloatTensor | None = None,
  601. decoder_inputs_embeds: torch.FloatTensor | None = None,
  602. use_cache: bool | None = None,
  603. **kwargs: Unpack[TransformersKwargs],
  604. ) -> Seq2SeqModelOutput:
  605. r"""
  606. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  607. Indices of decoder input sequence tokens in the vocabulary.
  608. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  609. [`PreTrainedTokenizer.__call__`] for details.
  610. [What are decoder input IDs?](../glossary#decoder-input-ids)
  611. Marian uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
  612. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  613. `past_key_values`).
  614. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  615. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  616. be used by default.
  617. Example:
  618. ```python
  619. >>> from transformers import AutoTokenizer, MarianModel
  620. >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
  621. >>> model = MarianModel.from_pretrained("Helsinki-NLP/opus-mt-en-de")
  622. >>> inputs = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt")
  623. >>> decoder_inputs = tokenizer(
  624. ... "<pad> Studien haben gezeigt dass es hilfreich ist einen Hund zu besitzen",
  625. ... return_tensors="pt",
  626. ... add_special_tokens=False,
  627. ... )
  628. >>> outputs = model(input_ids=inputs.input_ids, decoder_input_ids=decoder_inputs.input_ids)
  629. >>> last_hidden_states = outputs.last_hidden_state
  630. >>> list(last_hidden_states.shape)
  631. [1, 26, 512]
  632. ```"""
  633. # If encoder_outputs are not given, pass the inputs to the encoder
  634. if encoder_outputs is None:
  635. encoder_outputs = self.encoder(
  636. input_ids=input_ids,
  637. attention_mask=attention_mask,
  638. inputs_embeds=inputs_embeds,
  639. **kwargs,
  640. )
  641. # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
  642. elif not isinstance(encoder_outputs, BaseModelOutput):
  643. encoder_outputs = BaseModelOutput(
  644. last_hidden_state=encoder_outputs[0],
  645. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  646. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  647. )
  648. # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
  649. decoder_outputs = self.decoder(
  650. input_ids=decoder_input_ids,
  651. attention_mask=decoder_attention_mask,
  652. encoder_hidden_states=encoder_outputs[0],
  653. encoder_attention_mask=attention_mask,
  654. past_key_values=past_key_values,
  655. inputs_embeds=decoder_inputs_embeds,
  656. use_cache=use_cache,
  657. **kwargs,
  658. )
  659. return Seq2SeqModelOutput(
  660. last_hidden_state=decoder_outputs.last_hidden_state,
  661. past_key_values=decoder_outputs.past_key_values,
  662. decoder_hidden_states=decoder_outputs.hidden_states,
  663. decoder_attentions=decoder_outputs.attentions,
  664. cross_attentions=decoder_outputs.cross_attentions,
  665. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  666. encoder_hidden_states=encoder_outputs.hidden_states,
  667. encoder_attentions=encoder_outputs.attentions,
  668. )
  669. @auto_docstring(
  670. custom_intro="""
  671. The Marian Model with a language modeling head. Can be used for summarization.
  672. """
  673. )
  674. class MarianMTModel(MarianPreTrainedModel, GenerationMixin):
  675. base_model_prefix = "model"
  676. _keys_to_ignore_on_load_missing = [
  677. "final_logits_bias",
  678. "model.encoder.embed_positions.weight",
  679. "model.decoder.embed_positions.weight",
  680. ]
  681. _keys_to_ignore_on_save = ["model.encoder.embed_positions.weight", "model.decoder.embed_positions.weight"]
  682. _tied_weights_keys = {"lm_head.weight": "model.decoder.embed_tokens.weight"}
  683. def __init__(self, config: MarianConfig):
  684. super().__init__(config)
  685. self.model = MarianModel(config)
  686. if self.config.share_encoder_decoder_embeddings:
  687. self._tied_weights_keys = {
  688. "lm_head.weight": "model.shared.weight",
  689. "model.decoder.embed_tokens.weight": "model.shared.weight",
  690. "model.encoder.embed_tokens.weight": "model.shared.weight",
  691. }
  692. target_vocab_size = config.vocab_size if config.share_encoder_decoder_embeddings else config.decoder_vocab_size
  693. self.register_buffer("final_logits_bias", torch.zeros((1, target_vocab_size)))
  694. self.lm_head = nn.Linear(config.d_model, target_vocab_size, bias=False)
  695. # Initialize weights and apply final processing
  696. self.post_init()
  697. def resize_token_embeddings(
  698. self, new_num_tokens: int, pad_to_multiple_of: int | None = None, mean_resizing: bool = True
  699. ) -> nn.Embedding:
  700. new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
  701. if self.config.share_encoder_decoder_embeddings:
  702. self._resize_final_logits_bias(new_num_tokens)
  703. return new_embeddings
  704. # NOTE: `_resize_token_embeddings` was rewritten in the base class, *args exists to absorb the extra arg
  705. def _resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of=None, *args) -> nn.Embedding:
  706. old_embeddings = self.get_input_embeddings()
  707. new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of)
  708. self.set_input_embeddings(new_embeddings)
  709. new_num_tokens = new_embeddings.weight.shape[0]
  710. # update config.decoder_vocab_size if embeddings are tied
  711. if self.config.share_encoder_decoder_embeddings:
  712. self.config.decoder_vocab_size = new_num_tokens
  713. # if word embeddings are not tied, make sure that lm head is resized as well
  714. if (
  715. self.config.share_encoder_decoder_embeddings
  716. and self.get_output_embeddings() is not None
  717. and not self.config.tie_word_embeddings
  718. ):
  719. old_lm_head = self.get_output_embeddings()
  720. new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
  721. self.set_output_embeddings(new_lm_head)
  722. return self.get_input_embeddings()
  723. def resize_decoder_token_embeddings(self, new_num_tokens):
  724. if self.config.share_encoder_decoder_embeddings:
  725. raise ValueError(
  726. "`resize_decoder_token_embeddings` should not be called if `config.share_encoder_decoder_embeddings` "
  727. "is `True`. Please use `resize_token_embeddings` instead."
  728. )
  729. old_embeddings = self.model.get_decoder_input_embeddings()
  730. new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
  731. self.model.set_decoder_input_embeddings(new_embeddings)
  732. # if word embeddings are not tied, make sure that lm head is resized as well
  733. if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
  734. old_lm_head = self.get_output_embeddings()
  735. new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
  736. self.set_output_embeddings(new_lm_head)
  737. model_embeds = self.model.get_decoder_input_embeddings()
  738. if new_num_tokens is None:
  739. return model_embeds
  740. # Update base model and current model config
  741. self.config.decoder_vocab_size = new_num_tokens
  742. # Tie weights again if needed
  743. self.tie_weights()
  744. self._resize_final_logits_bias(new_num_tokens)
  745. return model_embeds
  746. def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
  747. old_num_tokens = self.final_logits_bias.shape[-1]
  748. if new_num_tokens <= old_num_tokens:
  749. new_bias = self.final_logits_bias[:, :new_num_tokens]
  750. else:
  751. extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
  752. new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
  753. self.register_buffer("final_logits_bias", new_bias)
  754. def set_output_embeddings(self, new_embeddings: nn.Embedding):
  755. self.lm_head = new_embeddings
  756. @can_return_tuple
  757. @auto_docstring
  758. def forward(
  759. self,
  760. input_ids: torch.LongTensor | None = None,
  761. attention_mask: torch.Tensor | None = None,
  762. decoder_input_ids: torch.LongTensor | None = None,
  763. decoder_attention_mask: torch.Tensor | None = None,
  764. encoder_outputs: tuple[torch.Tensor] | BaseModelOutput | None = None,
  765. past_key_values: Cache | None = None,
  766. inputs_embeds: torch.FloatTensor | None = None,
  767. decoder_inputs_embeds: torch.FloatTensor | None = None,
  768. labels: torch.LongTensor | None = None,
  769. use_cache: bool | None = None,
  770. **kwargs: Unpack[TransformersKwargs],
  771. ) -> Seq2SeqLMOutput:
  772. r"""
  773. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  774. Indices of decoder input sequence tokens in the vocabulary.
  775. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  776. [`PreTrainedTokenizer.__call__`] for details.
  777. [What are decoder input IDs?](../glossary#decoder-input-ids)
  778. Marian uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
  779. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  780. `past_key_values`).
  781. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  782. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  783. be used by default.
  784. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  785. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  786. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  787. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  788. Example:
  789. ```python
  790. >>> from transformers import AutoTokenizer, MarianMTModel
  791. >>> src = "fr" # source language
  792. >>> trg = "en" # target language
  793. >>> model_name = f"Helsinki-NLP/opus-mt-{src}-{trg}"
  794. >>> model = MarianMTModel.from_pretrained(model_name)
  795. >>> tokenizer = AutoTokenizer.from_pretrained(model_name)
  796. >>> sample_text = "où est l'arrêt de bus ?"
  797. >>> batch = tokenizer([sample_text], return_tensors="pt")
  798. >>> generated_ids = model.generate(**batch)
  799. >>> tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
  800. "Where's the bus stop?"
  801. ```
  802. """
  803. if labels is not None:
  804. if use_cache:
  805. logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
  806. use_cache = False
  807. if decoder_input_ids is None and decoder_inputs_embeds is None:
  808. decoder_input_ids = shift_tokens_right(
  809. labels, self.config.pad_token_id, self.config.decoder_start_token_id
  810. )
  811. outputs: Seq2SeqModelOutput = self.model(
  812. input_ids,
  813. attention_mask=attention_mask,
  814. decoder_input_ids=decoder_input_ids,
  815. encoder_outputs=encoder_outputs,
  816. decoder_attention_mask=decoder_attention_mask,
  817. past_key_values=past_key_values,
  818. inputs_embeds=inputs_embeds,
  819. decoder_inputs_embeds=decoder_inputs_embeds,
  820. use_cache=use_cache,
  821. **kwargs,
  822. )
  823. lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
  824. masked_lm_loss = None
  825. if labels is not None:
  826. loss_fct = CrossEntropyLoss()
  827. masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.decoder_vocab_size), labels.view(-1))
  828. return Seq2SeqLMOutput(
  829. loss=masked_lm_loss,
  830. logits=lm_logits,
  831. past_key_values=outputs.past_key_values,
  832. decoder_hidden_states=outputs.decoder_hidden_states,
  833. decoder_attentions=outputs.decoder_attentions,
  834. cross_attentions=outputs.cross_attentions,
  835. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  836. encoder_hidden_states=outputs.encoder_hidden_states,
  837. encoder_attentions=outputs.encoder_attentions,
  838. )
  839. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  840. return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
  841. # Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Marian
  842. class MarianDecoderWrapper(MarianPreTrainedModel):
  843. """
  844. This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
  845. used in combination with the [`EncoderDecoderModel`] framework.
  846. """
  847. def __init__(self, config):
  848. super().__init__(config)
  849. self.decoder = MarianDecoder(config)
  850. self.post_init()
  851. def forward(self, *args, **kwargs):
  852. return self.decoder(*args, **kwargs)
  853. # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Marian, facebook/bart-base->Helsinki-NLP/opus-mt-fr-en
  854. class MarianForCausalLM(MarianPreTrainedModel, GenerationMixin):
  855. _tied_weights_keys = {
  856. "lm_head.weight": "model.decoder.embed_tokens.weight",
  857. }
  858. def __init__(self, config):
  859. config.is_decoder = True
  860. config.is_encoder_decoder = False
  861. super().__init__(config)
  862. self.model = MarianDecoderWrapper(config)
  863. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  864. # Initialize weights and apply final processing
  865. self.post_init()
  866. def get_input_embeddings(self):
  867. return self.model.decoder.embed_tokens
  868. def set_input_embeddings(self, value):
  869. self.model.decoder.embed_tokens = value
  870. @can_return_tuple
  871. @auto_docstring
  872. def forward(
  873. self,
  874. input_ids: torch.LongTensor | None = None,
  875. attention_mask: torch.Tensor | None = None,
  876. encoder_hidden_states: torch.FloatTensor | None = None,
  877. encoder_attention_mask: torch.FloatTensor | None = None,
  878. past_key_values: Cache | None = None,
  879. inputs_embeds: torch.FloatTensor | None = None,
  880. labels: torch.LongTensor | None = None,
  881. use_cache: bool | None = None,
  882. logits_to_keep: int | torch.Tensor = 0,
  883. **kwargs: Unpack[TransformersKwargs],
  884. ) -> tuple | CausalLMOutputWithCrossAttentions:
  885. r"""
  886. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  887. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  888. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  889. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  890. Example:
  891. ```python
  892. >>> from transformers import AutoTokenizer, MarianForCausalLM
  893. >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en")
  894. >>> model = MarianForCausalLM.from_pretrained("Helsinki-NLP/opus-mt-fr-en")
  895. >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
  896. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  897. >>> outputs = model(**inputs)
  898. >>> logits = outputs.logits
  899. >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
  900. >>> list(logits.shape) == expected_shape
  901. True
  902. ```"""
  903. outputs: BaseModelOutputWithPastAndCrossAttentions = self.model.decoder(
  904. input_ids=input_ids,
  905. attention_mask=attention_mask,
  906. encoder_hidden_states=encoder_hidden_states,
  907. encoder_attention_mask=encoder_attention_mask,
  908. past_key_values=past_key_values,
  909. inputs_embeds=inputs_embeds,
  910. use_cache=use_cache,
  911. **kwargs,
  912. )
  913. hidden_states = outputs[0]
  914. # Only compute necessary logits
  915. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  916. logits = self.lm_head(hidden_states[:, slice_indices, :])
  917. loss = None
  918. if labels is not None:
  919. labels = labels.to(logits.device)
  920. loss_fct = CrossEntropyLoss()
  921. loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
  922. return CausalLMOutputWithCrossAttentions(
  923. loss=loss,
  924. logits=logits,
  925. past_key_values=outputs.past_key_values,
  926. hidden_states=outputs.hidden_states,
  927. attentions=outputs.attentions,
  928. cross_attentions=outputs.cross_attentions,
  929. )
  930. __all__ = ["MarianForCausalLM", "MarianModel", "MarianMTModel", "MarianPreTrainedModel"]