modeling_mbart.py 60 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396
  1. # Copyright 2021, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch MBART model."""
  15. import math
  16. from collections.abc import Callable
  17. import torch
  18. from torch import nn
  19. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  23. from ...generation import GenerationMixin
  24. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  25. from ...modeling_flash_attention_utils import (
  26. FlashAttentionKwargs,
  27. )
  28. from ...modeling_layers import GradientCheckpointingLayer
  29. from ...modeling_outputs import (
  30. BaseModelOutput,
  31. BaseModelOutputWithPastAndCrossAttentions,
  32. CausalLMOutputWithCrossAttentions,
  33. Seq2SeqLMOutput,
  34. Seq2SeqModelOutput,
  35. Seq2SeqQuestionAnsweringModelOutput,
  36. Seq2SeqSequenceClassifierOutput,
  37. )
  38. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  39. from ...processing_utils import Unpack
  40. from ...utils import (
  41. TransformersKwargs,
  42. auto_docstring,
  43. can_return_tuple,
  44. is_torch_flex_attn_available,
  45. is_torchdynamo_compiling,
  46. logging,
  47. torch_compilable_check,
  48. )
  49. from ...utils.generic import merge_with_config_defaults
  50. from ...utils.output_capturing import OutputRecorder, capture_outputs
  51. from .configuration_mbart import MBartConfig
  52. if is_torch_flex_attn_available():
  53. pass
  54. logger = logging.get_logger(__name__)
  55. def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
  56. """
  57. Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
  58. have a single `decoder_start_token_id` in contrast to other Bart-like models.
  59. """
  60. prev_output_tokens = input_ids.clone()
  61. if pad_token_id is None:
  62. raise ValueError("self.model.config.pad_token_id has to be defined.")
  63. # replace possible -100 values in labels by `pad_token_id`
  64. prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id)
  65. index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
  66. decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze()
  67. prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
  68. prev_output_tokens[:, 0] = decoder_start_tokens
  69. return prev_output_tokens
  70. # Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->MBart
  71. class MBartLearnedPositionalEmbedding(nn.Embedding):
  72. """
  73. This module learns positional embeddings up to a fixed maximum size.
  74. """
  75. def __init__(self, num_embeddings: int, embedding_dim: int):
  76. # MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
  77. # and adjust num_embeddings appropriately. Other models don't have this hack
  78. self.offset = 2
  79. super().__init__(num_embeddings + self.offset, embedding_dim)
  80. def forward(
  81. self, input_ids: torch.Tensor, past_key_values_length: int = 0, position_ids: torch.Tensor | None = None
  82. ):
  83. """`input_ids' shape is expected to be [bsz x seqlen]."""
  84. if position_ids is None:
  85. bsz, seq_len = input_ids.shape[:2]
  86. position_ids = torch.arange(
  87. past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
  88. ).expand(bsz, -1)
  89. else:
  90. position_ids = position_ids.unsqueeze(0)
  91. return super().forward(position_ids + self.offset)
  92. # Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->MBart
  93. class MBartScaledWordEmbedding(nn.Embedding):
  94. """
  95. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  96. """
  97. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float | None = 1.0):
  98. super().__init__(num_embeddings, embedding_dim, padding_idx)
  99. self.embed_scale = embed_scale
  100. def forward(self, input_ids: torch.Tensor):
  101. return super().forward(input_ids) * self.embed_scale
  102. # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
  103. def eager_attention_forward(
  104. module: nn.Module,
  105. query: torch.Tensor,
  106. key: torch.Tensor,
  107. value: torch.Tensor,
  108. attention_mask: torch.Tensor | None,
  109. scaling: float | None = None,
  110. dropout: float = 0.0,
  111. **kwargs: Unpack[TransformersKwargs],
  112. ):
  113. if scaling is None:
  114. scaling = query.size(-1) ** -0.5
  115. # Take the dot product between "query" and "key" to get the raw attention scores.
  116. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  117. if attention_mask is not None:
  118. attn_weights = attn_weights + attention_mask
  119. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  120. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  121. attn_output = torch.matmul(attn_weights, value)
  122. attn_output = attn_output.transpose(1, 2).contiguous()
  123. return attn_output, attn_weights
  124. # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->MBart
  125. class MBartAttention(nn.Module):
  126. """Multi-headed attention from 'Attention Is All You Need' paper"""
  127. def __init__(
  128. self,
  129. embed_dim: int,
  130. num_heads: int,
  131. dropout: float = 0.0,
  132. is_decoder: bool = False,
  133. bias: bool = True,
  134. is_causal: bool = False,
  135. config: MBartConfig | None = None,
  136. layer_idx: int | None = None,
  137. ):
  138. super().__init__()
  139. self.embed_dim = embed_dim
  140. self.num_heads = num_heads
  141. self.dropout = dropout
  142. self.head_dim = embed_dim // num_heads
  143. self.config = config
  144. if (self.head_dim * num_heads) != self.embed_dim:
  145. raise ValueError(
  146. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  147. f" and `num_heads`: {num_heads})."
  148. )
  149. self.scaling = self.head_dim**-0.5
  150. self.is_decoder = is_decoder
  151. self.is_causal = is_causal
  152. self.layer_idx = layer_idx
  153. if layer_idx is None and self.is_decoder:
  154. logger.warning_once(
  155. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  156. "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  157. "when creating this class."
  158. )
  159. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  160. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  161. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  162. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  163. def forward(
  164. self,
  165. hidden_states: torch.Tensor,
  166. key_value_states: torch.Tensor | None = None,
  167. past_key_values: Cache | None = None,
  168. attention_mask: torch.Tensor | None = None,
  169. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  170. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  171. **kwargs: Unpack[FlashAttentionKwargs],
  172. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  173. """Input shape: Batch x Time x Channel"""
  174. # if key_value_states are provided this layer is used as a cross-attention layer
  175. # for the decoder
  176. is_cross_attention = key_value_states is not None
  177. # determine input shapes
  178. input_shape = hidden_states.shape[:-1]
  179. hidden_shape = (*input_shape, -1, self.head_dim)
  180. # get query proj
  181. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  182. is_updated = False
  183. if past_key_values is not None:
  184. if isinstance(past_key_values, EncoderDecoderCache):
  185. is_updated = past_key_values.is_updated.get(self.layer_idx)
  186. if is_cross_attention:
  187. # after the first generated id, we can subsequently re-use all key/value_states from cache
  188. curr_past_key_values = past_key_values.cross_attention_cache
  189. else:
  190. curr_past_key_values = past_key_values.self_attention_cache
  191. else:
  192. curr_past_key_values = past_key_values
  193. current_states = key_value_states if is_cross_attention else hidden_states
  194. if is_cross_attention and past_key_values is not None and is_updated:
  195. # reuse k,v, cross_attentions
  196. key_states = curr_past_key_values.layers[self.layer_idx].keys
  197. value_states = curr_past_key_values.layers[self.layer_idx].values
  198. else:
  199. key_states = self.k_proj(current_states)
  200. value_states = self.v_proj(current_states)
  201. kv_shape = (*current_states.shape[:-1], -1, self.head_dim)
  202. key_states = key_states.view(kv_shape).transpose(1, 2)
  203. value_states = value_states.view(kv_shape).transpose(1, 2)
  204. if past_key_values is not None:
  205. key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
  206. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  207. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  208. past_key_values.is_updated[self.layer_idx] = True
  209. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  210. self.config._attn_implementation, eager_attention_forward
  211. )
  212. attn_output, attn_weights = attention_interface(
  213. self,
  214. query_states,
  215. key_states,
  216. value_states,
  217. attention_mask,
  218. dropout=0.0 if not self.training else self.dropout,
  219. scaling=self.scaling,
  220. **kwargs,
  221. )
  222. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  223. attn_output = self.out_proj(attn_output)
  224. return attn_output, attn_weights
  225. class MBartEncoderLayer(GradientCheckpointingLayer):
  226. def __init__(self, config: MBartConfig):
  227. super().__init__()
  228. self.embed_dim = config.d_model
  229. self.self_attn = MBartAttention(
  230. embed_dim=self.embed_dim,
  231. num_heads=config.encoder_attention_heads,
  232. dropout=config.attention_dropout,
  233. config=config,
  234. )
  235. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  236. self.dropout = config.dropout
  237. self.activation_fn = ACT2FN[config.activation_function]
  238. self.activation_dropout = config.activation_dropout
  239. self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
  240. self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
  241. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  242. def forward(
  243. self,
  244. hidden_states: torch.Tensor,
  245. attention_mask: torch.Tensor,
  246. **kwargs: Unpack[TransformersKwargs],
  247. ) -> torch.Tensor:
  248. """
  249. Args:
  250. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  251. attention_mask (`torch.FloatTensor`): attention mask of size
  252. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  253. """
  254. residual = hidden_states
  255. hidden_states = self.self_attn_layer_norm(hidden_states)
  256. hidden_states, _ = self.self_attn(
  257. hidden_states=hidden_states,
  258. attention_mask=attention_mask,
  259. **kwargs,
  260. )
  261. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  262. hidden_states = residual + hidden_states
  263. residual = hidden_states
  264. hidden_states = self.final_layer_norm(hidden_states)
  265. hidden_states = self.activation_fn(self.fc1(hidden_states))
  266. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  267. hidden_states = self.fc2(hidden_states)
  268. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  269. hidden_states = residual + hidden_states
  270. if hidden_states.dtype == torch.float16:
  271. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  272. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  273. return hidden_states
  274. class MBartDecoderLayer(GradientCheckpointingLayer):
  275. def __init__(self, config: MBartConfig, layer_idx: int | None = None):
  276. super().__init__()
  277. self.embed_dim = config.d_model
  278. self.self_attn = MBartAttention(
  279. embed_dim=self.embed_dim,
  280. num_heads=config.decoder_attention_heads,
  281. dropout=config.attention_dropout,
  282. is_decoder=True,
  283. is_causal=True,
  284. config=config,
  285. layer_idx=layer_idx,
  286. )
  287. self.dropout = config.dropout
  288. self.activation_fn = ACT2FN[config.activation_function]
  289. self.activation_dropout = config.activation_dropout
  290. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  291. self.encoder_attn = MBartAttention(
  292. self.embed_dim,
  293. config.decoder_attention_heads,
  294. dropout=config.attention_dropout,
  295. is_decoder=True,
  296. config=config,
  297. layer_idx=layer_idx,
  298. )
  299. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  300. self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
  301. self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
  302. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  303. def forward(
  304. self,
  305. hidden_states: torch.Tensor,
  306. attention_mask: torch.Tensor | None = None,
  307. encoder_hidden_states: torch.Tensor | None = None,
  308. encoder_attention_mask: torch.Tensor | None = None,
  309. past_key_values: Cache | None = None,
  310. use_cache: bool | None = True,
  311. **kwargs: Unpack[TransformersKwargs],
  312. ) -> torch.Tensor:
  313. """
  314. Args:
  315. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  316. attention_mask (`torch.FloatTensor`): attention mask of size
  317. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  318. encoder_hidden_states (`torch.FloatTensor`):
  319. cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
  320. encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
  321. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  322. past_key_values (`Cache`): cached past key and value projection states
  323. """
  324. residual = hidden_states
  325. hidden_states = self.self_attn_layer_norm(hidden_states)
  326. # Self Attention
  327. hidden_states, _ = self.self_attn(
  328. hidden_states=hidden_states,
  329. past_key_values=past_key_values,
  330. attention_mask=attention_mask,
  331. **kwargs,
  332. )
  333. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  334. hidden_states = residual + hidden_states
  335. # Cross-Attention Block
  336. if encoder_hidden_states is not None:
  337. residual = hidden_states
  338. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  339. hidden_states, _ = self.encoder_attn(
  340. hidden_states=hidden_states,
  341. key_value_states=encoder_hidden_states,
  342. attention_mask=encoder_attention_mask,
  343. past_key_values=past_key_values,
  344. **kwargs,
  345. )
  346. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  347. hidden_states = residual + hidden_states
  348. # Fully Connected
  349. residual = hidden_states
  350. hidden_states = self.final_layer_norm(hidden_states)
  351. hidden_states = self.activation_fn(self.fc1(hidden_states))
  352. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  353. hidden_states = self.fc2(hidden_states)
  354. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  355. hidden_states = residual + hidden_states
  356. return hidden_states
  357. # Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->MBart
  358. class MBartClassificationHead(nn.Module):
  359. """Head for sentence-level classification tasks."""
  360. def __init__(
  361. self,
  362. input_dim: int,
  363. inner_dim: int,
  364. num_classes: int,
  365. pooler_dropout: float,
  366. ):
  367. super().__init__()
  368. self.dense = nn.Linear(input_dim, inner_dim)
  369. self.dropout = nn.Dropout(p=pooler_dropout)
  370. self.out_proj = nn.Linear(inner_dim, num_classes)
  371. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  372. hidden_states = self.dropout(hidden_states)
  373. hidden_states = self.dense(hidden_states)
  374. hidden_states = torch.tanh(hidden_states)
  375. hidden_states = self.dropout(hidden_states)
  376. hidden_states = self.out_proj(hidden_states)
  377. return hidden_states
  378. @auto_docstring
  379. class MBartPreTrainedModel(PreTrainedModel):
  380. config: MBartConfig
  381. base_model_prefix = "model"
  382. supports_gradient_checkpointing = True
  383. _no_split_modules = ["MBartDecoderLayer", "MBartEncoderLayer", "MBartAttention"]
  384. _supports_flash_attn = True
  385. _supports_sdpa = True
  386. _supports_flex_attn = True
  387. _can_compile_fullgraph = True
  388. def _init_weights(self, module):
  389. super()._init_weights(module)
  390. if isinstance(module, MBartForConditionalGeneration):
  391. init.zeros_(module.final_logits_bias)
  392. @property
  393. def dummy_inputs(self):
  394. pad_token = self.config.pad_token_id
  395. input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
  396. dummy_inputs = {
  397. "attention_mask": input_ids.ne(pad_token),
  398. "input_ids": input_ids,
  399. }
  400. return dummy_inputs
  401. class MBartEncoder(MBartPreTrainedModel):
  402. """
  403. Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
  404. [`MBartEncoderLayer`].
  405. Args:
  406. config: MBartConfig
  407. embed_tokens (nn.Embedding): output embedding
  408. """
  409. _can_record_outputs = {
  410. "hidden_states": MBartEncoderLayer,
  411. "attentions": OutputRecorder(MBartAttention, index=1, layer_name="self_attn"),
  412. }
  413. def __init__(self, config: MBartConfig):
  414. super().__init__(config)
  415. self.dropout = config.dropout
  416. self.layerdrop = config.encoder_layerdrop
  417. embed_dim = config.d_model
  418. self.padding_idx = config.pad_token_id
  419. self.max_source_positions = config.max_position_embeddings
  420. embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
  421. self.embed_tokens = MBartScaledWordEmbedding(
  422. config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
  423. )
  424. self.embed_positions = MBartLearnedPositionalEmbedding(
  425. config.max_position_embeddings,
  426. embed_dim,
  427. )
  428. self.layers = nn.ModuleList([MBartEncoderLayer(config) for _ in range(config.encoder_layers)])
  429. self.config = config
  430. self.layernorm_embedding = nn.LayerNorm(embed_dim)
  431. self.layer_norm = nn.LayerNorm(config.d_model)
  432. self.gradient_checkpointing = False
  433. # Initialize weights and apply final processing
  434. self.post_init()
  435. def _backward_compatibility_gradient_checkpointing(self):
  436. # Override to not delete the attribute from the config
  437. if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
  438. self.gradient_checkpointing_enable()
  439. @merge_with_config_defaults
  440. @capture_outputs
  441. def forward(
  442. self,
  443. input_ids: torch.LongTensor | None = None,
  444. attention_mask: torch.Tensor | None = None,
  445. inputs_embeds: torch.FloatTensor | None = None,
  446. **kwargs: Unpack[TransformersKwargs],
  447. ) -> tuple | BaseModelOutput:
  448. r"""
  449. Args:
  450. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  451. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  452. provide it.
  453. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  454. [`PreTrainedTokenizer.__call__`] for details.
  455. [What are input IDs?](../glossary#input-ids)
  456. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  457. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  458. - 1 for tokens that are **not masked**,
  459. - 0 for tokens that are **masked**.
  460. [What are attention masks?](../glossary#attention-mask)
  461. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  462. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  463. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  464. than the model's internal embedding lookup matrix.
  465. """
  466. if (input_ids is None) ^ (inputs_embeds is not None):
  467. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  468. if inputs_embeds is None:
  469. inputs_embeds = self.embed_tokens(input_ids)
  470. embed_pos = self.embed_positions(inputs_embeds[..., -1]) # just for the shape
  471. hidden_states = inputs_embeds + embed_pos.to(inputs_embeds.device)
  472. hidden_states = self.layernorm_embedding(hidden_states)
  473. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  474. attention_mask = create_bidirectional_mask(
  475. config=self.config,
  476. inputs_embeds=inputs_embeds,
  477. attention_mask=attention_mask,
  478. )
  479. for idx, encoder_layer in enumerate(self.layers):
  480. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  481. to_drop = False
  482. if self.training:
  483. dropout_probability = torch.rand([])
  484. if dropout_probability < self.layerdrop: # skip the layer
  485. to_drop = True
  486. if not to_drop:
  487. hidden_states = encoder_layer(
  488. hidden_states,
  489. attention_mask,
  490. **kwargs,
  491. )
  492. hidden_states = self.layer_norm(hidden_states)
  493. return BaseModelOutput(last_hidden_state=hidden_states)
  494. class MBartDecoder(MBartPreTrainedModel):
  495. """
  496. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MBartDecoderLayer`]
  497. Args:
  498. config: MBartConfig
  499. embed_tokens (nn.Embedding): output embedding
  500. """
  501. _can_record_outputs = {
  502. "hidden_states": MBartDecoderLayer,
  503. "attentions": OutputRecorder(MBartAttention, index=1, layer_name="self_attn"),
  504. "cross_attentions": OutputRecorder(MBartAttention, index=1, layer_name="encoder_attn"),
  505. }
  506. def __init__(self, config: MBartConfig):
  507. super().__init__(config)
  508. self.dropout = config.dropout
  509. self.layerdrop = config.decoder_layerdrop
  510. self.padding_idx = config.pad_token_id
  511. self.max_target_positions = config.max_position_embeddings
  512. embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  513. self.embed_tokens = MBartScaledWordEmbedding(
  514. config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
  515. )
  516. self.embed_positions = MBartLearnedPositionalEmbedding(
  517. config.max_position_embeddings,
  518. config.d_model,
  519. )
  520. self.layers = nn.ModuleList([MBartDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)])
  521. self.config = config
  522. self.layernorm_embedding = nn.LayerNorm(config.d_model)
  523. self.layer_norm = nn.LayerNorm(config.d_model)
  524. self.gradient_checkpointing = False
  525. # Initialize weights and apply final processing
  526. self.post_init()
  527. @merge_with_config_defaults
  528. @capture_outputs
  529. def forward(
  530. self,
  531. input_ids: torch.LongTensor | None = None,
  532. attention_mask: torch.Tensor | None = None,
  533. encoder_hidden_states: torch.FloatTensor | None = None,
  534. encoder_attention_mask: torch.LongTensor | None = None,
  535. past_key_values: Cache | None = None,
  536. inputs_embeds: torch.FloatTensor | None = None,
  537. use_cache: bool | None = None,
  538. **kwargs: Unpack[TransformersKwargs],
  539. ) -> tuple | BaseModelOutputWithPastAndCrossAttentions:
  540. r"""
  541. Args:
  542. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  543. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  544. provide it.
  545. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  546. [`PreTrainedTokenizer.__call__`] for details.
  547. [What are input IDs?](../glossary#input-ids)
  548. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  549. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  550. - 1 for tokens that are **not masked**,
  551. - 0 for tokens that are **masked**.
  552. [What are attention masks?](../glossary#attention-mask)
  553. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  554. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  555. of the decoder.
  556. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
  557. Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
  558. selected in `[0, 1]`:
  559. - 1 for tokens that are **not masked**,
  560. - 0 for tokens that are **masked**.
  561. [What are attention masks?](../glossary#attention-mask)
  562. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  563. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  564. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
  565. cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  566. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
  567. that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
  568. all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  569. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  570. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  571. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  572. than the model's internal embedding lookup matrix.
  573. """
  574. if (input_ids is None) ^ (inputs_embeds is not None):
  575. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  576. if inputs_embeds is None:
  577. inputs_embeds = self.embed_tokens(input_ids)
  578. # initialize `past_key_values`
  579. if use_cache and past_key_values is None:
  580. past_key_values = (
  581. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  582. if encoder_hidden_states is not None or self.config.is_encoder_decoder
  583. else DynamicCache(config=self.config)
  584. )
  585. batch_size, seq_length = inputs_embeds.size()[:-1]
  586. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  587. position_ids = torch.arange(seq_length, device=inputs_embeds.device) + past_key_values_length
  588. if attention_mask is None and not is_torchdynamo_compiling():
  589. # required mask seq length can be calculated via length of past cache
  590. mask_seq_length = past_key_values_length + seq_length
  591. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  592. self_attn_cache = (
  593. past_key_values.self_attention_cache
  594. if isinstance(past_key_values, EncoderDecoderCache)
  595. else past_key_values
  596. )
  597. causal_mask = create_causal_mask(
  598. config=self.config,
  599. inputs_embeds=inputs_embeds,
  600. attention_mask=attention_mask,
  601. past_key_values=self_attn_cache,
  602. )
  603. encoder_attention_mask = create_bidirectional_mask(
  604. config=self.config,
  605. inputs_embeds=inputs_embeds,
  606. attention_mask=encoder_attention_mask,
  607. encoder_hidden_states=encoder_hidden_states,
  608. )
  609. # embed positions
  610. position_ids = self.embed_positions(input, past_key_values_length, position_ids=position_ids)
  611. hidden_states = inputs_embeds + position_ids.to(inputs_embeds.device)
  612. hidden_states = self.layernorm_embedding(hidden_states)
  613. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  614. for idx, decoder_layer in enumerate(self.layers):
  615. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  616. if self.training:
  617. dropout_probability = torch.rand([])
  618. if dropout_probability < self.layerdrop:
  619. continue
  620. hidden_states = decoder_layer(
  621. hidden_states,
  622. causal_mask,
  623. encoder_hidden_states, # as a positional argument for gradient checkpointing
  624. encoder_attention_mask=encoder_attention_mask,
  625. past_key_values=past_key_values,
  626. use_cache=use_cache,
  627. **kwargs,
  628. )
  629. hidden_states = self.layer_norm(hidden_states)
  630. return BaseModelOutputWithPastAndCrossAttentions(
  631. last_hidden_state=hidden_states,
  632. past_key_values=past_key_values,
  633. )
  634. @auto_docstring
  635. class MBartModel(MBartPreTrainedModel):
  636. _tied_weights_keys = {
  637. "decoder.embed_tokens.weight": "shared.weight",
  638. "encoder.embed_tokens.weight": "shared.weight",
  639. }
  640. def __init__(self, config: MBartConfig):
  641. super().__init__(config)
  642. padding_idx, vocab_size = config.pad_token_id, config.vocab_size
  643. embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  644. self.shared = MBartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
  645. self.encoder = MBartEncoder(config)
  646. self.decoder = MBartDecoder(config)
  647. # Initialize weights and apply final processing
  648. self.post_init()
  649. def get_input_embeddings(self):
  650. return self.shared
  651. def set_input_embeddings(self, value):
  652. self.shared = value
  653. self.encoder.embed_tokens = self.shared
  654. self.decoder.embed_tokens = self.shared
  655. @can_return_tuple
  656. @auto_docstring
  657. def forward(
  658. self,
  659. input_ids: torch.LongTensor | None = None,
  660. attention_mask: torch.Tensor | None = None,
  661. decoder_input_ids: torch.LongTensor | None = None,
  662. decoder_attention_mask: torch.LongTensor | None = None,
  663. encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
  664. past_key_values: Cache | None = None,
  665. inputs_embeds: torch.FloatTensor | None = None,
  666. decoder_inputs_embeds: torch.FloatTensor | None = None,
  667. use_cache: bool | None = None,
  668. return_dict: bool | None = None,
  669. **kwargs: Unpack[TransformersKwargs],
  670. ) -> Seq2SeqModelOutput | tuple[torch.FloatTensor]:
  671. r"""
  672. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  673. Indices of decoder input sequence tokens in the vocabulary.
  674. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  675. [`PreTrainedTokenizer.__call__`] for details.
  676. [What are decoder input IDs?](../glossary#decoder-input-ids)
  677. MBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
  678. varies according to source and target language, *e.g.* 25004 for *en_XX*, and 25003 for *de_DE*. If
  679. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  680. `past_key_values`).
  681. For translation and summarization training, `decoder_input_ids` should be provided. If no
  682. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  683. for denoising pre-training following the paper.
  684. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  685. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  686. be used by default.
  687. """
  688. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  689. # different to other models, MBart automatically creates decoder_input_ids from
  690. # input_ids if no decoder_input_ids are provided
  691. if decoder_input_ids is None and decoder_inputs_embeds is None:
  692. decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id)
  693. if encoder_outputs is None:
  694. encoder_outputs = self.encoder(
  695. input_ids=input_ids,
  696. attention_mask=attention_mask,
  697. inputs_embeds=inputs_embeds,
  698. return_dict=return_dict,
  699. **kwargs,
  700. )
  701. # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
  702. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  703. encoder_outputs = BaseModelOutput(
  704. last_hidden_state=encoder_outputs[0],
  705. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  706. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  707. )
  708. # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
  709. decoder_outputs = self.decoder(
  710. input_ids=decoder_input_ids,
  711. attention_mask=decoder_attention_mask,
  712. encoder_hidden_states=encoder_outputs[0],
  713. encoder_attention_mask=attention_mask,
  714. past_key_values=past_key_values,
  715. inputs_embeds=decoder_inputs_embeds,
  716. use_cache=use_cache,
  717. return_dict=return_dict,
  718. **kwargs,
  719. )
  720. if not return_dict:
  721. return decoder_outputs + encoder_outputs
  722. return Seq2SeqModelOutput(
  723. last_hidden_state=decoder_outputs.last_hidden_state,
  724. past_key_values=decoder_outputs.past_key_values,
  725. decoder_hidden_states=decoder_outputs.hidden_states,
  726. decoder_attentions=decoder_outputs.attentions,
  727. cross_attentions=decoder_outputs.cross_attentions,
  728. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  729. encoder_hidden_states=encoder_outputs.hidden_states,
  730. encoder_attentions=encoder_outputs.attentions,
  731. )
  732. @auto_docstring(
  733. custom_intro="""
  734. The MBART Model with a language modeling head. Can be used for summarization, after fine-tuning the pretrained models.
  735. """
  736. )
  737. class MBartForConditionalGeneration(MBartPreTrainedModel, GenerationMixin):
  738. base_model_prefix = "model"
  739. _keys_to_ignore_on_load_missing = ["final_logits_bias"]
  740. _tied_weights_keys = {"lm_head.weight": "model.shared.weight"}
  741. def __init__(self, config: MBartConfig):
  742. super().__init__(config)
  743. self.model = MBartModel(config)
  744. self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
  745. self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
  746. # Initialize weights and apply final processing
  747. self.post_init()
  748. def resize_token_embeddings(
  749. self, new_num_tokens: int, pad_to_multiple_of: int | None = None, mean_resizing: bool = True
  750. ) -> nn.Embedding:
  751. new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
  752. self._resize_final_logits_bias(new_embeddings.weight.shape[0])
  753. return new_embeddings
  754. def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
  755. old_num_tokens = self.final_logits_bias.shape[-1]
  756. if new_num_tokens <= old_num_tokens:
  757. new_bias = self.final_logits_bias[:, :new_num_tokens]
  758. else:
  759. extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
  760. new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
  761. self.register_buffer("final_logits_bias", new_bias)
  762. @auto_docstring
  763. def forward(
  764. self,
  765. input_ids: torch.LongTensor | None = None,
  766. attention_mask: torch.Tensor | None = None,
  767. decoder_input_ids: torch.LongTensor | None = None,
  768. decoder_attention_mask: torch.LongTensor | None = None,
  769. encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
  770. past_key_values: Cache | None = None,
  771. inputs_embeds: torch.FloatTensor | None = None,
  772. decoder_inputs_embeds: torch.FloatTensor | None = None,
  773. labels: torch.LongTensor | None = None,
  774. use_cache: bool | None = None,
  775. return_dict: bool | None = None,
  776. **kwargs: Unpack[TransformersKwargs],
  777. ) -> Seq2SeqLMOutput | tuple[torch.FloatTensor]:
  778. r"""
  779. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  780. Indices of decoder input sequence tokens in the vocabulary.
  781. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  782. [`PreTrainedTokenizer.__call__`] for details.
  783. [What are decoder input IDs?](../glossary#decoder-input-ids)
  784. MBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
  785. varies according to source and target language, *e.g.* 25004 for *en_XX*, and 25003 for *de_DE*. If
  786. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  787. `past_key_values`).
  788. For translation and summarization training, `decoder_input_ids` should be provided. If no
  789. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  790. for denoising pre-training following the paper.
  791. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  792. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  793. be used by default.
  794. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  795. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  796. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  797. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  798. Example Translation:
  799. ```python
  800. >>> from transformers import AutoTokenizer, MBartForConditionalGeneration
  801. >>> model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")
  802. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-en-ro")
  803. >>> example_english_phrase = "42 is the answer"
  804. >>> inputs = tokenizer(example_english_phrase, return_tensors="pt")
  805. >>> # Translate
  806. >>> generated_ids = model.generate(**inputs, num_beams=4, max_length=5)
  807. >>> tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  808. '42 este răspuns'
  809. ```
  810. Mask filling example:
  811. ```python
  812. >>> from transformers import AutoTokenizer, MBartForConditionalGeneration
  813. >>> model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25")
  814. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
  815. >>> # de_DE is the language symbol id <LID> for German
  816. >>> TXT = "</s> Meine Freunde sind <mask> nett aber sie essen zu viel Kuchen. </s> de_DE"
  817. >>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors="pt")["input_ids"]
  818. >>> logits = model(input_ids).logits
  819. >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
  820. >>> probs = logits[0, masked_index].softmax(dim=0)
  821. >>> values, predictions = probs.topk(5)
  822. >>> tokenizer.decode(predictions).split()
  823. ['nett', 'sehr', 'ganz', 'nicht', 'so']
  824. ```
  825. """
  826. return_dict = return_dict if return_dict is not None else self.config.return_dict
  827. if labels is not None:
  828. if use_cache:
  829. logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
  830. use_cache = False
  831. if decoder_input_ids is None and decoder_inputs_embeds is None:
  832. decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
  833. outputs = self.model(
  834. input_ids,
  835. attention_mask=attention_mask,
  836. decoder_input_ids=decoder_input_ids,
  837. encoder_outputs=encoder_outputs,
  838. decoder_attention_mask=decoder_attention_mask,
  839. past_key_values=past_key_values,
  840. inputs_embeds=inputs_embeds,
  841. decoder_inputs_embeds=decoder_inputs_embeds,
  842. use_cache=use_cache,
  843. return_dict=return_dict,
  844. **kwargs,
  845. )
  846. lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
  847. masked_lm_loss = None
  848. if labels is not None:
  849. loss_fct = CrossEntropyLoss()
  850. masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
  851. if not return_dict:
  852. output = (lm_logits,) + outputs[1:]
  853. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  854. return Seq2SeqLMOutput(
  855. loss=masked_lm_loss,
  856. logits=lm_logits,
  857. past_key_values=outputs.past_key_values,
  858. decoder_hidden_states=outputs.decoder_hidden_states,
  859. decoder_attentions=outputs.decoder_attentions,
  860. cross_attentions=outputs.cross_attentions,
  861. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  862. encoder_hidden_states=outputs.encoder_hidden_states,
  863. encoder_attentions=outputs.encoder_attentions,
  864. )
  865. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  866. return shift_tokens_right(labels, self.config.pad_token_id)
  867. @auto_docstring(
  868. custom_intro="""
  869. MBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
  870. tasks.
  871. """
  872. )
  873. class MBartForSequenceClassification(MBartPreTrainedModel):
  874. def __init__(self, config: MBartConfig, **kwargs):
  875. super().__init__(config, **kwargs)
  876. self.model = MBartModel(config)
  877. self.classification_head = MBartClassificationHead(
  878. config.d_model,
  879. config.d_model,
  880. config.num_labels,
  881. config.classifier_dropout,
  882. )
  883. # Initialize weights and apply final processing
  884. self.post_init()
  885. @can_return_tuple
  886. @auto_docstring
  887. # Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward
  888. def forward(
  889. self,
  890. input_ids: torch.LongTensor | None = None,
  891. attention_mask: torch.Tensor | None = None,
  892. decoder_input_ids: torch.LongTensor | None = None,
  893. decoder_attention_mask: torch.LongTensor | None = None,
  894. encoder_outputs: list[torch.FloatTensor] | None = None,
  895. inputs_embeds: torch.FloatTensor | None = None,
  896. decoder_inputs_embeds: torch.FloatTensor | None = None,
  897. labels: torch.LongTensor | None = None,
  898. use_cache: bool | None = None,
  899. **kwargs: Unpack[TransformersKwargs],
  900. ) -> tuple | Seq2SeqSequenceClassifierOutput:
  901. r"""
  902. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  903. Indices of decoder input sequence tokens in the vocabulary.
  904. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  905. [`PreTrainedTokenizer.__call__`] for details.
  906. [What are decoder input IDs?](../glossary#decoder-input-ids)
  907. Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  908. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  909. For translation and summarization training, `decoder_input_ids` should be provided. If no
  910. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  911. for denoising pre-training following the paper.
  912. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  913. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  914. be used by default.
  915. If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`]
  916. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  917. information on the default strategy.
  918. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  919. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  920. config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  921. """
  922. if labels is not None:
  923. use_cache = False
  924. if input_ids is None and inputs_embeds is not None:
  925. raise NotImplementedError(
  926. f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
  927. )
  928. outputs: Seq2SeqModelOutput = self.model(
  929. input_ids,
  930. attention_mask=attention_mask,
  931. decoder_input_ids=decoder_input_ids,
  932. decoder_attention_mask=decoder_attention_mask,
  933. encoder_outputs=encoder_outputs,
  934. inputs_embeds=inputs_embeds,
  935. decoder_inputs_embeds=decoder_inputs_embeds,
  936. use_cache=use_cache,
  937. **kwargs,
  938. )
  939. hidden_states = outputs[0] # last hidden state
  940. eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)
  941. torch_compilable_check(
  942. torch.unique_consecutive(eos_mask.sum(1)).numel() == 1,
  943. "All examples must have the same number of <eos> tokens.",
  944. )
  945. sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
  946. :, -1, :
  947. ]
  948. logits = self.classification_head(sentence_representation)
  949. loss = None
  950. if labels is not None:
  951. labels = labels.to(logits.device)
  952. if self.config.problem_type is None:
  953. if self.config.num_labels == 1:
  954. self.config.problem_type = "regression"
  955. elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  956. self.config.problem_type = "single_label_classification"
  957. else:
  958. self.config.problem_type = "multi_label_classification"
  959. if self.config.problem_type == "regression":
  960. loss_fct = MSELoss()
  961. if self.config.num_labels == 1:
  962. loss = loss_fct(logits.squeeze(), labels.squeeze())
  963. else:
  964. loss = loss_fct(logits, labels)
  965. elif self.config.problem_type == "single_label_classification":
  966. loss_fct = CrossEntropyLoss()
  967. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  968. elif self.config.problem_type == "multi_label_classification":
  969. loss_fct = BCEWithLogitsLoss()
  970. loss = loss_fct(logits, labels)
  971. return Seq2SeqSequenceClassifierOutput(
  972. loss=loss,
  973. logits=logits,
  974. past_key_values=outputs.past_key_values,
  975. decoder_hidden_states=outputs.decoder_hidden_states,
  976. decoder_attentions=outputs.decoder_attentions,
  977. cross_attentions=outputs.cross_attentions,
  978. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  979. encoder_hidden_states=outputs.encoder_hidden_states,
  980. encoder_attentions=outputs.encoder_attentions,
  981. )
  982. @auto_docstring
  983. class MBartForQuestionAnswering(MBartPreTrainedModel):
  984. def __init__(self, config):
  985. super().__init__(config)
  986. config.num_labels = 2
  987. self.num_labels = config.num_labels
  988. self.model = MBartModel(config)
  989. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  990. # Initialize weights and apply final processing
  991. self.post_init()
  992. @can_return_tuple
  993. @auto_docstring
  994. # Copied from transformers.models.bart.modeling_bart.BartForQuestionAnswering.forward
  995. def forward(
  996. self,
  997. input_ids: torch.Tensor | None = None,
  998. attention_mask: torch.Tensor | None = None,
  999. decoder_input_ids: torch.LongTensor | None = None,
  1000. decoder_attention_mask: torch.LongTensor | None = None,
  1001. encoder_outputs: list[torch.FloatTensor] | None = None,
  1002. start_positions: torch.LongTensor | None = None,
  1003. end_positions: torch.LongTensor | None = None,
  1004. inputs_embeds: torch.FloatTensor | None = None,
  1005. decoder_inputs_embeds: torch.FloatTensor | None = None,
  1006. use_cache: bool | None = None,
  1007. **kwargs: Unpack[TransformersKwargs],
  1008. ) -> tuple | Seq2SeqQuestionAnsweringModelOutput:
  1009. r"""
  1010. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1011. Indices of decoder input sequence tokens in the vocabulary.
  1012. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1013. [`PreTrainedTokenizer.__call__`] for details.
  1014. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1015. Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1016. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1017. For translation and summarization training, `decoder_input_ids` should be provided. If no
  1018. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  1019. for denoising pre-training following the paper.
  1020. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1021. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1022. be used by default.
  1023. If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`]
  1024. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  1025. information on the default strategy.
  1026. """
  1027. if start_positions is not None and end_positions is not None:
  1028. use_cache = False
  1029. outputs: Seq2SeqModelOutput = self.model(
  1030. input_ids,
  1031. attention_mask=attention_mask,
  1032. decoder_input_ids=decoder_input_ids,
  1033. decoder_attention_mask=decoder_attention_mask,
  1034. encoder_outputs=encoder_outputs,
  1035. inputs_embeds=inputs_embeds,
  1036. decoder_inputs_embeds=decoder_inputs_embeds,
  1037. use_cache=use_cache,
  1038. **kwargs,
  1039. )
  1040. sequence_output = outputs[0]
  1041. logits = self.qa_outputs(sequence_output)
  1042. start_logits, end_logits = logits.split(1, dim=-1)
  1043. start_logits = start_logits.squeeze(-1).contiguous()
  1044. end_logits = end_logits.squeeze(-1).contiguous()
  1045. total_loss = None
  1046. if start_positions is not None and end_positions is not None:
  1047. # If we are on multi-GPU, split add a dimension
  1048. if len(start_positions.size()) > 1:
  1049. start_positions = start_positions.squeeze(-1)
  1050. if len(end_positions.size()) > 1:
  1051. end_positions = end_positions.squeeze(-1)
  1052. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1053. ignored_index = start_logits.size(1)
  1054. start_positions = start_positions.clamp(0, ignored_index)
  1055. end_positions = end_positions.clamp(0, ignored_index)
  1056. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1057. start_loss = loss_fct(start_logits, start_positions)
  1058. end_loss = loss_fct(end_logits, end_positions)
  1059. total_loss = (start_loss + end_loss) / 2
  1060. return Seq2SeqQuestionAnsweringModelOutput(
  1061. loss=total_loss,
  1062. start_logits=start_logits,
  1063. end_logits=end_logits,
  1064. past_key_values=outputs.past_key_values,
  1065. decoder_hidden_states=outputs.decoder_hidden_states,
  1066. decoder_attentions=outputs.decoder_attentions,
  1067. cross_attentions=outputs.cross_attentions,
  1068. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1069. encoder_hidden_states=outputs.encoder_hidden_states,
  1070. encoder_attentions=outputs.encoder_attentions,
  1071. )
  1072. # Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->MBart
  1073. class MBartDecoderWrapper(MBartPreTrainedModel):
  1074. """
  1075. This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
  1076. used in combination with the [`EncoderDecoderModel`] framework.
  1077. """
  1078. def __init__(self, config):
  1079. super().__init__(config)
  1080. self.decoder = MBartDecoder(config)
  1081. self.post_init()
  1082. def forward(self, *args, **kwargs):
  1083. return self.decoder(*args, **kwargs)
  1084. # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->MBart, facebook/bart-base->facebook/mbart-large-cc25
  1085. class MBartForCausalLM(MBartPreTrainedModel, GenerationMixin):
  1086. _tied_weights_keys = {
  1087. "lm_head.weight": "model.decoder.embed_tokens.weight",
  1088. }
  1089. def __init__(self, config):
  1090. config.is_decoder = True
  1091. config.is_encoder_decoder = False
  1092. super().__init__(config)
  1093. self.model = MBartDecoderWrapper(config)
  1094. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1095. # Initialize weights and apply final processing
  1096. self.post_init()
  1097. def get_input_embeddings(self):
  1098. return self.model.decoder.embed_tokens
  1099. def set_input_embeddings(self, value):
  1100. self.model.decoder.embed_tokens = value
  1101. @can_return_tuple
  1102. @auto_docstring
  1103. def forward(
  1104. self,
  1105. input_ids: torch.LongTensor | None = None,
  1106. attention_mask: torch.Tensor | None = None,
  1107. encoder_hidden_states: torch.FloatTensor | None = None,
  1108. encoder_attention_mask: torch.FloatTensor | None = None,
  1109. past_key_values: Cache | None = None,
  1110. inputs_embeds: torch.FloatTensor | None = None,
  1111. labels: torch.LongTensor | None = None,
  1112. use_cache: bool | None = None,
  1113. logits_to_keep: int | torch.Tensor = 0,
  1114. **kwargs: Unpack[TransformersKwargs],
  1115. ) -> tuple | CausalLMOutputWithCrossAttentions:
  1116. r"""
  1117. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1118. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1119. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1120. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1121. Example:
  1122. ```python
  1123. >>> from transformers import AutoTokenizer, MBartForCausalLM
  1124. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
  1125. >>> model = MBartForCausalLM.from_pretrained("facebook/mbart-large-cc25")
  1126. >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
  1127. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  1128. >>> outputs = model(**inputs)
  1129. >>> logits = outputs.logits
  1130. >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
  1131. >>> list(logits.shape) == expected_shape
  1132. True
  1133. ```"""
  1134. outputs: BaseModelOutputWithPastAndCrossAttentions = self.model.decoder(
  1135. input_ids=input_ids,
  1136. attention_mask=attention_mask,
  1137. encoder_hidden_states=encoder_hidden_states,
  1138. encoder_attention_mask=encoder_attention_mask,
  1139. past_key_values=past_key_values,
  1140. inputs_embeds=inputs_embeds,
  1141. use_cache=use_cache,
  1142. **kwargs,
  1143. )
  1144. hidden_states = outputs[0]
  1145. # Only compute necessary logits
  1146. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  1147. logits = self.lm_head(hidden_states[:, slice_indices, :])
  1148. loss = None
  1149. if labels is not None:
  1150. labels = labels.to(logits.device)
  1151. loss_fct = CrossEntropyLoss()
  1152. loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
  1153. return CausalLMOutputWithCrossAttentions(
  1154. loss=loss,
  1155. logits=logits,
  1156. past_key_values=outputs.past_key_values,
  1157. hidden_states=outputs.hidden_states,
  1158. attentions=outputs.attentions,
  1159. cross_attentions=outputs.cross_attentions,
  1160. )
  1161. __all__ = [
  1162. "MBartForCausalLM",
  1163. "MBartForConditionalGeneration",
  1164. "MBartForQuestionAnswering",
  1165. "MBartForSequenceClassification",
  1166. "MBartModel",
  1167. "MBartPreTrainedModel",
  1168. ]