modular_ernie.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851
  1. # Copyright 2022 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch ERNIE model."""
  15. import torch
  16. import torch.nn as nn
  17. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  18. from ... import initialization as init
  19. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  20. from ...modeling_outputs import (
  21. BaseModelOutputWithPoolingAndCrossAttentions,
  22. CausalLMOutputWithCrossAttentions,
  23. MaskedLMOutput,
  24. MultipleChoiceModelOutput,
  25. NextSentencePredictorOutput,
  26. QuestionAnsweringModelOutput,
  27. SequenceClassifierOutput,
  28. TokenClassifierOutput,
  29. )
  30. from ...modeling_utils import PreTrainedModel
  31. from ...processing_utils import Unpack
  32. from ...utils import TransformersKwargs, auto_docstring, logging
  33. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  34. from ...utils.output_capturing import capture_outputs
  35. from ..bert.modeling_bert import (
  36. BertCrossAttention,
  37. BertEmbeddings,
  38. BertEncoder,
  39. BertForMaskedLM,
  40. BertForMultipleChoice,
  41. BertForNextSentencePrediction,
  42. BertForPreTraining,
  43. BertForPreTrainingOutput,
  44. BertForQuestionAnswering,
  45. BertForSequenceClassification,
  46. BertForTokenClassification,
  47. BertLayer,
  48. BertLMHeadModel,
  49. BertLMPredictionHead,
  50. BertModel,
  51. BertPooler,
  52. BertSelfAttention,
  53. )
  54. from .configuration_ernie import ErnieConfig
  55. logger = logging.get_logger(__name__)
  56. class ErnieEmbeddings(BertEmbeddings):
  57. """Construct the embeddings from word, position and token_type embeddings."""
  58. def __init__(self, config):
  59. super().__init__(config)
  60. self.use_task_id = config.use_task_id
  61. if config.use_task_id:
  62. self.task_type_embeddings = nn.Embedding(config.task_type_vocab_size, config.hidden_size)
  63. def forward(
  64. self,
  65. input_ids: torch.LongTensor | None = None,
  66. token_type_ids: torch.LongTensor | None = None,
  67. task_type_ids: torch.LongTensor | None = None,
  68. position_ids: torch.LongTensor | None = None,
  69. inputs_embeds: torch.FloatTensor | None = None,
  70. past_key_values_length: int = 0,
  71. ) -> torch.Tensor:
  72. if input_ids is not None:
  73. input_shape = input_ids.size()
  74. else:
  75. input_shape = inputs_embeds.size()[:-1]
  76. batch_size, seq_length = input_shape
  77. if position_ids is None:
  78. position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
  79. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  80. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  81. # issue #5664
  82. if token_type_ids is None:
  83. if hasattr(self, "token_type_ids"):
  84. # NOTE: We assume either pos ids to have bsz == 1 (broadcastable) or bsz == effective bsz (input_shape[0])
  85. buffered_token_type_ids = self.token_type_ids.expand(position_ids.shape[0], -1)
  86. buffered_token_type_ids = torch.gather(buffered_token_type_ids, dim=1, index=position_ids)
  87. token_type_ids = buffered_token_type_ids.expand(batch_size, seq_length)
  88. else:
  89. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  90. if inputs_embeds is None:
  91. inputs_embeds = self.word_embeddings(input_ids)
  92. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  93. # .to is better than using _no_split_modules on ErnieEmbeddings as it's the first module and >1/2 the model size
  94. inputs_embeds = inputs_embeds.to(token_type_embeddings.device)
  95. embeddings = inputs_embeds + token_type_embeddings
  96. position_embeddings = self.position_embeddings(position_ids)
  97. embeddings = embeddings + position_embeddings
  98. # add `task_type_id` for ERNIE model
  99. if self.use_task_id:
  100. if task_type_ids is None:
  101. task_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  102. task_type_embeddings = self.task_type_embeddings(task_type_ids)
  103. embeddings += task_type_embeddings
  104. embeddings = self.LayerNorm(embeddings)
  105. embeddings = self.dropout(embeddings)
  106. return embeddings
  107. class ErnieSelfAttention(BertSelfAttention):
  108. pass
  109. class ErnieCrossAttention(BertCrossAttention):
  110. pass
  111. class ErnieLayer(BertLayer):
  112. pass
  113. class ErniePooler(BertPooler):
  114. pass
  115. class ErnieLMPredictionHead(BertLMPredictionHead):
  116. pass
  117. class ErnieEncoder(BertEncoder):
  118. pass
  119. @auto_docstring
  120. class ErniePreTrainedModel(PreTrainedModel):
  121. config_class = ErnieConfig
  122. base_model_prefix = "ernie"
  123. supports_gradient_checkpointing = True
  124. _supports_flash_attn = True
  125. _supports_sdpa = True
  126. _supports_flex_attn = True
  127. _supports_attention_backend = True
  128. _can_record_outputs = {
  129. "hidden_states": ErnieLayer,
  130. "attentions": ErnieSelfAttention,
  131. "cross_attentions": ErnieCrossAttention,
  132. }
  133. @torch.no_grad()
  134. def _init_weights(self, module):
  135. """Initialize the weights"""
  136. super()._init_weights(module)
  137. if isinstance(module, ErnieLMPredictionHead):
  138. init.zeros_(module.bias)
  139. elif isinstance(module, ErnieEmbeddings):
  140. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  141. init.zeros_(module.token_type_ids)
  142. class ErnieModel(BertModel):
  143. _no_split_modules = ["ErnieLayer"]
  144. def __init__(self, config, add_pooling_layer=True):
  145. super().__init__(self, config)
  146. self.config = config
  147. self.gradient_checkpointing = False
  148. self.embeddings = ErnieEmbeddings(config)
  149. self.encoder = ErnieEncoder(config)
  150. self.pooler = ErniePooler(config) if add_pooling_layer else None
  151. # Initialize weights and apply final processing
  152. self.post_init()
  153. @merge_with_config_defaults
  154. @capture_outputs
  155. @auto_docstring
  156. def forward(
  157. self,
  158. input_ids: torch.Tensor | None = None,
  159. attention_mask: torch.Tensor | None = None,
  160. token_type_ids: torch.Tensor | None = None,
  161. task_type_ids: torch.Tensor | None = None,
  162. position_ids: torch.Tensor | None = None,
  163. inputs_embeds: torch.Tensor | None = None,
  164. encoder_hidden_states: torch.Tensor | None = None,
  165. encoder_attention_mask: torch.Tensor | None = None,
  166. past_key_values: Cache | None = None,
  167. use_cache: bool | None = None,
  168. **kwargs: Unpack[TransformersKwargs],
  169. ) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions:
  170. r"""
  171. task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  172. Task type embedding is a special embedding to represent the characteristic of different tasks, such as
  173. word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
  174. assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
  175. config.task_type_vocab_size-1]
  176. """
  177. if (input_ids is None) ^ (inputs_embeds is not None):
  178. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  179. if self.config.is_decoder:
  180. use_cache = use_cache if use_cache is not None else self.config.use_cache
  181. else:
  182. use_cache = False
  183. if use_cache and past_key_values is None:
  184. past_key_values = (
  185. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  186. if encoder_hidden_states is not None or self.config.is_encoder_decoder
  187. else DynamicCache(config=self.config)
  188. )
  189. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  190. embedding_output = self.embeddings(
  191. input_ids=input_ids,
  192. position_ids=position_ids,
  193. token_type_ids=token_type_ids,
  194. # specific to ernie
  195. task_type_ids=task_type_ids,
  196. inputs_embeds=inputs_embeds,
  197. past_key_values_length=past_key_values_length,
  198. )
  199. attention_mask, encoder_attention_mask = self._create_attention_masks(
  200. attention_mask=attention_mask,
  201. encoder_attention_mask=encoder_attention_mask,
  202. embedding_output=embedding_output,
  203. encoder_hidden_states=encoder_hidden_states,
  204. past_key_values=past_key_values,
  205. )
  206. encoder_outputs = self.encoder(
  207. embedding_output,
  208. attention_mask=attention_mask,
  209. encoder_hidden_states=encoder_hidden_states,
  210. encoder_attention_mask=encoder_attention_mask,
  211. past_key_values=past_key_values,
  212. use_cache=use_cache,
  213. position_ids=position_ids,
  214. **kwargs,
  215. )
  216. sequence_output = encoder_outputs[0]
  217. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  218. return BaseModelOutputWithPoolingAndCrossAttentions(
  219. last_hidden_state=sequence_output,
  220. pooler_output=pooled_output,
  221. past_key_values=encoder_outputs.past_key_values,
  222. )
  223. class ErnieForPreTrainingOutput(BertForPreTrainingOutput):
  224. pass
  225. class ErnieForPreTraining(BertForPreTraining):
  226. _tied_weights_keys = {
  227. "cls.predictions.decoder.bias": "cls.predictions.bias",
  228. "cls.predictions.decoder.weight": "ernie.embeddings.word_embeddings.weight",
  229. }
  230. @can_return_tuple
  231. @auto_docstring
  232. def forward(
  233. self,
  234. input_ids: torch.Tensor | None = None,
  235. attention_mask: torch.Tensor | None = None,
  236. token_type_ids: torch.Tensor | None = None,
  237. task_type_ids: torch.Tensor | None = None,
  238. position_ids: torch.Tensor | None = None,
  239. inputs_embeds: torch.Tensor | None = None,
  240. labels: torch.Tensor | None = None,
  241. next_sentence_label: torch.Tensor | None = None,
  242. **kwargs: Unpack[TransformersKwargs],
  243. ) -> tuple[torch.Tensor] | ErnieForPreTrainingOutput:
  244. r"""
  245. task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  246. Task type embedding is a special embedding to represent the characteristic of different tasks, such as
  247. word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
  248. assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
  249. config.task_type_vocab_size-1]
  250. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  251. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  252. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
  253. the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  254. next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  255. Labels for computing the next sequence prediction (classification) loss. Input should be a sequence
  256. pair (see `input_ids` docstring) Indices should be in `[0, 1]`:
  257. - 0 indicates sequence B is a continuation of sequence A,
  258. - 1 indicates sequence B is a random sequence.
  259. Example:
  260. ```python
  261. >>> from transformers import AutoTokenizer, ErnieForPreTraining
  262. >>> import torch
  263. >>> tokenizer = AutoTokenizer.from_pretrained("nghuyong/ernie-1.0-base-zh")
  264. >>> model = ErnieForPreTraining.from_pretrained("nghuyong/ernie-1.0-base-zh")
  265. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  266. >>> outputs = model(**inputs)
  267. >>> prediction_logits = outputs.prediction_logits
  268. >>> seq_relationship_logits = outputs.seq_relationship_logits
  269. ```
  270. """
  271. outputs = self.ernie(
  272. input_ids,
  273. attention_mask=attention_mask,
  274. token_type_ids=token_type_ids,
  275. task_type_ids=task_type_ids,
  276. position_ids=position_ids,
  277. inputs_embeds=inputs_embeds,
  278. return_dict=True,
  279. **kwargs,
  280. )
  281. sequence_output, pooled_output = outputs[:2]
  282. prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
  283. total_loss = None
  284. if labels is not None and next_sentence_label is not None:
  285. loss_fct = CrossEntropyLoss()
  286. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  287. next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
  288. total_loss = masked_lm_loss + next_sentence_loss
  289. return ErnieForPreTrainingOutput(
  290. loss=total_loss,
  291. prediction_logits=prediction_scores,
  292. seq_relationship_logits=seq_relationship_score,
  293. hidden_states=outputs.hidden_states,
  294. attentions=outputs.attentions,
  295. )
  296. class ErnieForCausalLM(BertLMHeadModel):
  297. @can_return_tuple
  298. @auto_docstring
  299. def forward(
  300. self,
  301. input_ids: torch.Tensor | None = None,
  302. attention_mask: torch.Tensor | None = None,
  303. token_type_ids: torch.Tensor | None = None,
  304. task_type_ids: torch.Tensor | None = None,
  305. position_ids: torch.Tensor | None = None,
  306. inputs_embeds: torch.Tensor | None = None,
  307. encoder_hidden_states: torch.Tensor | None = None,
  308. encoder_attention_mask: torch.Tensor | None = None,
  309. labels: torch.Tensor | None = None,
  310. past_key_values: list[torch.Tensor] | None = None,
  311. use_cache: bool | None = None,
  312. logits_to_keep: int | torch.Tensor = 0,
  313. **kwargs: Unpack[TransformersKwargs],
  314. ) -> tuple[torch.Tensor] | CausalLMOutputWithCrossAttentions:
  315. r"""
  316. task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  317. Task type embedding is a special embedding to represent the characteristic of different tasks, such as
  318. word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
  319. assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
  320. config.task_type_vocab_size-1]
  321. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  322. Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
  323. `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
  324. ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
  325. """
  326. if labels is not None:
  327. use_cache = False
  328. outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.ernie(
  329. input_ids,
  330. attention_mask=attention_mask,
  331. token_type_ids=token_type_ids,
  332. task_type_ids=task_type_ids,
  333. position_ids=position_ids,
  334. inputs_embeds=inputs_embeds,
  335. encoder_hidden_states=encoder_hidden_states,
  336. encoder_attention_mask=encoder_attention_mask,
  337. past_key_values=past_key_values,
  338. use_cache=use_cache,
  339. return_dict=True,
  340. **kwargs,
  341. )
  342. hidden_states = outputs.last_hidden_state
  343. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  344. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  345. logits = self.cls(hidden_states[:, slice_indices, :])
  346. loss = None
  347. if labels is not None:
  348. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  349. return CausalLMOutputWithCrossAttentions(
  350. loss=loss,
  351. logits=logits,
  352. past_key_values=outputs.past_key_values,
  353. hidden_states=outputs.hidden_states,
  354. attentions=outputs.attentions,
  355. cross_attentions=outputs.cross_attentions,
  356. )
  357. class ErnieForMaskedLM(BertForMaskedLM):
  358. _tied_weights_keys = {
  359. "cls.predictions.decoder.bias": "cls.predictions.bias",
  360. "cls.predictions.decoder.weight": "ernie.embeddings.word_embeddings.weight",
  361. }
  362. @can_return_tuple
  363. @auto_docstring
  364. def forward(
  365. self,
  366. input_ids: torch.Tensor | None = None,
  367. attention_mask: torch.Tensor | None = None,
  368. token_type_ids: torch.Tensor | None = None,
  369. task_type_ids: torch.Tensor | None = None,
  370. position_ids: torch.Tensor | None = None,
  371. inputs_embeds: torch.Tensor | None = None,
  372. encoder_hidden_states: torch.Tensor | None = None,
  373. encoder_attention_mask: torch.Tensor | None = None,
  374. labels: torch.Tensor | None = None,
  375. **kwargs: Unpack[TransformersKwargs],
  376. ) -> tuple[torch.Tensor] | MaskedLMOutput:
  377. r"""
  378. task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  379. Task type embedding is a special embedding to represent the characteristic of different tasks, such as
  380. word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
  381. assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
  382. config.task_type_vocab_size-1]
  383. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  384. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  385. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  386. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  387. """
  388. outputs = self.ernie(
  389. input_ids,
  390. attention_mask=attention_mask,
  391. token_type_ids=token_type_ids,
  392. task_type_ids=task_type_ids,
  393. position_ids=position_ids,
  394. inputs_embeds=inputs_embeds,
  395. encoder_hidden_states=encoder_hidden_states,
  396. encoder_attention_mask=encoder_attention_mask,
  397. return_dict=True,
  398. **kwargs,
  399. )
  400. sequence_output = outputs[0]
  401. prediction_scores = self.cls(sequence_output)
  402. masked_lm_loss = None
  403. if labels is not None:
  404. loss_fct = CrossEntropyLoss() # -100 index = padding token
  405. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  406. return MaskedLMOutput(
  407. loss=masked_lm_loss,
  408. logits=prediction_scores,
  409. hidden_states=outputs.hidden_states,
  410. attentions=outputs.attentions,
  411. )
  412. class ErnieForNextSentencePrediction(BertForNextSentencePrediction):
  413. @can_return_tuple
  414. @auto_docstring
  415. def forward(
  416. self,
  417. input_ids: torch.Tensor | None = None,
  418. attention_mask: torch.Tensor | None = None,
  419. token_type_ids: torch.Tensor | None = None,
  420. task_type_ids: torch.Tensor | None = None,
  421. position_ids: torch.Tensor | None = None,
  422. inputs_embeds: torch.Tensor | None = None,
  423. labels: torch.Tensor | None = None,
  424. **kwargs: Unpack[TransformersKwargs],
  425. ) -> tuple[torch.Tensor] | NextSentencePredictorOutput:
  426. r"""
  427. task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  428. Task type embedding is a special embedding to represent the characteristic of different tasks, such as
  429. word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
  430. assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
  431. config.task_type_vocab_size-1]
  432. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  433. Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
  434. (see `input_ids` docstring). Indices should be in `[0, 1]`:
  435. - 0 indicates sequence B is a continuation of sequence A,
  436. - 1 indicates sequence B is a random sequence.
  437. Example:
  438. ```python
  439. >>> from transformers import AutoTokenizer, ErnieForNextSentencePrediction
  440. >>> import torch
  441. >>> tokenizer = AutoTokenizer.from_pretrained("nghuyong/ernie-1.0-base-zh")
  442. >>> model = ErnieForNextSentencePrediction.from_pretrained("nghuyong/ernie-1.0-base-zh")
  443. >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
  444. >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
  445. >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
  446. >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
  447. >>> logits = outputs.logits
  448. >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
  449. ```
  450. """
  451. outputs = self.ernie(
  452. input_ids,
  453. attention_mask=attention_mask,
  454. token_type_ids=token_type_ids,
  455. task_type_ids=task_type_ids,
  456. position_ids=position_ids,
  457. inputs_embeds=inputs_embeds,
  458. return_dict=True,
  459. **kwargs,
  460. )
  461. pooled_output = outputs[1]
  462. seq_relationship_scores = self.cls(pooled_output)
  463. next_sentence_loss = None
  464. if labels is not None:
  465. loss_fct = CrossEntropyLoss()
  466. next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
  467. return NextSentencePredictorOutput(
  468. loss=next_sentence_loss,
  469. logits=seq_relationship_scores,
  470. hidden_states=outputs.hidden_states,
  471. attentions=outputs.attentions,
  472. )
  473. class ErnieForSequenceClassification(BertForSequenceClassification):
  474. @can_return_tuple
  475. @auto_docstring
  476. def forward(
  477. self,
  478. input_ids: torch.Tensor | None = None,
  479. attention_mask: torch.Tensor | None = None,
  480. token_type_ids: torch.Tensor | None = None,
  481. task_type_ids: torch.Tensor | None = None,
  482. position_ids: torch.Tensor | None = None,
  483. inputs_embeds: torch.Tensor | None = None,
  484. labels: torch.Tensor | None = None,
  485. **kwargs: Unpack[TransformersKwargs],
  486. ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
  487. r"""
  488. task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  489. Task type embedding is a special embedding to represent the characteristic of different tasks, such as
  490. word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
  491. assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
  492. config.task_type_vocab_size-1]
  493. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  494. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  495. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  496. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  497. """
  498. outputs = self.ernie(
  499. input_ids,
  500. attention_mask=attention_mask,
  501. token_type_ids=token_type_ids,
  502. task_type_ids=task_type_ids,
  503. position_ids=position_ids,
  504. inputs_embeds=inputs_embeds,
  505. return_dict=True,
  506. **kwargs,
  507. )
  508. pooled_output = outputs[1]
  509. pooled_output = self.dropout(pooled_output)
  510. logits = self.classifier(pooled_output)
  511. loss = None
  512. if labels is not None:
  513. if self.config.problem_type is None:
  514. if self.num_labels == 1:
  515. self.config.problem_type = "regression"
  516. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  517. self.config.problem_type = "single_label_classification"
  518. else:
  519. self.config.problem_type = "multi_label_classification"
  520. if self.config.problem_type == "regression":
  521. loss_fct = MSELoss()
  522. if self.num_labels == 1:
  523. loss = loss_fct(logits.squeeze(), labels.squeeze())
  524. else:
  525. loss = loss_fct(logits, labels)
  526. elif self.config.problem_type == "single_label_classification":
  527. loss_fct = CrossEntropyLoss()
  528. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  529. elif self.config.problem_type == "multi_label_classification":
  530. loss_fct = BCEWithLogitsLoss()
  531. loss = loss_fct(logits, labels)
  532. return SequenceClassifierOutput(
  533. loss=loss,
  534. logits=logits,
  535. hidden_states=outputs.hidden_states,
  536. attentions=outputs.attentions,
  537. )
  538. class ErnieForMultipleChoice(BertForMultipleChoice):
  539. @can_return_tuple
  540. @auto_docstring
  541. def forward(
  542. self,
  543. input_ids: torch.Tensor | None = None,
  544. attention_mask: torch.Tensor | None = None,
  545. token_type_ids: torch.Tensor | None = None,
  546. task_type_ids: torch.Tensor | None = None,
  547. position_ids: torch.Tensor | None = None,
  548. inputs_embeds: torch.Tensor | None = None,
  549. labels: torch.Tensor | None = None,
  550. **kwargs: Unpack[TransformersKwargs],
  551. ) -> tuple[torch.Tensor] | MultipleChoiceModelOutput:
  552. r"""
  553. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  554. Indices of input sequence tokens in the vocabulary.
  555. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  556. [`PreTrainedTokenizer.__call__`] for details.
  557. [What are input IDs?](../glossary#input-ids)
  558. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  559. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  560. 1]`:
  561. - 0 corresponds to a *sentence A* token,
  562. - 1 corresponds to a *sentence B* token.
  563. [What are token type IDs?](../glossary#token-type-ids)
  564. task_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  565. Task type embedding is a special embedding to represent the characteristic of different tasks, such as
  566. word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
  567. assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
  568. config.task_type_vocab_size-1]
  569. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  570. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  571. config.max_position_embeddings - 1]`.
  572. [What are position IDs?](../glossary#position-ids)
  573. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  574. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  575. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  576. model's internal embedding lookup matrix.
  577. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  578. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  579. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  580. `input_ids` above)
  581. """
  582. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  583. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  584. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  585. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  586. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  587. inputs_embeds = (
  588. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  589. if inputs_embeds is not None
  590. else None
  591. )
  592. outputs = self.ernie(
  593. input_ids,
  594. attention_mask=attention_mask,
  595. token_type_ids=token_type_ids,
  596. task_type_ids=task_type_ids,
  597. position_ids=position_ids,
  598. inputs_embeds=inputs_embeds,
  599. return_dict=True,
  600. **kwargs,
  601. )
  602. pooled_output = outputs[1]
  603. pooled_output = self.dropout(pooled_output)
  604. logits = self.classifier(pooled_output)
  605. reshaped_logits = logits.view(-1, num_choices)
  606. loss = None
  607. if labels is not None:
  608. loss_fct = CrossEntropyLoss()
  609. loss = loss_fct(reshaped_logits, labels)
  610. return MultipleChoiceModelOutput(
  611. loss=loss,
  612. logits=reshaped_logits,
  613. hidden_states=outputs.hidden_states,
  614. attentions=outputs.attentions,
  615. )
  616. class ErnieForTokenClassification(BertForTokenClassification):
  617. @can_return_tuple
  618. @auto_docstring
  619. def forward(
  620. self,
  621. input_ids: torch.Tensor | None = None,
  622. attention_mask: torch.Tensor | None = None,
  623. token_type_ids: torch.Tensor | None = None,
  624. task_type_ids: torch.Tensor | None = None,
  625. position_ids: torch.Tensor | None = None,
  626. inputs_embeds: torch.Tensor | None = None,
  627. labels: torch.Tensor | None = None,
  628. **kwargs: Unpack[TransformersKwargs],
  629. ) -> tuple[torch.Tensor] | TokenClassifierOutput:
  630. r"""
  631. task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  632. Task type embedding is a special embedding to represent the characteristic of different tasks, such as
  633. word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
  634. assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
  635. config.task_type_vocab_size-1]
  636. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  637. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  638. """
  639. outputs = self.ernie(
  640. input_ids,
  641. attention_mask=attention_mask,
  642. token_type_ids=token_type_ids,
  643. task_type_ids=task_type_ids,
  644. position_ids=position_ids,
  645. inputs_embeds=inputs_embeds,
  646. return_dict=True,
  647. **kwargs,
  648. )
  649. sequence_output = outputs[0]
  650. sequence_output = self.dropout(sequence_output)
  651. logits = self.classifier(sequence_output)
  652. loss = None
  653. if labels is not None:
  654. loss_fct = CrossEntropyLoss()
  655. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  656. return TokenClassifierOutput(
  657. loss=loss,
  658. logits=logits,
  659. hidden_states=outputs.hidden_states,
  660. attentions=outputs.attentions,
  661. )
  662. class ErnieForQuestionAnswering(BertForQuestionAnswering):
  663. @can_return_tuple
  664. @auto_docstring
  665. def forward(
  666. self,
  667. input_ids: torch.Tensor | None = None,
  668. attention_mask: torch.Tensor | None = None,
  669. token_type_ids: torch.Tensor | None = None,
  670. task_type_ids: torch.Tensor | None = None,
  671. position_ids: torch.Tensor | None = None,
  672. inputs_embeds: torch.Tensor | None = None,
  673. start_positions: torch.Tensor | None = None,
  674. end_positions: torch.Tensor | None = None,
  675. **kwargs: Unpack[TransformersKwargs],
  676. ) -> tuple[torch.Tensor] | QuestionAnsweringModelOutput:
  677. r"""
  678. task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  679. Task type embedding is a special embedding to represent the characteristic of different tasks, such as
  680. word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
  681. assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
  682. config.task_type_vocab_size-1]
  683. """
  684. outputs = self.ernie(
  685. input_ids,
  686. attention_mask=attention_mask,
  687. token_type_ids=token_type_ids,
  688. task_type_ids=task_type_ids,
  689. position_ids=position_ids,
  690. inputs_embeds=inputs_embeds,
  691. return_dict=True,
  692. **kwargs,
  693. )
  694. sequence_output = outputs[0]
  695. logits = self.qa_outputs(sequence_output)
  696. start_logits, end_logits = logits.split(1, dim=-1)
  697. start_logits = start_logits.squeeze(-1).contiguous()
  698. end_logits = end_logits.squeeze(-1).contiguous()
  699. total_loss = None
  700. if start_positions is not None and end_positions is not None:
  701. # If we are on multi-GPU, split add a dimension
  702. if len(start_positions.size()) > 1:
  703. start_positions = start_positions.squeeze(-1)
  704. if len(end_positions.size()) > 1:
  705. end_positions = end_positions.squeeze(-1)
  706. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  707. ignored_index = start_logits.size(1)
  708. start_positions = start_positions.clamp(0, ignored_index)
  709. end_positions = end_positions.clamp(0, ignored_index)
  710. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  711. start_loss = loss_fct(start_logits, start_positions)
  712. end_loss = loss_fct(end_logits, end_positions)
  713. total_loss = (start_loss + end_loss) / 2
  714. return QuestionAnsweringModelOutput(
  715. loss=total_loss,
  716. start_logits=start_logits,
  717. end_logits=end_logits,
  718. hidden_states=outputs.hidden_states,
  719. attentions=outputs.attentions,
  720. )
  721. __all__ = [
  722. "ErnieForCausalLM",
  723. "ErnieForMaskedLM",
  724. "ErnieForMultipleChoice",
  725. "ErnieForNextSentencePrediction",
  726. "ErnieForPreTraining",
  727. "ErnieForQuestionAnswering",
  728. "ErnieForSequenceClassification",
  729. "ErnieForTokenClassification",
  730. "ErnieModel",
  731. "ErniePreTrainedModel",
  732. ]