modular_camembert.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534
  1. # Copyright 2019 Inria, Facebook AI Research and the HuggingFace Inc. team.
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch CamemBERT model."""
  16. import torch
  17. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  18. from ...modeling_outputs import (
  19. BaseModelOutputWithPoolingAndCrossAttentions,
  20. CausalLMOutputWithCrossAttentions,
  21. MaskedLMOutput,
  22. MultipleChoiceModelOutput,
  23. QuestionAnsweringModelOutput,
  24. SequenceClassifierOutput,
  25. TokenClassifierOutput,
  26. )
  27. from ...processing_utils import Unpack
  28. from ...utils import TransformersKwargs, auto_docstring
  29. from ...utils.generic import can_return_tuple
  30. from ..roberta.modeling_roberta import (
  31. RobertaForCausalLM,
  32. RobertaForMaskedLM,
  33. RobertaForMultipleChoice,
  34. RobertaForQuestionAnswering,
  35. RobertaForSequenceClassification,
  36. RobertaForTokenClassification,
  37. RobertaModel,
  38. RobertaPreTrainedModel,
  39. )
  40. class CamembertPreTrainedModel(RobertaPreTrainedModel):
  41. base_model_prefix = "roberta"
  42. class CamembertModel(RobertaModel):
  43. pass
  44. class CamembertForMaskedLM(RobertaForMaskedLM):
  45. _tied_weights_keys = {
  46. "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight",
  47. "lm_head.decoder.bias": "lm_head.bias",
  48. }
  49. def __init__(self, config):
  50. super().__init__(config)
  51. del self.camembert
  52. self.roberta = CamembertModel(config, add_pooling_layer=False)
  53. @can_return_tuple
  54. @auto_docstring
  55. def forward(
  56. self,
  57. input_ids: torch.LongTensor | None = None,
  58. attention_mask: torch.FloatTensor | None = None,
  59. token_type_ids: torch.LongTensor | None = None,
  60. position_ids: torch.LongTensor | None = None,
  61. inputs_embeds: torch.FloatTensor | None = None,
  62. encoder_hidden_states: torch.FloatTensor | None = None,
  63. encoder_attention_mask: torch.FloatTensor | None = None,
  64. labels: torch.LongTensor | None = None,
  65. **kwargs: Unpack[TransformersKwargs],
  66. ) -> tuple[torch.Tensor] | MaskedLMOutput:
  67. r"""
  68. token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  69. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
  70. - 0 corresponds to a *sentence A* token,
  71. - 1 corresponds to a *sentence B* token.
  72. This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
  73. >= 2. All the value in this tensor should be always < type_vocab_size.
  74. [What are token type IDs?](../glossary#token-type-ids)
  75. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  76. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  77. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  78. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  79. """
  80. outputs = self.roberta(
  81. input_ids,
  82. attention_mask=attention_mask,
  83. token_type_ids=token_type_ids,
  84. position_ids=position_ids,
  85. inputs_embeds=inputs_embeds,
  86. encoder_hidden_states=encoder_hidden_states,
  87. encoder_attention_mask=encoder_attention_mask,
  88. return_dict=True,
  89. **kwargs,
  90. )
  91. sequence_output = outputs[0]
  92. prediction_scores = self.lm_head(sequence_output)
  93. masked_lm_loss = None
  94. if labels is not None:
  95. # move labels to correct device
  96. labels = labels.to(prediction_scores.device)
  97. loss_fct = CrossEntropyLoss()
  98. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  99. return MaskedLMOutput(
  100. loss=masked_lm_loss,
  101. logits=prediction_scores,
  102. hidden_states=outputs.hidden_states,
  103. attentions=outputs.attentions,
  104. )
  105. class CamembertForSequenceClassification(RobertaForSequenceClassification):
  106. def __init__(self, config):
  107. super().__init__(config)
  108. del self.camembert
  109. self.roberta = CamembertModel(config, add_pooling_layer=False)
  110. @can_return_tuple
  111. @auto_docstring
  112. def forward(
  113. self,
  114. input_ids: torch.LongTensor | None = None,
  115. attention_mask: torch.FloatTensor | None = None,
  116. token_type_ids: torch.LongTensor | None = None,
  117. position_ids: torch.LongTensor | None = None,
  118. inputs_embeds: torch.FloatTensor | None = None,
  119. labels: torch.LongTensor | None = None,
  120. **kwargs: Unpack[TransformersKwargs],
  121. ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
  122. r"""
  123. token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  124. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
  125. - 0 corresponds to a *sentence A* token,
  126. - 1 corresponds to a *sentence B* token.
  127. This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
  128. >= 2. All the value in this tensor should be always < type_vocab_size.
  129. [What are token type IDs?](../glossary#token-type-ids)
  130. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  131. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  132. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  133. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  134. """
  135. outputs = self.roberta(
  136. input_ids,
  137. attention_mask=attention_mask,
  138. token_type_ids=token_type_ids,
  139. position_ids=position_ids,
  140. inputs_embeds=inputs_embeds,
  141. return_dict=True,
  142. **kwargs,
  143. )
  144. sequence_output = outputs[0]
  145. logits = self.classifier(sequence_output)
  146. loss = None
  147. if labels is not None:
  148. # move labels to correct device
  149. labels = labels.to(logits.device)
  150. if self.config.problem_type is None:
  151. if self.num_labels == 1:
  152. self.config.problem_type = "regression"
  153. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  154. self.config.problem_type = "single_label_classification"
  155. else:
  156. self.config.problem_type = "multi_label_classification"
  157. if self.config.problem_type == "regression":
  158. loss_fct = MSELoss()
  159. if self.num_labels == 1:
  160. loss = loss_fct(logits.squeeze(), labels.squeeze())
  161. else:
  162. loss = loss_fct(logits, labels)
  163. elif self.config.problem_type == "single_label_classification":
  164. loss_fct = CrossEntropyLoss()
  165. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  166. elif self.config.problem_type == "multi_label_classification":
  167. loss_fct = BCEWithLogitsLoss()
  168. loss = loss_fct(logits, labels)
  169. return SequenceClassifierOutput(
  170. loss=loss,
  171. logits=logits,
  172. hidden_states=outputs.hidden_states,
  173. attentions=outputs.attentions,
  174. )
  175. class CamembertForMultipleChoice(RobertaForMultipleChoice):
  176. def __init__(self, config):
  177. super().__init__(config)
  178. del self.camembert
  179. self.roberta = CamembertModel(config, add_pooling_layer=False)
  180. @can_return_tuple
  181. @auto_docstring
  182. def forward(
  183. self,
  184. input_ids: torch.LongTensor | None = None,
  185. token_type_ids: torch.LongTensor | None = None,
  186. attention_mask: torch.FloatTensor | None = None,
  187. labels: torch.LongTensor | None = None,
  188. position_ids: torch.LongTensor | None = None,
  189. inputs_embeds: torch.FloatTensor | None = None,
  190. **kwargs: Unpack[TransformersKwargs],
  191. ) -> tuple[torch.Tensor] | MultipleChoiceModelOutput:
  192. r"""
  193. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  194. Indices of input sequence tokens in the vocabulary.
  195. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  196. [`PreTrainedTokenizer.__call__`] for details.
  197. [What are input IDs?](../glossary#input-ids)
  198. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  199. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
  200. - 0 corresponds to a *sentence A* token,
  201. - 1 corresponds to a *sentence B* token.
  202. This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
  203. >= 2. All the value in this tensor should be always < type_vocab_size.
  204. [What are token type IDs?](../glossary#token-type-ids)
  205. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  206. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  207. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  208. `input_ids` above)
  209. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  210. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  211. config.max_position_embeddings - 1]`.
  212. [What are position IDs?](../glossary#position-ids)
  213. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  214. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  215. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  216. model's internal embedding lookup matrix.
  217. """
  218. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  219. flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  220. flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  221. flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  222. flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  223. flat_inputs_embeds = (
  224. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  225. if inputs_embeds is not None
  226. else None
  227. )
  228. outputs = self.roberta(
  229. flat_input_ids,
  230. position_ids=flat_position_ids,
  231. token_type_ids=flat_token_type_ids,
  232. attention_mask=flat_attention_mask,
  233. inputs_embeds=flat_inputs_embeds,
  234. return_dict=True,
  235. **kwargs,
  236. )
  237. pooled_output = outputs[1]
  238. pooled_output = self.dropout(pooled_output)
  239. logits = self.classifier(pooled_output)
  240. reshaped_logits = logits.view(-1, num_choices)
  241. loss = None
  242. if labels is not None:
  243. # move labels to correct device
  244. labels = labels.to(reshaped_logits.device)
  245. loss_fct = CrossEntropyLoss()
  246. loss = loss_fct(reshaped_logits, labels)
  247. return MultipleChoiceModelOutput(
  248. loss=loss,
  249. logits=reshaped_logits,
  250. hidden_states=outputs.hidden_states,
  251. attentions=outputs.attentions,
  252. )
  253. class CamembertForTokenClassification(RobertaForTokenClassification):
  254. def __init__(self, config):
  255. super().__init__(config)
  256. del self.camembert
  257. self.roberta = CamembertModel(config, add_pooling_layer=False)
  258. @can_return_tuple
  259. @auto_docstring
  260. def forward(
  261. self,
  262. input_ids: torch.LongTensor | None = None,
  263. attention_mask: torch.FloatTensor | None = None,
  264. token_type_ids: torch.LongTensor | None = None,
  265. position_ids: torch.LongTensor | None = None,
  266. inputs_embeds: torch.FloatTensor | None = None,
  267. labels: torch.LongTensor | None = None,
  268. **kwargs: Unpack[TransformersKwargs],
  269. ) -> tuple[torch.Tensor] | TokenClassifierOutput:
  270. r"""
  271. token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  272. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
  273. - 0 corresponds to a *sentence A* token,
  274. - 1 corresponds to a *sentence B* token.
  275. This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
  276. >= 2. All the value in this tensor should be always < type_vocab_size.
  277. [What are token type IDs?](../glossary#token-type-ids)
  278. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  279. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  280. """
  281. outputs = self.roberta(
  282. input_ids,
  283. attention_mask=attention_mask,
  284. token_type_ids=token_type_ids,
  285. position_ids=position_ids,
  286. inputs_embeds=inputs_embeds,
  287. return_dict=True,
  288. **kwargs,
  289. )
  290. sequence_output = outputs[0]
  291. sequence_output = self.dropout(sequence_output)
  292. logits = self.classifier(sequence_output)
  293. loss = None
  294. if labels is not None:
  295. # move labels to correct device
  296. labels = labels.to(logits.device)
  297. loss_fct = CrossEntropyLoss()
  298. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  299. return TokenClassifierOutput(
  300. loss=loss,
  301. logits=logits,
  302. hidden_states=outputs.hidden_states,
  303. attentions=outputs.attentions,
  304. )
  305. class CamembertForQuestionAnswering(RobertaForQuestionAnswering):
  306. def __init__(self, config):
  307. super().__init__(config)
  308. del self.camembert
  309. self.roberta = CamembertModel(config, add_pooling_layer=False)
  310. @can_return_tuple
  311. @auto_docstring
  312. def forward(
  313. self,
  314. input_ids: torch.LongTensor | None = None,
  315. attention_mask: torch.FloatTensor | None = None,
  316. token_type_ids: torch.LongTensor | None = None,
  317. position_ids: torch.LongTensor | None = None,
  318. inputs_embeds: torch.FloatTensor | None = None,
  319. start_positions: torch.LongTensor | None = None,
  320. end_positions: torch.LongTensor | None = None,
  321. **kwargs: Unpack[TransformersKwargs],
  322. ) -> tuple[torch.Tensor] | QuestionAnsweringModelOutput:
  323. r"""
  324. token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  325. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
  326. - 0 corresponds to a *sentence A* token,
  327. - 1 corresponds to a *sentence B* token.
  328. This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
  329. >= 2. All the value in this tensor should be always < type_vocab_size.
  330. [What are token type IDs?](../glossary#token-type-ids)
  331. """
  332. outputs = self.roberta(
  333. input_ids,
  334. attention_mask=attention_mask,
  335. token_type_ids=token_type_ids,
  336. position_ids=position_ids,
  337. inputs_embeds=inputs_embeds,
  338. return_dict=True,
  339. **kwargs,
  340. )
  341. sequence_output = outputs[0]
  342. logits = self.qa_outputs(sequence_output)
  343. start_logits, end_logits = logits.split(1, dim=-1)
  344. start_logits = start_logits.squeeze(-1).contiguous()
  345. end_logits = end_logits.squeeze(-1).contiguous()
  346. total_loss = None
  347. if start_positions is not None and end_positions is not None:
  348. # If we are on multi-GPU, split add a dimension
  349. if len(start_positions.size()) > 1:
  350. start_positions = start_positions.squeeze(-1)
  351. if len(end_positions.size()) > 1:
  352. end_positions = end_positions.squeeze(-1)
  353. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  354. ignored_index = start_logits.size(1)
  355. start_positions = start_positions.clamp(0, ignored_index)
  356. end_positions = end_positions.clamp(0, ignored_index)
  357. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  358. start_loss = loss_fct(start_logits, start_positions)
  359. end_loss = loss_fct(end_logits, end_positions)
  360. total_loss = (start_loss + end_loss) / 2
  361. return QuestionAnsweringModelOutput(
  362. loss=total_loss,
  363. start_logits=start_logits,
  364. end_logits=end_logits,
  365. hidden_states=outputs.hidden_states,
  366. attentions=outputs.attentions,
  367. )
  368. class CamembertForCausalLM(RobertaForCausalLM):
  369. _tied_weights_keys = {
  370. "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight",
  371. "lm_head.decoder.bias": "lm_head.bias",
  372. }
  373. def __init__(self, config):
  374. super().__init__(config)
  375. del self.camembert
  376. self.roberta = CamembertModel(config, add_pooling_layer=False)
  377. @can_return_tuple
  378. @auto_docstring
  379. def forward(
  380. self,
  381. input_ids: torch.LongTensor | None = None,
  382. attention_mask: torch.FloatTensor | None = None,
  383. token_type_ids: torch.LongTensor | None = None,
  384. position_ids: torch.LongTensor | None = None,
  385. inputs_embeds: torch.FloatTensor | None = None,
  386. encoder_hidden_states: torch.FloatTensor | None = None,
  387. encoder_attention_mask: torch.FloatTensor | None = None,
  388. labels: torch.LongTensor | None = None,
  389. past_key_values: tuple[tuple[torch.FloatTensor]] | None = None,
  390. use_cache: bool | None = None,
  391. logits_to_keep: int | torch.Tensor = 0,
  392. **kwargs: Unpack[TransformersKwargs],
  393. ) -> tuple[torch.Tensor] | CausalLMOutputWithCrossAttentions:
  394. r"""
  395. token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  396. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
  397. - 0 corresponds to a *sentence A* token,
  398. - 1 corresponds to a *sentence B* token.
  399. This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
  400. >= 2. All the value in this tensor should be always < type_vocab_size.
  401. [What are token type IDs?](../glossary#token-type-ids)
  402. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  403. Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
  404. `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
  405. ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  406. Example:
  407. ```python
  408. >>> from transformers import AutoTokenizer, CamembertForCausalLM, AutoConfig
  409. >>> import torch
  410. >>> tokenizer = AutoTokenizer.from_pretrained("almanach/camembert-base")
  411. >>> config = AutoConfig.from_pretrained("almanach/camembert-base")
  412. >>> config.is_decoder = True
  413. >>> model = CamembertForCausalLM.from_pretrained("almanach/camembert-base", config=config)
  414. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  415. >>> outputs = model(**inputs)
  416. >>> prediction_logits = outputs.logits
  417. ```"""
  418. if labels is not None:
  419. use_cache = False
  420. outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.roberta(
  421. input_ids,
  422. attention_mask=attention_mask,
  423. token_type_ids=token_type_ids,
  424. position_ids=position_ids,
  425. inputs_embeds=inputs_embeds,
  426. encoder_hidden_states=encoder_hidden_states,
  427. encoder_attention_mask=encoder_attention_mask,
  428. past_key_values=past_key_values,
  429. use_cache=use_cache,
  430. return_dict=True,
  431. **kwargs,
  432. )
  433. hidden_states = outputs.last_hidden_state
  434. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  435. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  436. logits = self.lm_head(hidden_states[:, slice_indices, :])
  437. loss = None
  438. if labels is not None:
  439. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  440. return CausalLMOutputWithCrossAttentions(
  441. loss=loss,
  442. logits=logits,
  443. past_key_values=outputs.past_key_values,
  444. hidden_states=outputs.hidden_states,
  445. attentions=outputs.attentions,
  446. cross_attentions=outputs.cross_attentions,
  447. )
  448. __all__ = [
  449. "CamembertForCausalLM",
  450. "CamembertForMaskedLM",
  451. "CamembertForMultipleChoice",
  452. "CamembertForQuestionAnswering",
  453. "CamembertForSequenceClassification",
  454. "CamembertForTokenClassification",
  455. "CamembertModel",
  456. "CamembertPreTrainedModel",
  457. ]