modular_plbart.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. # Copyright 2022, UCLA NLP, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch PLBART model."""
  15. import math
  16. import torch
  17. from torch import nn
  18. from torch.nn import CrossEntropyLoss
  19. from ... import initialization as init
  20. from ...cache_utils import Cache
  21. from ...generation import GenerationMixin
  22. from ...modeling_outputs import (
  23. BaseModelOutput,
  24. Seq2SeqLMOutput,
  25. Seq2SeqModelOutput,
  26. )
  27. from ...modeling_utils import PreTrainedModel
  28. from ...processing_utils import Unpack
  29. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  30. from ...utils.generic import merge_with_config_defaults
  31. from ...utils.output_capturing import capture_outputs
  32. from ..bart.modeling_bart import (
  33. BartClassificationHead,
  34. BartDecoder,
  35. BartEncoder,
  36. BartForCausalLM,
  37. BartScaledWordEmbedding,
  38. )
  39. from ..bigbird_pegasus.modeling_bigbird_pegasus import BigBirdPegasusForSequenceClassification
  40. from ..mbart.modeling_mbart import shift_tokens_right
  41. from .configuration_plbart import PLBartConfig
  42. class PLBartScaledWordEmbedding(BartScaledWordEmbedding):
  43. pass
  44. @auto_docstring
  45. class PLBartPreTrainedModel(PreTrainedModel):
  46. config: PLBartConfig
  47. base_model_prefix = "model"
  48. supports_gradient_checkpointing = True
  49. _no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"]
  50. _supports_flash_attn = True
  51. _supports_sdpa = True
  52. _supports_flex_attn = True
  53. def _init_weights(self, module):
  54. super()._init_weights(module)
  55. if isinstance(module, PLBartForConditionalGeneration):
  56. init.zeros_(module.final_logits_bias)
  57. class PLBartEncoder(BartEncoder):
  58. pass
  59. class PLBartDecoder(BartDecoder):
  60. pass
  61. @auto_docstring
  62. class PLBartModel(PLBartPreTrainedModel):
  63. _tied_weights_keys = {
  64. "encoder.embed_tokens.weight": "shared.weight",
  65. "decoder.embed_tokens.weight": "shared.weight",
  66. }
  67. def __init__(self, config: PLBartConfig):
  68. super().__init__(config)
  69. padding_idx, vocab_size = config.pad_token_id, config.vocab_size
  70. embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  71. self.shared = PLBartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
  72. self.encoder = PLBartEncoder(config)
  73. self.decoder = PLBartDecoder(config)
  74. self.post_init()
  75. def get_input_embeddings(self):
  76. return self.shared
  77. def set_input_embeddings(self, value):
  78. self.shared = value
  79. self.encoder.embed_tokens = self.shared
  80. self.decoder.embed_tokens = self.shared
  81. @merge_with_config_defaults
  82. @capture_outputs
  83. @auto_docstring
  84. def forward(
  85. self,
  86. input_ids: torch.LongTensor | None = None,
  87. attention_mask: torch.LongTensor | None = None,
  88. decoder_input_ids: torch.LongTensor | None = None,
  89. decoder_attention_mask: torch.Tensor | None = None,
  90. encoder_outputs: list[torch.FloatTensor] | None = None,
  91. past_key_values: Cache | None = None,
  92. inputs_embeds: torch.FloatTensor | None = None,
  93. decoder_inputs_embeds: torch.FloatTensor | None = None,
  94. use_cache: bool | None = None,
  95. **kwargs: Unpack[TransformersKwargs],
  96. ) -> tuple[torch.Tensor] | Seq2SeqModelOutput:
  97. r"""
  98. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  99. Indices of decoder input sequence tokens in the vocabulary.
  100. Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint.
  101. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
  102. [What are decoder input IDs?](../glossary#decoder-input-ids)
  103. PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
  104. varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If
  105. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  106. `past_key_values`).
  107. For translation and summarization training, `decoder_input_ids` should be provided. If no
  108. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  109. for denoising pre-training following the paper.
  110. decoder_attention_mask (:
  111. obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*):
  112. Default behavior:
  113. generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default.
  114. """
  115. # different to other models, PLBart automatically creates decoder_input_ids from
  116. # input_ids if no decoder_input_ids are provided
  117. if decoder_input_ids is None and decoder_inputs_embeds is None:
  118. decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id)
  119. if encoder_outputs is None:
  120. encoder_outputs: BaseModelOutput = self.encoder(
  121. input_ids=input_ids,
  122. attention_mask=attention_mask,
  123. inputs_embeds=inputs_embeds,
  124. **kwargs,
  125. )
  126. elif not isinstance(encoder_outputs, BaseModelOutput):
  127. encoder_outputs = BaseModelOutput(
  128. last_hidden_state=encoder_outputs[0],
  129. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  130. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  131. )
  132. decoder_outputs = self.decoder(
  133. input_ids=decoder_input_ids,
  134. attention_mask=decoder_attention_mask,
  135. encoder_hidden_states=encoder_outputs[0],
  136. encoder_attention_mask=attention_mask,
  137. past_key_values=past_key_values,
  138. inputs_embeds=decoder_inputs_embeds,
  139. use_cache=use_cache,
  140. **kwargs,
  141. )
  142. return Seq2SeqModelOutput(
  143. last_hidden_state=decoder_outputs.last_hidden_state,
  144. past_key_values=decoder_outputs.past_key_values,
  145. decoder_hidden_states=decoder_outputs.hidden_states,
  146. decoder_attentions=decoder_outputs.attentions,
  147. cross_attentions=decoder_outputs.cross_attentions,
  148. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  149. encoder_hidden_states=encoder_outputs.hidden_states,
  150. encoder_attentions=encoder_outputs.attentions,
  151. )
  152. @auto_docstring(
  153. custom_intro="""
  154. The PLBART Model with a language modeling head. Can be used for code-to-text, text-to-code and code-to-code.
  155. """
  156. )
  157. class PLBartForConditionalGeneration(PLBartPreTrainedModel, GenerationMixin):
  158. base_model_prefix = "model"
  159. _keys_to_ignore_on_load_missing = ["final_logits_bias"]
  160. _tied_weights_keys = {
  161. "lm_head.weight": "model.shared.weight",
  162. }
  163. def __init__(self, config: PLBartConfig):
  164. super().__init__(config)
  165. self.model = PLBartModel(config)
  166. self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
  167. self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
  168. self.post_init()
  169. def resize_token_embeddings(
  170. self, new_num_tokens: int, pad_to_multiple_of: int | None = None, mean_resizing: bool = True
  171. ) -> nn.Embedding:
  172. new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
  173. self._resize_final_logits_bias(new_embeddings.weight.shape[0])
  174. return new_embeddings
  175. def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
  176. old_num_tokens = self.final_logits_bias.shape[-1]
  177. if new_num_tokens <= old_num_tokens:
  178. new_bias = self.final_logits_bias[:, :new_num_tokens]
  179. else:
  180. extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
  181. new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
  182. self.register_buffer("final_logits_bias", new_bias)
  183. @merge_with_config_defaults
  184. @capture_outputs
  185. @auto_docstring
  186. def forward(
  187. self,
  188. input_ids: torch.LongTensor | None = None,
  189. attention_mask: torch.LongTensor | None = None,
  190. decoder_input_ids: torch.LongTensor | None = None,
  191. decoder_attention_mask: torch.Tensor | None = None,
  192. encoder_outputs: list[torch.FloatTensor] | None = None,
  193. past_key_values: Cache | None = None,
  194. inputs_embeds: torch.FloatTensor | None = None,
  195. decoder_inputs_embeds: torch.FloatTensor | None = None,
  196. labels: torch.Tensor | None = None,
  197. use_cache: bool | None = None,
  198. **kwargs: Unpack[TransformersKwargs],
  199. ) -> tuple[torch.Tensor] | Seq2SeqLMOutput:
  200. r"""
  201. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  202. Indices of decoder input sequence tokens in the vocabulary.
  203. Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint.
  204. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
  205. [What are decoder input IDs?](../glossary#decoder-input-ids)
  206. PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
  207. varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If
  208. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  209. `past_key_values`).
  210. For translation and summarization training, `decoder_input_ids` should be provided. If no
  211. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  212. for denoising pre-training following the paper.
  213. decoder_attention_mask (:
  214. obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*):
  215. Default behavior:
  216. generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default.
  217. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  218. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  219. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  220. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  221. Example Mask-filling:
  222. ```python
  223. >>> from transformers import AutoTokenizer, PLBartForConditionalGeneration
  224. >>> model = PLBartForConditionalGeneration.from_pretrained("uclanlp/plbart-base")
  225. >>> tokenizer = AutoTokenizer.from_pretrained("uclanlp/plbart-base")
  226. >>> # en_XX is the language symbol id <LID> for English
  227. >>> TXT = "<s> Is 0 the <mask> Fibonacci number ? </s> en_XX"
  228. >>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors="pt").input_ids
  229. >>> logits = model(input_ids).logits
  230. >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
  231. >>> probs = logits[0, masked_index].softmax(dim=0)
  232. >>> values, predictions = probs.topk(5)
  233. >>> tokenizer.decode(predictions).split()
  234. ['first', 'same', 'highest', 'result', 'number']
  235. ```
  236. """
  237. if labels is not None:
  238. if decoder_input_ids is None and decoder_inputs_embeds is None:
  239. decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
  240. outputs: Seq2SeqModelOutput = self.model(
  241. input_ids,
  242. attention_mask=attention_mask,
  243. decoder_input_ids=decoder_input_ids,
  244. encoder_outputs=encoder_outputs,
  245. decoder_attention_mask=decoder_attention_mask,
  246. past_key_values=past_key_values,
  247. inputs_embeds=inputs_embeds,
  248. decoder_inputs_embeds=decoder_inputs_embeds,
  249. use_cache=use_cache,
  250. **kwargs,
  251. )
  252. lm_logits = self.lm_head(outputs.last_hidden_state)
  253. lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
  254. masked_lm_loss = None
  255. if labels is not None:
  256. loss_fct = CrossEntropyLoss()
  257. masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
  258. return Seq2SeqLMOutput(
  259. loss=masked_lm_loss,
  260. logits=lm_logits,
  261. past_key_values=outputs.past_key_values,
  262. decoder_hidden_states=outputs.decoder_hidden_states,
  263. decoder_attentions=outputs.decoder_attentions,
  264. cross_attentions=outputs.cross_attentions,
  265. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  266. encoder_hidden_states=outputs.encoder_hidden_states,
  267. encoder_attentions=outputs.encoder_attentions,
  268. )
  269. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  270. return shift_tokens_right(labels, self.config.pad_token_id)
  271. class PLBartClassificationHead(BartClassificationHead):
  272. pass
  273. class PLBartForSequenceClassification(BigBirdPegasusForSequenceClassification):
  274. def forward(**super_kwargs):
  275. r"""
  276. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  277. Indices of decoder input sequence tokens in the vocabulary.
  278. Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint.
  279. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
  280. [What are decoder input IDs?](../glossary#decoder-input-ids)
  281. PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
  282. varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If
  283. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  284. `past_key_values`).
  285. For translation and summarization training, `decoder_input_ids` should be provided. If no
  286. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  287. for denoising pre-training following the paper.
  288. decoder_attention_mask (:
  289. obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*):
  290. Default behavior:
  291. generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default.
  292. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  293. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  294. config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  295. """
  296. super().forward(**super_kwargs)
  297. class PLBartForCausalLM(BartForCausalLM):
  298. @can_return_tuple
  299. @auto_docstring
  300. def forward(**super_kwargs):
  301. r"""
  302. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  303. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  304. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  305. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  306. Example:
  307. ```python
  308. >>> from transformers import AutoTokenizer, PLBartForCausalLM
  309. >>> tokenizer = AutoTokenizer.from_pretrained("uclanlp/plbart-base")
  310. >>> model = PLBartForCausalLM.from_pretrained("uclanlp/plbart-base")
  311. >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
  312. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  313. >>> outputs = model(**inputs)
  314. >>> logits = outputs.logits
  315. >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
  316. >>> list(logits.shape) == expected_shape
  317. True
  318. ```"""
  319. super().forward(**super_kwargs)
  320. __all__ = [
  321. "PLBartForCausalLM",
  322. "PLBartForConditionalGeneration",
  323. "PLBartForSequenceClassification",
  324. "PLBartModel",
  325. "PLBartPreTrainedModel",
  326. ]