modular_roberta.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771
  1. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch RoBERTa model."""
  16. import torch
  17. import torch.nn as nn
  18. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  19. from ... import initialization as init
  20. from ...activations import gelu
  21. from ...generation import GenerationMixin
  22. from ...modeling_outputs import (
  23. BaseModelOutputWithPoolingAndCrossAttentions,
  24. CausalLMOutputWithCrossAttentions,
  25. MaskedLMOutput,
  26. MultipleChoiceModelOutput,
  27. QuestionAnsweringModelOutput,
  28. SequenceClassifierOutput,
  29. TokenClassifierOutput,
  30. )
  31. from ...modeling_utils import PreTrainedModel
  32. from ...processing_utils import Unpack
  33. from ...utils import TransformersKwargs, auto_docstring, logging
  34. from ...utils.generic import can_return_tuple
  35. from ..bert.modeling_bert import BertCrossAttention, BertEmbeddings, BertLayer, BertModel, BertSelfAttention
  36. from .configuration_roberta import RobertaConfig
  37. logger = logging.get_logger(__name__)
  38. class RobertaEmbeddings(BertEmbeddings):
  39. def __init__(self, config):
  40. super().__init__(config)
  41. del self.pad_token_id
  42. del self.position_embeddings
  43. self.padding_idx = config.pad_token_id
  44. self.position_embeddings = nn.Embedding(
  45. config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
  46. )
  47. def forward(
  48. self,
  49. input_ids: torch.LongTensor | None = None,
  50. token_type_ids: torch.LongTensor | None = None,
  51. position_ids: torch.LongTensor | None = None,
  52. inputs_embeds: torch.FloatTensor | None = None,
  53. past_key_values_length: int = 0,
  54. ):
  55. if position_ids is None:
  56. if input_ids is not None:
  57. # Create the position ids from the input token ids. Any padded tokens remain padded.
  58. position_ids = self.create_position_ids_from_input_ids(
  59. input_ids, self.padding_idx, past_key_values_length
  60. )
  61. else:
  62. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, self.padding_idx)
  63. if input_ids is not None:
  64. input_shape = input_ids.size()
  65. else:
  66. input_shape = inputs_embeds.size()[:-1]
  67. batch_size, seq_length = input_shape
  68. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  69. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  70. # issue #5664
  71. if token_type_ids is None:
  72. if hasattr(self, "token_type_ids"):
  73. # NOTE: We assume either pos ids to have bsz == 1 (broadcastable) or bsz == effective bsz (input_shape[0])
  74. buffered_token_type_ids = self.token_type_ids.expand(position_ids.shape[0], -1)
  75. buffered_token_type_ids = torch.gather(buffered_token_type_ids, dim=1, index=position_ids)
  76. token_type_ids = buffered_token_type_ids.expand(batch_size, seq_length)
  77. else:
  78. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  79. if inputs_embeds is None:
  80. inputs_embeds = self.word_embeddings(input_ids)
  81. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  82. embeddings = inputs_embeds + token_type_embeddings
  83. position_embeddings = self.position_embeddings(position_ids)
  84. embeddings = embeddings + position_embeddings
  85. embeddings = self.LayerNorm(embeddings)
  86. embeddings = self.dropout(embeddings)
  87. return embeddings
  88. @staticmethod
  89. def create_position_ids_from_inputs_embeds(inputs_embeds, padding_idx):
  90. """
  91. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  92. Args:
  93. inputs_embeds: torch.Tensor
  94. Returns: torch.Tensor
  95. """
  96. input_shape = inputs_embeds.size()[:-1]
  97. sequence_length = input_shape[1]
  98. position_ids = torch.arange(
  99. padding_idx + 1, sequence_length + padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  100. )
  101. return position_ids.unsqueeze(0).expand(input_shape)
  102. @staticmethod
  103. def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
  104. """
  105. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
  106. are ignored. This is modified from fairseq's `utils.make_positions`.
  107. Args:
  108. x: torch.Tensor x:
  109. Returns: torch.Tensor
  110. """
  111. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  112. mask = input_ids.ne(padding_idx).int()
  113. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
  114. return incremental_indices.long() + padding_idx
  115. class RobertaSelfAttention(BertSelfAttention):
  116. pass
  117. class RobertaCrossAttention(BertCrossAttention):
  118. pass
  119. class RobertaLayer(BertLayer):
  120. pass
  121. @auto_docstring
  122. class RobertaPreTrainedModel(PreTrainedModel):
  123. config_class = RobertaConfig
  124. base_model_prefix = "roberta"
  125. supports_gradient_checkpointing = True
  126. _supports_flash_attn = True
  127. _supports_sdpa = True
  128. _supports_flex_attn = True
  129. _supports_attention_backend = True
  130. _can_record_outputs = {
  131. "hidden_states": RobertaLayer,
  132. "attentions": RobertaSelfAttention,
  133. "cross_attentions": RobertaCrossAttention,
  134. }
  135. @torch.no_grad()
  136. def _init_weights(self, module):
  137. """Initialize the weights"""
  138. super()._init_weights(module)
  139. if isinstance(module, RobertaLMHead):
  140. init.zeros_(module.bias)
  141. elif isinstance(module, RobertaEmbeddings):
  142. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  143. init.zeros_(module.token_type_ids)
  144. class RobertaModel(BertModel):
  145. def __init__(self, config, add_pooling_layer=True):
  146. super().__init__(self, config)
  147. @auto_docstring(
  148. custom_intro="""
  149. RoBERTa Model with a `language modeling` head on top for CLM fine-tuning.
  150. """
  151. )
  152. class RobertaForCausalLM(RobertaPreTrainedModel, GenerationMixin):
  153. _tied_weights_keys = {
  154. "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight",
  155. "lm_head.decoder.bias": "lm_head.bias",
  156. }
  157. def __init__(self, config):
  158. super().__init__(config)
  159. if not config.is_decoder:
  160. logger.warning("If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`")
  161. self.roberta = RobertaModel(config, add_pooling_layer=False)
  162. self.lm_head = RobertaLMHead(config)
  163. # Initialize weights and apply final processing
  164. self.post_init()
  165. def get_output_embeddings(self):
  166. return self.lm_head.decoder
  167. def set_output_embeddings(self, new_embeddings):
  168. self.lm_head.decoder = new_embeddings
  169. @can_return_tuple
  170. @auto_docstring
  171. def forward(
  172. self,
  173. input_ids: torch.LongTensor | None = None,
  174. attention_mask: torch.FloatTensor | None = None,
  175. token_type_ids: torch.LongTensor | None = None,
  176. position_ids: torch.LongTensor | None = None,
  177. inputs_embeds: torch.FloatTensor | None = None,
  178. encoder_hidden_states: torch.FloatTensor | None = None,
  179. encoder_attention_mask: torch.FloatTensor | None = None,
  180. labels: torch.LongTensor | None = None,
  181. past_key_values: tuple[tuple[torch.FloatTensor]] | None = None,
  182. use_cache: bool | None = None,
  183. logits_to_keep: int | torch.Tensor = 0,
  184. **kwargs: Unpack[TransformersKwargs],
  185. ) -> tuple[torch.Tensor] | CausalLMOutputWithCrossAttentions:
  186. r"""
  187. token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  188. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
  189. - 0 corresponds to a *sentence A* token,
  190. - 1 corresponds to a *sentence B* token.
  191. This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
  192. >= 2. All the value in this tensor should be always < type_vocab_size.
  193. [What are token type IDs?](../glossary#token-type-ids)
  194. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  195. Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
  196. `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
  197. ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  198. Example:
  199. ```python
  200. >>> from transformers import AutoTokenizer, RobertaForCausalLM, AutoConfig
  201. >>> import torch
  202. >>> tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base")
  203. >>> config = AutoConfig.from_pretrained("FacebookAI/roberta-base")
  204. >>> config.is_decoder = True
  205. >>> model = RobertaForCausalLM.from_pretrained("FacebookAI/roberta-base", config=config)
  206. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  207. >>> outputs = model(**inputs)
  208. >>> prediction_logits = outputs.logits
  209. ```"""
  210. if labels is not None:
  211. use_cache = False
  212. outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.roberta(
  213. input_ids,
  214. attention_mask=attention_mask,
  215. token_type_ids=token_type_ids,
  216. position_ids=position_ids,
  217. inputs_embeds=inputs_embeds,
  218. encoder_hidden_states=encoder_hidden_states,
  219. encoder_attention_mask=encoder_attention_mask,
  220. past_key_values=past_key_values,
  221. use_cache=use_cache,
  222. return_dict=True,
  223. **kwargs,
  224. )
  225. hidden_states = outputs.last_hidden_state
  226. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  227. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  228. logits = self.lm_head(hidden_states[:, slice_indices, :])
  229. loss = None
  230. if labels is not None:
  231. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  232. return CausalLMOutputWithCrossAttentions(
  233. loss=loss,
  234. logits=logits,
  235. past_key_values=outputs.past_key_values,
  236. hidden_states=outputs.hidden_states,
  237. attentions=outputs.attentions,
  238. cross_attentions=outputs.cross_attentions,
  239. )
  240. @auto_docstring
  241. class RobertaForMaskedLM(RobertaPreTrainedModel):
  242. _tied_weights_keys = {
  243. "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight",
  244. "lm_head.decoder.bias": "lm_head.bias",
  245. }
  246. def __init__(self, config):
  247. super().__init__(config)
  248. if config.is_decoder:
  249. logger.warning(
  250. "If you want to use `RobertaForMaskedLM` make sure `config.is_decoder=False` for "
  251. "bi-directional self-attention."
  252. )
  253. self.roberta = RobertaModel(config, add_pooling_layer=False)
  254. self.lm_head = RobertaLMHead(config)
  255. # Initialize weights and apply final processing
  256. self.post_init()
  257. def get_output_embeddings(self):
  258. return self.lm_head.decoder
  259. def set_output_embeddings(self, new_embeddings):
  260. self.lm_head.decoder = new_embeddings
  261. @can_return_tuple
  262. @auto_docstring
  263. def forward(
  264. self,
  265. input_ids: torch.LongTensor | None = None,
  266. attention_mask: torch.FloatTensor | None = None,
  267. token_type_ids: torch.LongTensor | None = None,
  268. position_ids: torch.LongTensor | None = None,
  269. inputs_embeds: torch.FloatTensor | None = None,
  270. encoder_hidden_states: torch.FloatTensor | None = None,
  271. encoder_attention_mask: torch.FloatTensor | None = None,
  272. labels: torch.LongTensor | None = None,
  273. **kwargs: Unpack[TransformersKwargs],
  274. ) -> tuple[torch.Tensor] | MaskedLMOutput:
  275. r"""
  276. token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  277. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
  278. - 0 corresponds to a *sentence A* token,
  279. - 1 corresponds to a *sentence B* token.
  280. This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
  281. >= 2. All the value in this tensor should be always < type_vocab_size.
  282. [What are token type IDs?](../glossary#token-type-ids)
  283. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  284. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  285. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  286. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  287. """
  288. outputs = self.roberta(
  289. input_ids,
  290. attention_mask=attention_mask,
  291. token_type_ids=token_type_ids,
  292. position_ids=position_ids,
  293. inputs_embeds=inputs_embeds,
  294. encoder_hidden_states=encoder_hidden_states,
  295. encoder_attention_mask=encoder_attention_mask,
  296. return_dict=True,
  297. **kwargs,
  298. )
  299. sequence_output = outputs[0]
  300. prediction_scores = self.lm_head(sequence_output)
  301. masked_lm_loss = None
  302. if labels is not None:
  303. # move labels to correct device
  304. labels = labels.to(prediction_scores.device)
  305. loss_fct = CrossEntropyLoss()
  306. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  307. return MaskedLMOutput(
  308. loss=masked_lm_loss,
  309. logits=prediction_scores,
  310. hidden_states=outputs.hidden_states,
  311. attentions=outputs.attentions,
  312. )
  313. class RobertaLMHead(nn.Module):
  314. """Roberta Head for masked language modeling."""
  315. def __init__(self, config):
  316. super().__init__()
  317. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  318. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  319. self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
  320. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  321. def forward(self, features, **kwargs):
  322. x = self.dense(features)
  323. x = gelu(x)
  324. x = self.layer_norm(x)
  325. # project back to size of vocabulary with bias
  326. x = self.decoder(x)
  327. return x
  328. @auto_docstring(
  329. custom_intro="""
  330. RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
  331. pooled output) e.g. for GLUE tasks.
  332. """
  333. )
  334. class RobertaForSequenceClassification(RobertaPreTrainedModel):
  335. def __init__(self, config):
  336. super().__init__(config)
  337. self.num_labels = config.num_labels
  338. self.config = config
  339. self.roberta = RobertaModel(config, add_pooling_layer=False)
  340. self.classifier = RobertaClassificationHead(config)
  341. # Initialize weights and apply final processing
  342. self.post_init()
  343. @can_return_tuple
  344. @auto_docstring
  345. def forward(
  346. self,
  347. input_ids: torch.LongTensor | None = None,
  348. attention_mask: torch.FloatTensor | None = None,
  349. token_type_ids: torch.LongTensor | None = None,
  350. position_ids: torch.LongTensor | None = None,
  351. inputs_embeds: torch.FloatTensor | None = None,
  352. labels: torch.LongTensor | None = None,
  353. **kwargs: Unpack[TransformersKwargs],
  354. ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
  355. r"""
  356. token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  357. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
  358. - 0 corresponds to a *sentence A* token,
  359. - 1 corresponds to a *sentence B* token.
  360. This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
  361. >= 2. All the value in this tensor should be always < type_vocab_size.
  362. [What are token type IDs?](../glossary#token-type-ids)
  363. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  364. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  365. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  366. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  367. """
  368. outputs = self.roberta(
  369. input_ids,
  370. attention_mask=attention_mask,
  371. token_type_ids=token_type_ids,
  372. position_ids=position_ids,
  373. inputs_embeds=inputs_embeds,
  374. return_dict=True,
  375. **kwargs,
  376. )
  377. sequence_output = outputs[0]
  378. logits = self.classifier(sequence_output)
  379. loss = None
  380. if labels is not None:
  381. # move labels to correct device
  382. labels = labels.to(logits.device)
  383. if self.config.problem_type is None:
  384. if self.num_labels == 1:
  385. self.config.problem_type = "regression"
  386. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  387. self.config.problem_type = "single_label_classification"
  388. else:
  389. self.config.problem_type = "multi_label_classification"
  390. if self.config.problem_type == "regression":
  391. loss_fct = MSELoss()
  392. if self.num_labels == 1:
  393. loss = loss_fct(logits.squeeze(), labels.squeeze())
  394. else:
  395. loss = loss_fct(logits, labels)
  396. elif self.config.problem_type == "single_label_classification":
  397. loss_fct = CrossEntropyLoss()
  398. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  399. elif self.config.problem_type == "multi_label_classification":
  400. loss_fct = BCEWithLogitsLoss()
  401. loss = loss_fct(logits, labels)
  402. return SequenceClassifierOutput(
  403. loss=loss,
  404. logits=logits,
  405. hidden_states=outputs.hidden_states,
  406. attentions=outputs.attentions,
  407. )
  408. @auto_docstring
  409. class RobertaForMultipleChoice(RobertaPreTrainedModel):
  410. def __init__(self, config):
  411. super().__init__(config)
  412. self.roberta = RobertaModel(config)
  413. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  414. self.classifier = nn.Linear(config.hidden_size, 1)
  415. # Initialize weights and apply final processing
  416. self.post_init()
  417. @can_return_tuple
  418. @auto_docstring
  419. def forward(
  420. self,
  421. input_ids: torch.LongTensor | None = None,
  422. token_type_ids: torch.LongTensor | None = None,
  423. attention_mask: torch.FloatTensor | None = None,
  424. labels: torch.LongTensor | None = None,
  425. position_ids: torch.LongTensor | None = None,
  426. inputs_embeds: torch.FloatTensor | None = None,
  427. **kwargs: Unpack[TransformersKwargs],
  428. ) -> tuple[torch.Tensor] | MultipleChoiceModelOutput:
  429. r"""
  430. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  431. Indices of input sequence tokens in the vocabulary.
  432. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  433. [`PreTrainedTokenizer.__call__`] for details.
  434. [What are input IDs?](../glossary#input-ids)
  435. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  436. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
  437. - 0 corresponds to a *sentence A* token,
  438. - 1 corresponds to a *sentence B* token.
  439. This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
  440. >= 2. All the value in this tensor should be always < type_vocab_size.
  441. [What are token type IDs?](../glossary#token-type-ids)
  442. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  443. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  444. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  445. `input_ids` above)
  446. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  447. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  448. config.max_position_embeddings - 1]`.
  449. [What are position IDs?](../glossary#position-ids)
  450. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  451. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  452. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  453. model's internal embedding lookup matrix.
  454. """
  455. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  456. flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  457. flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  458. flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  459. flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  460. flat_inputs_embeds = (
  461. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  462. if inputs_embeds is not None
  463. else None
  464. )
  465. outputs = self.roberta(
  466. flat_input_ids,
  467. position_ids=flat_position_ids,
  468. token_type_ids=flat_token_type_ids,
  469. attention_mask=flat_attention_mask,
  470. inputs_embeds=flat_inputs_embeds,
  471. return_dict=True,
  472. **kwargs,
  473. )
  474. pooled_output = outputs[1]
  475. pooled_output = self.dropout(pooled_output)
  476. logits = self.classifier(pooled_output)
  477. reshaped_logits = logits.view(-1, num_choices)
  478. loss = None
  479. if labels is not None:
  480. # move labels to correct device
  481. labels = labels.to(reshaped_logits.device)
  482. loss_fct = CrossEntropyLoss()
  483. loss = loss_fct(reshaped_logits, labels)
  484. return MultipleChoiceModelOutput(
  485. loss=loss,
  486. logits=reshaped_logits,
  487. hidden_states=outputs.hidden_states,
  488. attentions=outputs.attentions,
  489. )
  490. @auto_docstring
  491. class RobertaForTokenClassification(RobertaPreTrainedModel):
  492. def __init__(self, config):
  493. super().__init__(config)
  494. self.num_labels = config.num_labels
  495. self.roberta = RobertaModel(config, add_pooling_layer=False)
  496. classifier_dropout = (
  497. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  498. )
  499. self.dropout = nn.Dropout(classifier_dropout)
  500. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  501. # Initialize weights and apply final processing
  502. self.post_init()
  503. @can_return_tuple
  504. @auto_docstring
  505. def forward(
  506. self,
  507. input_ids: torch.LongTensor | None = None,
  508. attention_mask: torch.FloatTensor | None = None,
  509. token_type_ids: torch.LongTensor | None = None,
  510. position_ids: torch.LongTensor | None = None,
  511. inputs_embeds: torch.FloatTensor | None = None,
  512. labels: torch.LongTensor | None = None,
  513. **kwargs: Unpack[TransformersKwargs],
  514. ) -> tuple[torch.Tensor] | TokenClassifierOutput:
  515. r"""
  516. token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  517. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
  518. - 0 corresponds to a *sentence A* token,
  519. - 1 corresponds to a *sentence B* token.
  520. This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
  521. >= 2. All the value in this tensor should be always < type_vocab_size.
  522. [What are token type IDs?](../glossary#token-type-ids)
  523. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  524. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  525. """
  526. outputs = self.roberta(
  527. input_ids,
  528. attention_mask=attention_mask,
  529. token_type_ids=token_type_ids,
  530. position_ids=position_ids,
  531. inputs_embeds=inputs_embeds,
  532. return_dict=True,
  533. **kwargs,
  534. )
  535. sequence_output = outputs[0]
  536. sequence_output = self.dropout(sequence_output)
  537. logits = self.classifier(sequence_output)
  538. loss = None
  539. if labels is not None:
  540. # move labels to correct device
  541. labels = labels.to(logits.device)
  542. loss_fct = CrossEntropyLoss()
  543. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  544. return TokenClassifierOutput(
  545. loss=loss,
  546. logits=logits,
  547. hidden_states=outputs.hidden_states,
  548. attentions=outputs.attentions,
  549. )
  550. class RobertaClassificationHead(nn.Module):
  551. """Head for sentence-level classification tasks."""
  552. def __init__(self, config):
  553. super().__init__()
  554. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  555. classifier_dropout = (
  556. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  557. )
  558. self.dropout = nn.Dropout(classifier_dropout)
  559. self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
  560. def forward(self, features, **kwargs):
  561. x = features[:, 0, :] # take <s> token (equiv. to [CLS])
  562. x = self.dropout(x)
  563. x = self.dense(x)
  564. x = torch.tanh(x)
  565. x = self.dropout(x)
  566. x = self.out_proj(x)
  567. return x
  568. @auto_docstring
  569. class RobertaForQuestionAnswering(RobertaPreTrainedModel):
  570. def __init__(self, config):
  571. super().__init__(config)
  572. self.num_labels = config.num_labels
  573. self.roberta = RobertaModel(config, add_pooling_layer=False)
  574. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  575. # Initialize weights and apply final processing
  576. self.post_init()
  577. @can_return_tuple
  578. @auto_docstring
  579. def forward(
  580. self,
  581. input_ids: torch.LongTensor | None = None,
  582. attention_mask: torch.FloatTensor | None = None,
  583. token_type_ids: torch.LongTensor | None = None,
  584. position_ids: torch.LongTensor | None = None,
  585. inputs_embeds: torch.FloatTensor | None = None,
  586. start_positions: torch.LongTensor | None = None,
  587. end_positions: torch.LongTensor | None = None,
  588. **kwargs: Unpack[TransformersKwargs],
  589. ) -> tuple[torch.Tensor] | QuestionAnsweringModelOutput:
  590. r"""
  591. token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  592. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
  593. - 0 corresponds to a *sentence A* token,
  594. - 1 corresponds to a *sentence B* token.
  595. This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
  596. >= 2. All the value in this tensor should be always < type_vocab_size.
  597. [What are token type IDs?](../glossary#token-type-ids)
  598. """
  599. outputs = self.roberta(
  600. input_ids,
  601. attention_mask=attention_mask,
  602. token_type_ids=token_type_ids,
  603. position_ids=position_ids,
  604. inputs_embeds=inputs_embeds,
  605. return_dict=True,
  606. **kwargs,
  607. )
  608. sequence_output = outputs[0]
  609. logits = self.qa_outputs(sequence_output)
  610. start_logits, end_logits = logits.split(1, dim=-1)
  611. start_logits = start_logits.squeeze(-1).contiguous()
  612. end_logits = end_logits.squeeze(-1).contiguous()
  613. total_loss = None
  614. if start_positions is not None and end_positions is not None:
  615. # If we are on multi-GPU, split add a dimension
  616. if len(start_positions.size()) > 1:
  617. start_positions = start_positions.squeeze(-1)
  618. if len(end_positions.size()) > 1:
  619. end_positions = end_positions.squeeze(-1)
  620. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  621. ignored_index = start_logits.size(1)
  622. start_positions = start_positions.clamp(0, ignored_index)
  623. end_positions = end_positions.clamp(0, ignored_index)
  624. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  625. start_loss = loss_fct(start_logits, start_positions)
  626. end_loss = loss_fct(end_logits, end_positions)
  627. total_loss = (start_loss + end_loss) / 2
  628. return QuestionAnsweringModelOutput(
  629. loss=total_loss,
  630. start_logits=start_logits,
  631. end_logits=end_logits,
  632. hidden_states=outputs.hidden_states,
  633. attentions=outputs.attentions,
  634. )
  635. __all__ = [
  636. "RobertaForCausalLM",
  637. "RobertaForMaskedLM",
  638. "RobertaForMultipleChoice",
  639. "RobertaForQuestionAnswering",
  640. "RobertaForSequenceClassification",
  641. "RobertaForTokenClassification",
  642. "RobertaModel",
  643. "RobertaPreTrainedModel",
  644. ]