modeling_plbart.py 47 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/plbart/modular_plbart.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_plbart.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2022, UCLA NLP, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import math
  21. from collections.abc import Callable
  22. import torch
  23. from torch import nn
  24. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  25. from ... import initialization as init
  26. from ...activations import ACT2FN
  27. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  28. from ...generation import GenerationMixin
  29. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  30. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  31. from ...modeling_layers import GradientCheckpointingLayer
  32. from ...modeling_outputs import (
  33. BaseModelOutput,
  34. BaseModelOutputWithPastAndCrossAttentions,
  35. CausalLMOutputWithCrossAttentions,
  36. Seq2SeqLMOutput,
  37. Seq2SeqModelOutput,
  38. Seq2SeqSequenceClassifierOutput,
  39. )
  40. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  41. from ...processing_utils import Unpack
  42. from ...utils import (
  43. TransformersKwargs,
  44. auto_docstring,
  45. can_return_tuple,
  46. is_torchdynamo_compiling,
  47. logging,
  48. torch_compilable_check,
  49. )
  50. from ...utils.generic import merge_with_config_defaults
  51. from ...utils.output_capturing import OutputRecorder, capture_outputs
  52. from .configuration_plbart import PLBartConfig
  53. logger = logging.get_logger(__name__)
  54. class PLBartScaledWordEmbedding(nn.Embedding):
  55. """
  56. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  57. """
  58. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float | None = 1.0):
  59. super().__init__(num_embeddings, embedding_dim, padding_idx)
  60. self.embed_scale = embed_scale
  61. def forward(self, input_ids: torch.Tensor):
  62. return super().forward(input_ids) * self.embed_scale
  63. @auto_docstring
  64. class PLBartPreTrainedModel(PreTrainedModel):
  65. config: PLBartConfig
  66. base_model_prefix = "model"
  67. supports_gradient_checkpointing = True
  68. _no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"]
  69. _supports_flash_attn = True
  70. _supports_sdpa = True
  71. _supports_flex_attn = True
  72. def _init_weights(self, module):
  73. super()._init_weights(module)
  74. if isinstance(module, PLBartForConditionalGeneration):
  75. init.zeros_(module.final_logits_bias)
  76. class PLBartLearnedPositionalEmbedding(nn.Embedding):
  77. """
  78. This module learns positional embeddings up to a fixed maximum size.
  79. """
  80. def __init__(self, num_embeddings: int, embedding_dim: int):
  81. # PLBart is set up so that if padding_idx is specified then offset the embedding ids by 2
  82. # and adjust num_embeddings appropriately. Other models don't have this hack
  83. self.offset = 2
  84. super().__init__(num_embeddings + self.offset, embedding_dim)
  85. def forward(
  86. self, input_ids: torch.Tensor, past_key_values_length: int = 0, position_ids: torch.Tensor | None = None
  87. ):
  88. """`input_ids' shape is expected to be [bsz x seqlen]."""
  89. if position_ids is None:
  90. bsz, seq_len = input_ids.shape[:2]
  91. position_ids = torch.arange(
  92. past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
  93. ).expand(bsz, -1)
  94. else:
  95. position_ids = position_ids.unsqueeze(0)
  96. return super().forward(position_ids + self.offset)
  97. def eager_attention_forward(
  98. module: nn.Module,
  99. query: torch.Tensor,
  100. key: torch.Tensor,
  101. value: torch.Tensor,
  102. attention_mask: torch.Tensor | None,
  103. scaling: float | None = None,
  104. dropout: float = 0.0,
  105. **kwargs: Unpack[TransformersKwargs],
  106. ):
  107. if scaling is None:
  108. scaling = query.size(-1) ** -0.5
  109. # Take the dot product between "query" and "key" to get the raw attention scores.
  110. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  111. if attention_mask is not None:
  112. attn_weights = attn_weights + attention_mask
  113. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  114. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  115. attn_output = torch.matmul(attn_weights, value)
  116. attn_output = attn_output.transpose(1, 2).contiguous()
  117. return attn_output, attn_weights
  118. class PLBartAttention(nn.Module):
  119. """Multi-headed attention from 'Attention Is All You Need' paper"""
  120. def __init__(
  121. self,
  122. embed_dim: int,
  123. num_heads: int,
  124. dropout: float = 0.0,
  125. is_decoder: bool = False,
  126. bias: bool = True,
  127. is_causal: bool = False,
  128. config: PLBartConfig | None = None,
  129. layer_idx: int | None = None,
  130. ):
  131. super().__init__()
  132. self.embed_dim = embed_dim
  133. self.num_heads = num_heads
  134. self.dropout = dropout
  135. self.head_dim = embed_dim // num_heads
  136. self.config = config
  137. if (self.head_dim * num_heads) != self.embed_dim:
  138. raise ValueError(
  139. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  140. f" and `num_heads`: {num_heads})."
  141. )
  142. self.scaling = self.head_dim**-0.5
  143. self.is_decoder = is_decoder
  144. self.is_causal = is_causal
  145. self.layer_idx = layer_idx
  146. if layer_idx is None and self.is_decoder:
  147. logger.warning_once(
  148. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  149. "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  150. "when creating this class."
  151. )
  152. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  153. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  154. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  155. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  156. def forward(
  157. self,
  158. hidden_states: torch.Tensor,
  159. key_value_states: torch.Tensor | None = None,
  160. past_key_values: Cache | None = None,
  161. attention_mask: torch.Tensor | None = None,
  162. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  163. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  164. **kwargs: Unpack[FlashAttentionKwargs],
  165. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  166. """Input shape: Batch x Time x Channel"""
  167. # if key_value_states are provided this layer is used as a cross-attention layer
  168. # for the decoder
  169. is_cross_attention = key_value_states is not None
  170. # determine input shapes
  171. input_shape = hidden_states.shape[:-1]
  172. hidden_shape = (*input_shape, -1, self.head_dim)
  173. # get query proj
  174. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  175. is_updated = False
  176. if past_key_values is not None:
  177. if isinstance(past_key_values, EncoderDecoderCache):
  178. is_updated = past_key_values.is_updated.get(self.layer_idx)
  179. if is_cross_attention:
  180. # after the first generated id, we can subsequently re-use all key/value_states from cache
  181. curr_past_key_values = past_key_values.cross_attention_cache
  182. else:
  183. curr_past_key_values = past_key_values.self_attention_cache
  184. else:
  185. curr_past_key_values = past_key_values
  186. current_states = key_value_states if is_cross_attention else hidden_states
  187. if is_cross_attention and past_key_values is not None and is_updated:
  188. # reuse k,v, cross_attentions
  189. key_states = curr_past_key_values.layers[self.layer_idx].keys
  190. value_states = curr_past_key_values.layers[self.layer_idx].values
  191. else:
  192. key_states = self.k_proj(current_states)
  193. value_states = self.v_proj(current_states)
  194. kv_shape = (*current_states.shape[:-1], -1, self.head_dim)
  195. key_states = key_states.view(kv_shape).transpose(1, 2)
  196. value_states = value_states.view(kv_shape).transpose(1, 2)
  197. if past_key_values is not None:
  198. key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
  199. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  200. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  201. past_key_values.is_updated[self.layer_idx] = True
  202. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  203. self.config._attn_implementation, eager_attention_forward
  204. )
  205. attn_output, attn_weights = attention_interface(
  206. self,
  207. query_states,
  208. key_states,
  209. value_states,
  210. attention_mask,
  211. dropout=0.0 if not self.training else self.dropout,
  212. scaling=self.scaling,
  213. **kwargs,
  214. )
  215. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  216. attn_output = self.out_proj(attn_output)
  217. return attn_output, attn_weights
  218. class PLBartEncoderLayer(GradientCheckpointingLayer):
  219. def __init__(self, config: PLBartConfig, layer_idx: int | None = None):
  220. super().__init__()
  221. self.embed_dim = config.d_model
  222. self.self_attn = PLBartAttention(
  223. embed_dim=self.embed_dim,
  224. num_heads=config.encoder_attention_heads,
  225. dropout=config.attention_dropout,
  226. config=config,
  227. layer_idx=layer_idx,
  228. )
  229. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  230. self.dropout = config.dropout
  231. self.activation_fn = ACT2FN[config.activation_function]
  232. self.activation_dropout = config.activation_dropout
  233. self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
  234. self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
  235. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  236. def forward(
  237. self,
  238. hidden_states: torch.FloatTensor,
  239. attention_mask: torch.FloatTensor,
  240. **kwargs: Unpack[TransformersKwargs],
  241. ) -> torch.Tensor:
  242. residual = hidden_states
  243. hidden_states, _ = self.self_attn(
  244. hidden_states,
  245. attention_mask=attention_mask,
  246. **kwargs,
  247. )
  248. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  249. hidden_states = residual + hidden_states
  250. hidden_states = self.self_attn_layer_norm(hidden_states)
  251. residual = hidden_states
  252. hidden_states = self.activation_fn(self.fc1(hidden_states))
  253. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  254. hidden_states = self.fc2(hidden_states)
  255. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  256. hidden_states = residual + hidden_states
  257. hidden_states = self.final_layer_norm(hidden_states)
  258. if hidden_states.dtype == torch.float16 and not torch.isfinite(hidden_states).all():
  259. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  260. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  261. return hidden_states
  262. class PLBartEncoder(PLBartPreTrainedModel):
  263. """
  264. Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
  265. [`PLBartEncoderLayer`].
  266. Args:
  267. config: PLBartConfig
  268. embed_tokens (nn.Embedding): output embedding
  269. """
  270. _can_record_outputs = {
  271. "hidden_states": PLBartEncoderLayer,
  272. "attentions": PLBartAttention,
  273. }
  274. def __init__(self, config: PLBartConfig):
  275. super().__init__(config)
  276. self.dropout = config.dropout
  277. self.layerdrop = config.encoder_layerdrop
  278. embed_dim = config.d_model
  279. self.padding_idx = config.pad_token_id
  280. self.max_source_positions = config.max_position_embeddings
  281. embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
  282. self.embed_tokens = PLBartScaledWordEmbedding(
  283. config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
  284. )
  285. self.embed_positions = PLBartLearnedPositionalEmbedding(
  286. config.max_position_embeddings,
  287. embed_dim,
  288. )
  289. self.layers = nn.ModuleList([PLBartEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)])
  290. self.layernorm_embedding = nn.LayerNorm(embed_dim)
  291. self.gradient_checkpointing = False
  292. # Initialize weights and apply final processing
  293. self.post_init()
  294. @merge_with_config_defaults
  295. @capture_outputs
  296. @auto_docstring
  297. def forward(
  298. self,
  299. input_ids: torch.LongTensor | None = None,
  300. attention_mask: torch.Tensor | None = None,
  301. inputs_embeds: torch.FloatTensor | None = None,
  302. **kwargs: Unpack[TransformersKwargs],
  303. ) -> BaseModelOutput:
  304. if (input_ids is None) ^ (inputs_embeds is not None):
  305. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  306. if inputs_embeds is None:
  307. inputs_embeds = self.embed_tokens(input_ids)
  308. embed_pos = self.embed_positions(inputs_embeds[:, :, -1]) # needed for the shape only
  309. embed_pos = embed_pos.to(inputs_embeds.device)
  310. hidden_states = inputs_embeds + embed_pos
  311. hidden_states = self.layernorm_embedding(hidden_states)
  312. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  313. attention_mask = create_bidirectional_mask(
  314. config=self.config,
  315. inputs_embeds=inputs_embeds,
  316. attention_mask=attention_mask,
  317. )
  318. for idx, encoder_layer in enumerate(self.layers):
  319. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  320. to_drop = False
  321. if self.training:
  322. dropout_probability = torch.rand([])
  323. if dropout_probability < self.layerdrop: # skip the layer
  324. to_drop = True
  325. if not to_drop:
  326. hidden_states = encoder_layer(
  327. hidden_states,
  328. attention_mask,
  329. **kwargs,
  330. )
  331. return BaseModelOutput(
  332. last_hidden_state=hidden_states,
  333. )
  334. class PLBartDecoderLayer(GradientCheckpointingLayer):
  335. def __init__(self, config: PLBartConfig, layer_idx: int | None = None):
  336. super().__init__()
  337. self.embed_dim = config.d_model
  338. self.self_attn = PLBartAttention(
  339. embed_dim=self.embed_dim,
  340. num_heads=config.decoder_attention_heads,
  341. dropout=config.attention_dropout,
  342. is_decoder=True,
  343. is_causal=True,
  344. config=config,
  345. layer_idx=layer_idx,
  346. )
  347. self.dropout = config.dropout
  348. self.activation_fn = ACT2FN[config.activation_function]
  349. self.activation_dropout = config.activation_dropout
  350. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  351. self.encoder_attn = PLBartAttention(
  352. self.embed_dim,
  353. config.decoder_attention_heads,
  354. dropout=config.attention_dropout,
  355. is_decoder=True,
  356. config=config,
  357. layer_idx=layer_idx,
  358. )
  359. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  360. self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
  361. self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
  362. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  363. def forward(
  364. self,
  365. hidden_states: torch.Tensor,
  366. attention_mask: torch.Tensor | None = None,
  367. encoder_hidden_states: torch.Tensor | None = None,
  368. encoder_attention_mask: torch.Tensor | None = None,
  369. past_key_values: Cache | None = None,
  370. use_cache: bool | None = True,
  371. **kwargs: Unpack[TransformersKwargs],
  372. ) -> torch.Tensor:
  373. residual = hidden_states
  374. # Self Attention
  375. hidden_states, _ = self.self_attn(
  376. hidden_states,
  377. past_key_values=past_key_values,
  378. attention_mask=attention_mask,
  379. **kwargs,
  380. )
  381. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  382. hidden_states = residual + hidden_states
  383. hidden_states = self.self_attn_layer_norm(hidden_states)
  384. # Cross-Attention Block
  385. if encoder_hidden_states is not None:
  386. residual = hidden_states
  387. hidden_states, _ = self.encoder_attn(
  388. hidden_states,
  389. key_value_states=encoder_hidden_states,
  390. attention_mask=encoder_attention_mask,
  391. past_key_values=past_key_values,
  392. **kwargs,
  393. )
  394. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  395. hidden_states = residual + hidden_states
  396. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  397. # Fully Connected
  398. residual = hidden_states
  399. hidden_states = self.activation_fn(self.fc1(hidden_states))
  400. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  401. hidden_states = self.fc2(hidden_states)
  402. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  403. hidden_states = residual + hidden_states
  404. hidden_states = self.final_layer_norm(hidden_states)
  405. return hidden_states
  406. class PLBartDecoder(PLBartPreTrainedModel):
  407. """
  408. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`PLBartDecoderLayer`]
  409. Args:
  410. config: PLBartConfig
  411. embed_tokens (nn.Embedding): output embedding
  412. """
  413. _can_record_outputs = {
  414. "hidden_states": PLBartDecoderLayer,
  415. "attentions": OutputRecorder(PLBartAttention, index=1, layer_name="self_attn"),
  416. "cross_attentions": OutputRecorder(PLBartAttention, index=1, layer_name="encoder_attn"),
  417. }
  418. def __init__(self, config: PLBartConfig):
  419. super().__init__(config)
  420. self.dropout = config.dropout
  421. self.layerdrop = config.decoder_layerdrop
  422. self.padding_idx = config.pad_token_id
  423. self.max_target_positions = config.max_position_embeddings
  424. embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  425. self.embed_tokens = PLBartScaledWordEmbedding(
  426. config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
  427. )
  428. self.embed_positions = PLBartLearnedPositionalEmbedding(
  429. config.max_position_embeddings,
  430. config.d_model,
  431. )
  432. self.layers = nn.ModuleList([PLBartDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)])
  433. self.layernorm_embedding = nn.LayerNorm(config.d_model)
  434. self.gradient_checkpointing = False
  435. # Initialize weights and apply final processing
  436. self.post_init()
  437. @merge_with_config_defaults
  438. @capture_outputs
  439. @auto_docstring
  440. def forward(
  441. self,
  442. input_ids: torch.LongTensor | None = None,
  443. attention_mask: torch.Tensor | None = None,
  444. encoder_hidden_states: torch.FloatTensor | None = None,
  445. encoder_attention_mask: torch.LongTensor | None = None,
  446. past_key_values: Cache | None = None,
  447. inputs_embeds: torch.FloatTensor | None = None,
  448. use_cache: bool | None = None,
  449. **kwargs: Unpack[TransformersKwargs],
  450. ) -> BaseModelOutputWithPastAndCrossAttentions:
  451. if (input_ids is None) ^ (inputs_embeds is not None):
  452. raise ValueError("You must specify exactly one of decoder_input_ids or decoder_inputs_embeds")
  453. if inputs_embeds is None:
  454. inputs_embeds = self.embed_tokens(input_ids)
  455. # initialize `past_key_values`
  456. if use_cache and past_key_values is None:
  457. past_key_values = (
  458. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  459. if encoder_hidden_states is not None or self.config.is_encoder_decoder
  460. else DynamicCache(config=self.config)
  461. )
  462. batch_size, seq_length = inputs_embeds.size()[:-1]
  463. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  464. position_ids = torch.arange(seq_length, device=inputs_embeds.device) + past_key_values_length
  465. if attention_mask is None and not is_torchdynamo_compiling():
  466. # required mask seq length can be calculated via length of past cache
  467. mask_seq_length = past_key_values_length + seq_length
  468. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  469. self_attn_cache = (
  470. past_key_values.self_attention_cache
  471. if isinstance(past_key_values, EncoderDecoderCache)
  472. else past_key_values
  473. )
  474. attention_mask = create_causal_mask(
  475. config=self.config,
  476. inputs_embeds=inputs_embeds,
  477. attention_mask=attention_mask,
  478. past_key_values=self_attn_cache,
  479. )
  480. encoder_attention_mask = create_bidirectional_mask(
  481. config=self.config,
  482. inputs_embeds=inputs_embeds,
  483. attention_mask=encoder_attention_mask,
  484. encoder_hidden_states=encoder_hidden_states,
  485. )
  486. # embed positions
  487. positions = self.embed_positions(input, past_key_values_length, position_ids=position_ids)
  488. positions = positions.to(inputs_embeds.device)
  489. hidden_states = inputs_embeds + positions
  490. hidden_states = self.layernorm_embedding(hidden_states)
  491. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  492. for idx, decoder_layer in enumerate(self.layers):
  493. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  494. if self.training:
  495. dropout_probability = torch.rand([])
  496. if dropout_probability < self.layerdrop:
  497. continue
  498. hidden_states = decoder_layer(
  499. hidden_states,
  500. attention_mask,
  501. encoder_hidden_states, # as a positional argument for gradient checkpointing
  502. encoder_attention_mask=encoder_attention_mask,
  503. past_key_values=past_key_values,
  504. use_cache=use_cache,
  505. **kwargs,
  506. )
  507. return BaseModelOutputWithPastAndCrossAttentions(
  508. last_hidden_state=hidden_states,
  509. past_key_values=past_key_values,
  510. )
  511. def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
  512. """
  513. Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that PLBart does not
  514. have a single `decoder_start_token_id` in contrast to other Bart-like models.
  515. """
  516. prev_output_tokens = input_ids.clone()
  517. if pad_token_id is None:
  518. raise ValueError("self.model.config.pad_token_id has to be defined.")
  519. # replace possible -100 values in labels by `pad_token_id`
  520. prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id)
  521. index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
  522. decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze()
  523. prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
  524. prev_output_tokens[:, 0] = decoder_start_tokens
  525. return prev_output_tokens
  526. @auto_docstring
  527. class PLBartModel(PLBartPreTrainedModel):
  528. _tied_weights_keys = {
  529. "encoder.embed_tokens.weight": "shared.weight",
  530. "decoder.embed_tokens.weight": "shared.weight",
  531. }
  532. def __init__(self, config: PLBartConfig):
  533. super().__init__(config)
  534. padding_idx, vocab_size = config.pad_token_id, config.vocab_size
  535. embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  536. self.shared = PLBartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
  537. self.encoder = PLBartEncoder(config)
  538. self.decoder = PLBartDecoder(config)
  539. self.post_init()
  540. def get_input_embeddings(self):
  541. return self.shared
  542. def set_input_embeddings(self, value):
  543. self.shared = value
  544. self.encoder.embed_tokens = self.shared
  545. self.decoder.embed_tokens = self.shared
  546. @merge_with_config_defaults
  547. @capture_outputs
  548. @auto_docstring
  549. def forward(
  550. self,
  551. input_ids: torch.LongTensor | None = None,
  552. attention_mask: torch.LongTensor | None = None,
  553. decoder_input_ids: torch.LongTensor | None = None,
  554. decoder_attention_mask: torch.Tensor | None = None,
  555. encoder_outputs: list[torch.FloatTensor] | None = None,
  556. past_key_values: Cache | None = None,
  557. inputs_embeds: torch.FloatTensor | None = None,
  558. decoder_inputs_embeds: torch.FloatTensor | None = None,
  559. use_cache: bool | None = None,
  560. **kwargs: Unpack[TransformersKwargs],
  561. ) -> tuple[torch.Tensor] | Seq2SeqModelOutput:
  562. r"""
  563. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  564. Indices of decoder input sequence tokens in the vocabulary.
  565. Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint.
  566. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
  567. [What are decoder input IDs?](../glossary#decoder-input-ids)
  568. PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
  569. varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If
  570. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  571. `past_key_values`).
  572. For translation and summarization training, `decoder_input_ids` should be provided. If no
  573. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  574. for denoising pre-training following the paper.
  575. decoder_attention_mask (:
  576. obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*):
  577. Default behavior:
  578. generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default.
  579. """
  580. # different to other models, PLBart automatically creates decoder_input_ids from
  581. # input_ids if no decoder_input_ids are provided
  582. if decoder_input_ids is None and decoder_inputs_embeds is None:
  583. decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id)
  584. if encoder_outputs is None:
  585. encoder_outputs: BaseModelOutput = self.encoder(
  586. input_ids=input_ids,
  587. attention_mask=attention_mask,
  588. inputs_embeds=inputs_embeds,
  589. **kwargs,
  590. )
  591. elif not isinstance(encoder_outputs, BaseModelOutput):
  592. encoder_outputs = BaseModelOutput(
  593. last_hidden_state=encoder_outputs[0],
  594. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  595. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  596. )
  597. decoder_outputs = self.decoder(
  598. input_ids=decoder_input_ids,
  599. attention_mask=decoder_attention_mask,
  600. encoder_hidden_states=encoder_outputs[0],
  601. encoder_attention_mask=attention_mask,
  602. past_key_values=past_key_values,
  603. inputs_embeds=decoder_inputs_embeds,
  604. use_cache=use_cache,
  605. **kwargs,
  606. )
  607. return Seq2SeqModelOutput(
  608. last_hidden_state=decoder_outputs.last_hidden_state,
  609. past_key_values=decoder_outputs.past_key_values,
  610. decoder_hidden_states=decoder_outputs.hidden_states,
  611. decoder_attentions=decoder_outputs.attentions,
  612. cross_attentions=decoder_outputs.cross_attentions,
  613. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  614. encoder_hidden_states=encoder_outputs.hidden_states,
  615. encoder_attentions=encoder_outputs.attentions,
  616. )
  617. @auto_docstring(
  618. custom_intro="""
  619. The PLBART Model with a language modeling head. Can be used for code-to-text, text-to-code and code-to-code.
  620. """
  621. )
  622. class PLBartForConditionalGeneration(PLBartPreTrainedModel, GenerationMixin):
  623. base_model_prefix = "model"
  624. _keys_to_ignore_on_load_missing = ["final_logits_bias"]
  625. _tied_weights_keys = {
  626. "lm_head.weight": "model.shared.weight",
  627. }
  628. def __init__(self, config: PLBartConfig):
  629. super().__init__(config)
  630. self.model = PLBartModel(config)
  631. self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
  632. self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
  633. self.post_init()
  634. def resize_token_embeddings(
  635. self, new_num_tokens: int, pad_to_multiple_of: int | None = None, mean_resizing: bool = True
  636. ) -> nn.Embedding:
  637. new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
  638. self._resize_final_logits_bias(new_embeddings.weight.shape[0])
  639. return new_embeddings
  640. def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
  641. old_num_tokens = self.final_logits_bias.shape[-1]
  642. if new_num_tokens <= old_num_tokens:
  643. new_bias = self.final_logits_bias[:, :new_num_tokens]
  644. else:
  645. extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
  646. new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
  647. self.register_buffer("final_logits_bias", new_bias)
  648. @merge_with_config_defaults
  649. @capture_outputs
  650. @auto_docstring
  651. def forward(
  652. self,
  653. input_ids: torch.LongTensor | None = None,
  654. attention_mask: torch.LongTensor | None = None,
  655. decoder_input_ids: torch.LongTensor | None = None,
  656. decoder_attention_mask: torch.Tensor | None = None,
  657. encoder_outputs: list[torch.FloatTensor] | None = None,
  658. past_key_values: Cache | None = None,
  659. inputs_embeds: torch.FloatTensor | None = None,
  660. decoder_inputs_embeds: torch.FloatTensor | None = None,
  661. labels: torch.Tensor | None = None,
  662. use_cache: bool | None = None,
  663. **kwargs: Unpack[TransformersKwargs],
  664. ) -> tuple[torch.Tensor] | Seq2SeqLMOutput:
  665. r"""
  666. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  667. Indices of decoder input sequence tokens in the vocabulary.
  668. Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint.
  669. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
  670. [What are decoder input IDs?](../glossary#decoder-input-ids)
  671. PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
  672. varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If
  673. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  674. `past_key_values`).
  675. For translation and summarization training, `decoder_input_ids` should be provided. If no
  676. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  677. for denoising pre-training following the paper.
  678. decoder_attention_mask (:
  679. obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*):
  680. Default behavior:
  681. generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default.
  682. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  683. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  684. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  685. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  686. Example Mask-filling:
  687. ```python
  688. >>> from transformers import AutoTokenizer, PLBartForConditionalGeneration
  689. >>> model = PLBartForConditionalGeneration.from_pretrained("uclanlp/plbart-base")
  690. >>> tokenizer = AutoTokenizer.from_pretrained("uclanlp/plbart-base")
  691. >>> # en_XX is the language symbol id <LID> for English
  692. >>> TXT = "<s> Is 0 the <mask> Fibonacci number ? </s> en_XX"
  693. >>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors="pt").input_ids
  694. >>> logits = model(input_ids).logits
  695. >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
  696. >>> probs = logits[0, masked_index].softmax(dim=0)
  697. >>> values, predictions = probs.topk(5)
  698. >>> tokenizer.decode(predictions).split()
  699. ['first', 'same', 'highest', 'result', 'number']
  700. ```
  701. """
  702. if labels is not None:
  703. if decoder_input_ids is None and decoder_inputs_embeds is None:
  704. decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
  705. outputs: Seq2SeqModelOutput = self.model(
  706. input_ids,
  707. attention_mask=attention_mask,
  708. decoder_input_ids=decoder_input_ids,
  709. encoder_outputs=encoder_outputs,
  710. decoder_attention_mask=decoder_attention_mask,
  711. past_key_values=past_key_values,
  712. inputs_embeds=inputs_embeds,
  713. decoder_inputs_embeds=decoder_inputs_embeds,
  714. use_cache=use_cache,
  715. **kwargs,
  716. )
  717. lm_logits = self.lm_head(outputs.last_hidden_state)
  718. lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
  719. masked_lm_loss = None
  720. if labels is not None:
  721. loss_fct = CrossEntropyLoss()
  722. masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
  723. return Seq2SeqLMOutput(
  724. loss=masked_lm_loss,
  725. logits=lm_logits,
  726. past_key_values=outputs.past_key_values,
  727. decoder_hidden_states=outputs.decoder_hidden_states,
  728. decoder_attentions=outputs.decoder_attentions,
  729. cross_attentions=outputs.cross_attentions,
  730. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  731. encoder_hidden_states=outputs.encoder_hidden_states,
  732. encoder_attentions=outputs.encoder_attentions,
  733. )
  734. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  735. return shift_tokens_right(labels, self.config.pad_token_id)
  736. class PLBartClassificationHead(nn.Module):
  737. """Head for sentence-level classification tasks."""
  738. def __init__(
  739. self,
  740. input_dim: int,
  741. inner_dim: int,
  742. num_classes: int,
  743. pooler_dropout: float,
  744. ):
  745. super().__init__()
  746. self.dense = nn.Linear(input_dim, inner_dim)
  747. self.dropout = nn.Dropout(p=pooler_dropout)
  748. self.out_proj = nn.Linear(inner_dim, num_classes)
  749. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  750. hidden_states = self.dropout(hidden_states)
  751. hidden_states = self.dense(hidden_states)
  752. hidden_states = torch.tanh(hidden_states)
  753. hidden_states = self.dropout(hidden_states)
  754. hidden_states = self.out_proj(hidden_states)
  755. return hidden_states
  756. @auto_docstring(
  757. custom_intro="""
  758. PLBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g.
  759. for GLUE tasks.
  760. """
  761. )
  762. class PLBartForSequenceClassification(PLBartPreTrainedModel):
  763. def __init__(self, config: PLBartConfig, **kwargs):
  764. super().__init__(config, **kwargs)
  765. self.model = PLBartModel(config)
  766. self.classification_head = PLBartClassificationHead(
  767. config.d_model,
  768. config.d_model,
  769. config.num_labels,
  770. config.classifier_dropout,
  771. )
  772. # Initialize weights and apply final processing
  773. self.post_init()
  774. @can_return_tuple
  775. @auto_docstring
  776. def forward(
  777. self,
  778. input_ids: torch.LongTensor | None = None,
  779. attention_mask: torch.Tensor | None = None,
  780. decoder_input_ids: torch.LongTensor | None = None,
  781. decoder_attention_mask: torch.LongTensor | None = None,
  782. encoder_outputs: list[torch.FloatTensor] | None = None,
  783. inputs_embeds: torch.FloatTensor | None = None,
  784. decoder_inputs_embeds: torch.FloatTensor | None = None,
  785. labels: torch.LongTensor | None = None,
  786. use_cache: bool | None = None,
  787. **kwargs: Unpack[TransformersKwargs],
  788. ) -> tuple | Seq2SeqSequenceClassifierOutput:
  789. r"""
  790. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  791. Indices of decoder input sequence tokens in the vocabulary.
  792. Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint.
  793. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
  794. [What are decoder input IDs?](../glossary#decoder-input-ids)
  795. PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
  796. varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If
  797. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  798. `past_key_values`).
  799. For translation and summarization training, `decoder_input_ids` should be provided. If no
  800. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  801. for denoising pre-training following the paper.
  802. decoder_attention_mask (:
  803. obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*):
  804. Default behavior:
  805. generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default.
  806. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  807. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  808. config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  809. """
  810. if labels is not None:
  811. use_cache = False
  812. if input_ids is None and inputs_embeds is not None:
  813. raise NotImplementedError(
  814. f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
  815. )
  816. outputs: Seq2SeqModelOutput = self.model(
  817. input_ids,
  818. attention_mask=attention_mask,
  819. decoder_input_ids=decoder_input_ids,
  820. decoder_attention_mask=decoder_attention_mask,
  821. encoder_outputs=encoder_outputs,
  822. inputs_embeds=inputs_embeds,
  823. decoder_inputs_embeds=decoder_inputs_embeds,
  824. use_cache=use_cache,
  825. **kwargs,
  826. )
  827. hidden_states = outputs[0] # last hidden state
  828. eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)
  829. torch_compilable_check(
  830. torch.unique_consecutive(eos_mask.sum(1)).numel() == 1,
  831. "All examples must have the same number of <eos> tokens.",
  832. )
  833. sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
  834. :, -1, :
  835. ]
  836. logits = self.classification_head(sentence_representation)
  837. loss = None
  838. if labels is not None:
  839. labels = labels.to(logits.device)
  840. if self.config.problem_type is None:
  841. if self.config.num_labels == 1:
  842. self.config.problem_type = "regression"
  843. elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  844. self.config.problem_type = "single_label_classification"
  845. else:
  846. self.config.problem_type = "multi_label_classification"
  847. if self.config.problem_type == "regression":
  848. loss_fct = MSELoss()
  849. if self.config.num_labels == 1:
  850. loss = loss_fct(logits.squeeze(), labels.squeeze())
  851. else:
  852. loss = loss_fct(logits, labels)
  853. elif self.config.problem_type == "single_label_classification":
  854. loss_fct = CrossEntropyLoss()
  855. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  856. elif self.config.problem_type == "multi_label_classification":
  857. loss_fct = BCEWithLogitsLoss()
  858. loss = loss_fct(logits, labels)
  859. return Seq2SeqSequenceClassifierOutput(
  860. loss=loss,
  861. logits=logits,
  862. past_key_values=outputs.past_key_values,
  863. decoder_hidden_states=outputs.decoder_hidden_states,
  864. decoder_attentions=outputs.decoder_attentions,
  865. cross_attentions=outputs.cross_attentions,
  866. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  867. encoder_hidden_states=outputs.encoder_hidden_states,
  868. encoder_attentions=outputs.encoder_attentions,
  869. )
  870. class PLBartDecoderWrapper(PLBartPreTrainedModel):
  871. """
  872. This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
  873. used in combination with the [`EncoderDecoderModel`] framework.
  874. """
  875. def __init__(self, config):
  876. super().__init__(config)
  877. self.decoder = PLBartDecoder(config)
  878. self.post_init()
  879. def forward(self, *args, **kwargs):
  880. return self.decoder(*args, **kwargs)
  881. @auto_docstring(
  882. custom_intro="""
  883. PLBART decoder with a language modeling head on top (linear layer with weights tied to the input embeddings).
  884. """
  885. )
  886. class PLBartForCausalLM(PLBartPreTrainedModel, GenerationMixin):
  887. _tied_weights_keys = {
  888. "lm_head.weight": "model.decoder.embed_tokens.weight",
  889. }
  890. def __init__(self, config):
  891. config.is_decoder = True
  892. config.is_encoder_decoder = False
  893. super().__init__(config)
  894. self.model = PLBartDecoderWrapper(config)
  895. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  896. # Initialize weights and apply final processing
  897. self.post_init()
  898. def get_input_embeddings(self):
  899. return self.model.decoder.embed_tokens
  900. def set_input_embeddings(self, value):
  901. self.model.decoder.embed_tokens = value
  902. @can_return_tuple
  903. @auto_docstring
  904. def forward(
  905. self,
  906. input_ids: torch.LongTensor | None = None,
  907. attention_mask: torch.Tensor | None = None,
  908. encoder_hidden_states: torch.FloatTensor | None = None,
  909. encoder_attention_mask: torch.FloatTensor | None = None,
  910. past_key_values: Cache | None = None,
  911. inputs_embeds: torch.FloatTensor | None = None,
  912. labels: torch.LongTensor | None = None,
  913. use_cache: bool | None = None,
  914. logits_to_keep: int | torch.Tensor = 0,
  915. **kwargs: Unpack[TransformersKwargs],
  916. ) -> tuple | CausalLMOutputWithCrossAttentions:
  917. r"""
  918. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  919. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  920. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  921. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  922. Example:
  923. ```python
  924. >>> from transformers import AutoTokenizer, PLBartForCausalLM
  925. >>> tokenizer = AutoTokenizer.from_pretrained("uclanlp/plbart-base")
  926. >>> model = PLBartForCausalLM.from_pretrained("uclanlp/plbart-base")
  927. >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
  928. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  929. >>> outputs = model(**inputs)
  930. >>> logits = outputs.logits
  931. >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
  932. >>> list(logits.shape) == expected_shape
  933. True
  934. ```"""
  935. outputs: BaseModelOutputWithPastAndCrossAttentions = self.model.decoder(
  936. input_ids=input_ids,
  937. attention_mask=attention_mask,
  938. encoder_hidden_states=encoder_hidden_states,
  939. encoder_attention_mask=encoder_attention_mask,
  940. past_key_values=past_key_values,
  941. inputs_embeds=inputs_embeds,
  942. use_cache=use_cache,
  943. **kwargs,
  944. )
  945. hidden_states = outputs[0]
  946. # Only compute necessary logits
  947. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  948. logits = self.lm_head(hidden_states[:, slice_indices, :])
  949. loss = None
  950. if labels is not None:
  951. labels = labels.to(logits.device)
  952. loss_fct = CrossEntropyLoss()
  953. loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
  954. return CausalLMOutputWithCrossAttentions(
  955. loss=loss,
  956. logits=logits,
  957. past_key_values=outputs.past_key_values,
  958. hidden_states=outputs.hidden_states,
  959. attentions=outputs.attentions,
  960. cross_attentions=outputs.cross_attentions,
  961. )
  962. __all__ = [
  963. "PLBartForCausalLM",
  964. "PLBartForConditionalGeneration",
  965. "PLBartForSequenceClassification",
  966. "PLBartModel",
  967. "PLBartPreTrainedModel",
  968. ]