modeling_ernie.py 61 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/ernie/modular_ernie.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_ernie.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2022 The HuggingFace Inc. team.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from collections.abc import Callable
  21. from dataclasses import dataclass
  22. import torch
  23. import torch.nn as nn
  24. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  25. from ... import initialization as init
  26. from ...activations import ACT2FN
  27. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  28. from ...generation import GenerationMixin
  29. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import (
  32. BaseModelOutputWithPastAndCrossAttentions,
  33. BaseModelOutputWithPoolingAndCrossAttentions,
  34. CausalLMOutputWithCrossAttentions,
  35. MaskedLMOutput,
  36. MultipleChoiceModelOutput,
  37. NextSentencePredictorOutput,
  38. QuestionAnsweringModelOutput,
  39. SequenceClassifierOutput,
  40. TokenClassifierOutput,
  41. )
  42. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  43. from ...processing_utils import Unpack
  44. from ...pytorch_utils import apply_chunking_to_forward
  45. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging
  46. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  47. from ...utils.output_capturing import capture_outputs
  48. from .configuration_ernie import ErnieConfig
  49. logger = logging.get_logger(__name__)
  50. class ErnieEmbeddings(nn.Module):
  51. """Construct the embeddings from word, position and token_type embeddings."""
  52. def __init__(self, config):
  53. super().__init__()
  54. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  55. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  56. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  57. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  58. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  59. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  60. self.register_buffer(
  61. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  62. )
  63. self.register_buffer(
  64. "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
  65. )
  66. self.use_task_id = config.use_task_id
  67. if config.use_task_id:
  68. self.task_type_embeddings = nn.Embedding(config.task_type_vocab_size, config.hidden_size)
  69. def forward(
  70. self,
  71. input_ids: torch.LongTensor | None = None,
  72. token_type_ids: torch.LongTensor | None = None,
  73. task_type_ids: torch.LongTensor | None = None,
  74. position_ids: torch.LongTensor | None = None,
  75. inputs_embeds: torch.FloatTensor | None = None,
  76. past_key_values_length: int = 0,
  77. ) -> torch.Tensor:
  78. if input_ids is not None:
  79. input_shape = input_ids.size()
  80. else:
  81. input_shape = inputs_embeds.size()[:-1]
  82. batch_size, seq_length = input_shape
  83. if position_ids is None:
  84. position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
  85. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  86. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  87. # issue #5664
  88. if token_type_ids is None:
  89. if hasattr(self, "token_type_ids"):
  90. # NOTE: We assume either pos ids to have bsz == 1 (broadcastable) or bsz == effective bsz (input_shape[0])
  91. buffered_token_type_ids = self.token_type_ids.expand(position_ids.shape[0], -1)
  92. buffered_token_type_ids = torch.gather(buffered_token_type_ids, dim=1, index=position_ids)
  93. token_type_ids = buffered_token_type_ids.expand(batch_size, seq_length)
  94. else:
  95. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  96. if inputs_embeds is None:
  97. inputs_embeds = self.word_embeddings(input_ids)
  98. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  99. # .to is better than using _no_split_modules on ErnieEmbeddings as it's the first module and >1/2 the model size
  100. inputs_embeds = inputs_embeds.to(token_type_embeddings.device)
  101. embeddings = inputs_embeds + token_type_embeddings
  102. position_embeddings = self.position_embeddings(position_ids)
  103. embeddings = embeddings + position_embeddings
  104. # add `task_type_id` for ERNIE model
  105. if self.use_task_id:
  106. if task_type_ids is None:
  107. task_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  108. task_type_embeddings = self.task_type_embeddings(task_type_ids)
  109. embeddings += task_type_embeddings
  110. embeddings = self.LayerNorm(embeddings)
  111. embeddings = self.dropout(embeddings)
  112. return embeddings
  113. def eager_attention_forward(
  114. module: nn.Module,
  115. query: torch.Tensor,
  116. key: torch.Tensor,
  117. value: torch.Tensor,
  118. attention_mask: torch.Tensor | None,
  119. scaling: float | None = None,
  120. dropout: float = 0.0,
  121. **kwargs: Unpack[TransformersKwargs],
  122. ):
  123. if scaling is None:
  124. scaling = query.size(-1) ** -0.5
  125. # Take the dot product between "query" and "key" to get the raw attention scores.
  126. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  127. if attention_mask is not None:
  128. attn_weights = attn_weights + attention_mask
  129. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  130. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  131. attn_output = torch.matmul(attn_weights, value)
  132. attn_output = attn_output.transpose(1, 2).contiguous()
  133. return attn_output, attn_weights
  134. class ErnieSelfAttention(nn.Module):
  135. def __init__(self, config, is_causal=False, layer_idx=None):
  136. super().__init__()
  137. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  138. raise ValueError(
  139. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  140. f"heads ({config.num_attention_heads})"
  141. )
  142. self.config = config
  143. self.num_attention_heads = config.num_attention_heads
  144. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  145. self.all_head_size = self.num_attention_heads * self.attention_head_size
  146. self.scaling = self.attention_head_size**-0.5
  147. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  148. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  149. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  150. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  151. self.is_decoder = config.is_decoder
  152. self.is_causal = is_causal
  153. self.layer_idx = layer_idx
  154. def forward(
  155. self,
  156. hidden_states: torch.Tensor,
  157. attention_mask: torch.FloatTensor | None = None,
  158. past_key_values: Cache | None = None,
  159. **kwargs: Unpack[TransformersKwargs],
  160. ) -> tuple[torch.Tensor]:
  161. input_shape = hidden_states.shape[:-1]
  162. hidden_shape = (*input_shape, -1, self.attention_head_size)
  163. # get all proj
  164. query_layer = self.query(hidden_states).view(*hidden_shape).transpose(1, 2)
  165. key_layer = self.key(hidden_states).view(*hidden_shape).transpose(1, 2)
  166. value_layer = self.value(hidden_states).view(*hidden_shape).transpose(1, 2)
  167. if past_key_values is not None:
  168. # decoder-only ernie can have a simple dynamic cache for example
  169. current_past_key_values = past_key_values
  170. if isinstance(past_key_values, EncoderDecoderCache):
  171. current_past_key_values = past_key_values.self_attention_cache
  172. # save all key/value_layer to cache to be re-used for fast auto-regressive generation
  173. key_layer, value_layer = current_past_key_values.update(key_layer, value_layer, self.layer_idx)
  174. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  175. self.config._attn_implementation, eager_attention_forward
  176. )
  177. attn_output, attn_weights = attention_interface(
  178. self,
  179. query_layer,
  180. key_layer,
  181. value_layer,
  182. attention_mask,
  183. dropout=0.0 if not self.training else self.dropout.p,
  184. scaling=self.scaling,
  185. **kwargs,
  186. )
  187. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  188. return attn_output, attn_weights
  189. class ErnieCrossAttention(nn.Module):
  190. def __init__(self, config, is_causal=False, layer_idx=None):
  191. super().__init__()
  192. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  193. raise ValueError(
  194. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  195. f"heads ({config.num_attention_heads})"
  196. )
  197. self.config = config
  198. self.num_attention_heads = config.num_attention_heads
  199. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  200. self.all_head_size = self.num_attention_heads * self.attention_head_size
  201. self.scaling = self.attention_head_size**-0.5
  202. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  203. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  204. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  205. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  206. self.is_causal = is_causal
  207. self.layer_idx = layer_idx
  208. def forward(
  209. self,
  210. hidden_states: torch.Tensor,
  211. encoder_hidden_states: torch.FloatTensor | None = None,
  212. attention_mask: torch.FloatTensor | None = None,
  213. past_key_values: EncoderDecoderCache | None = None,
  214. **kwargs: Unpack[TransformersKwargs],
  215. ) -> tuple[torch.Tensor]:
  216. # determine input shapes
  217. input_shape = hidden_states.shape[:-1]
  218. hidden_shape = (*input_shape, -1, self.attention_head_size)
  219. # get query proj
  220. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  221. is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
  222. if past_key_values is not None and is_updated:
  223. # reuse k,v, cross_attentions
  224. key_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
  225. value_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].values
  226. else:
  227. kv_shape = (*encoder_hidden_states.shape[:-1], -1, self.attention_head_size)
  228. key_layer = self.key(encoder_hidden_states).view(kv_shape).transpose(1, 2)
  229. value_layer = self.value(encoder_hidden_states).view(kv_shape).transpose(1, 2)
  230. if past_key_values is not None:
  231. # save all states to the cache
  232. key_layer, value_layer = past_key_values.cross_attention_cache.update(
  233. key_layer, value_layer, self.layer_idx
  234. )
  235. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  236. past_key_values.is_updated[self.layer_idx] = True
  237. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  238. self.config._attn_implementation, eager_attention_forward
  239. )
  240. attn_output, attn_weights = attention_interface(
  241. self,
  242. query_layer,
  243. key_layer,
  244. value_layer,
  245. attention_mask,
  246. dropout=0.0 if not self.training else self.dropout.p,
  247. scaling=self.scaling,
  248. **kwargs,
  249. )
  250. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  251. return attn_output, attn_weights
  252. class ErnieSelfOutput(nn.Module):
  253. def __init__(self, config):
  254. super().__init__()
  255. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  256. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  257. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  258. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  259. hidden_states = self.dense(hidden_states)
  260. hidden_states = self.dropout(hidden_states)
  261. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  262. return hidden_states
  263. class ErnieAttention(nn.Module):
  264. def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False):
  265. super().__init__()
  266. self.is_cross_attention = is_cross_attention
  267. attention_class = ErnieCrossAttention if is_cross_attention else ErnieSelfAttention
  268. self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx)
  269. self.output = ErnieSelfOutput(config)
  270. def forward(
  271. self,
  272. hidden_states: torch.Tensor,
  273. attention_mask: torch.FloatTensor | None = None,
  274. encoder_hidden_states: torch.FloatTensor | None = None,
  275. encoder_attention_mask: torch.FloatTensor | None = None,
  276. past_key_values: Cache | None = None,
  277. **kwargs: Unpack[TransformersKwargs],
  278. ) -> tuple[torch.Tensor]:
  279. attention_mask = attention_mask if not self.is_cross_attention else encoder_attention_mask
  280. attention_output, attn_weights = self.self(
  281. hidden_states,
  282. encoder_hidden_states=encoder_hidden_states,
  283. attention_mask=attention_mask,
  284. past_key_values=past_key_values,
  285. **kwargs,
  286. )
  287. attention_output = self.output(attention_output, hidden_states)
  288. return attention_output, attn_weights
  289. class ErnieIntermediate(nn.Module):
  290. def __init__(self, config):
  291. super().__init__()
  292. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  293. if isinstance(config.hidden_act, str):
  294. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  295. else:
  296. self.intermediate_act_fn = config.hidden_act
  297. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  298. hidden_states = self.dense(hidden_states)
  299. hidden_states = self.intermediate_act_fn(hidden_states)
  300. return hidden_states
  301. class ErnieOutput(nn.Module):
  302. def __init__(self, config):
  303. super().__init__()
  304. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  305. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  306. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  307. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  308. hidden_states = self.dense(hidden_states)
  309. hidden_states = self.dropout(hidden_states)
  310. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  311. return hidden_states
  312. class ErnieLayer(GradientCheckpointingLayer):
  313. def __init__(self, config, layer_idx=None):
  314. super().__init__()
  315. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  316. self.seq_len_dim = 1
  317. self.attention = ErnieAttention(config, is_causal=config.is_decoder, layer_idx=layer_idx)
  318. self.is_decoder = config.is_decoder
  319. self.add_cross_attention = config.add_cross_attention
  320. if self.add_cross_attention:
  321. if not self.is_decoder:
  322. raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
  323. self.crossattention = ErnieAttention(
  324. config,
  325. is_causal=False,
  326. layer_idx=layer_idx,
  327. is_cross_attention=True,
  328. )
  329. self.intermediate = ErnieIntermediate(config)
  330. self.output = ErnieOutput(config)
  331. def forward(
  332. self,
  333. hidden_states: torch.Tensor,
  334. attention_mask: torch.FloatTensor | None = None,
  335. encoder_hidden_states: torch.FloatTensor | None = None,
  336. encoder_attention_mask: torch.FloatTensor | None = None,
  337. past_key_values: Cache | None = None,
  338. **kwargs: Unpack[TransformersKwargs],
  339. ) -> torch.Tensor:
  340. self_attention_output, _ = self.attention(
  341. hidden_states,
  342. attention_mask,
  343. past_key_values=past_key_values,
  344. **kwargs,
  345. )
  346. attention_output = self_attention_output
  347. if self.is_decoder and encoder_hidden_states is not None:
  348. if not hasattr(self, "crossattention"):
  349. raise ValueError(
  350. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
  351. " by setting `config.add_cross_attention=True`"
  352. )
  353. cross_attention_output, _ = self.crossattention(
  354. self_attention_output,
  355. None, # attention_mask
  356. encoder_hidden_states,
  357. encoder_attention_mask,
  358. past_key_values=past_key_values,
  359. **kwargs,
  360. )
  361. attention_output = cross_attention_output
  362. layer_output = apply_chunking_to_forward(
  363. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  364. )
  365. return layer_output
  366. def feed_forward_chunk(self, attention_output):
  367. intermediate_output = self.intermediate(attention_output)
  368. layer_output = self.output(intermediate_output, attention_output)
  369. return layer_output
  370. class ErniePooler(nn.Module):
  371. def __init__(self, config):
  372. super().__init__()
  373. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  374. self.activation = nn.Tanh()
  375. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  376. # We "pool" the model by simply taking the hidden state corresponding
  377. # to the first token.
  378. first_token_tensor = hidden_states[:, 0]
  379. pooled_output = self.dense(first_token_tensor)
  380. pooled_output = self.activation(pooled_output)
  381. return pooled_output
  382. class ErniePredictionHeadTransform(nn.Module):
  383. def __init__(self, config):
  384. super().__init__()
  385. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  386. if isinstance(config.hidden_act, str):
  387. self.transform_act_fn = ACT2FN[config.hidden_act]
  388. else:
  389. self.transform_act_fn = config.hidden_act
  390. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  391. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  392. hidden_states = self.dense(hidden_states)
  393. hidden_states = self.transform_act_fn(hidden_states)
  394. hidden_states = self.LayerNorm(hidden_states)
  395. return hidden_states
  396. class ErnieLMPredictionHead(nn.Module):
  397. def __init__(self, config):
  398. super().__init__()
  399. self.transform = ErniePredictionHeadTransform(config)
  400. # The output weights are the same as the input embeddings, but there is
  401. # an output-only bias for each token.
  402. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
  403. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  404. def forward(self, hidden_states):
  405. hidden_states = self.transform(hidden_states)
  406. hidden_states = self.decoder(hidden_states)
  407. return hidden_states
  408. class ErnieEncoder(nn.Module):
  409. def __init__(self, config):
  410. super().__init__()
  411. self.config = config
  412. self.layer = nn.ModuleList([ErnieLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
  413. def forward(
  414. self,
  415. hidden_states: torch.Tensor,
  416. attention_mask: torch.FloatTensor | None = None,
  417. encoder_hidden_states: torch.FloatTensor | None = None,
  418. encoder_attention_mask: torch.FloatTensor | None = None,
  419. past_key_values: Cache | None = None,
  420. use_cache: bool | None = None,
  421. **kwargs: Unpack[TransformersKwargs],
  422. ) -> tuple[torch.Tensor] | BaseModelOutputWithPastAndCrossAttentions:
  423. for i, layer_module in enumerate(self.layer):
  424. hidden_states = layer_module(
  425. hidden_states,
  426. attention_mask,
  427. encoder_hidden_states, # as a positional argument for gradient checkpointing
  428. encoder_attention_mask=encoder_attention_mask,
  429. past_key_values=past_key_values,
  430. **kwargs,
  431. )
  432. return BaseModelOutputWithPastAndCrossAttentions(
  433. last_hidden_state=hidden_states,
  434. past_key_values=past_key_values if use_cache else None,
  435. )
  436. @auto_docstring
  437. class ErniePreTrainedModel(PreTrainedModel):
  438. config_class = ErnieConfig
  439. base_model_prefix = "ernie"
  440. supports_gradient_checkpointing = True
  441. _supports_flash_attn = True
  442. _supports_sdpa = True
  443. _supports_flex_attn = True
  444. _supports_attention_backend = True
  445. _can_record_outputs = {
  446. "hidden_states": ErnieLayer,
  447. "attentions": ErnieSelfAttention,
  448. "cross_attentions": ErnieCrossAttention,
  449. }
  450. @torch.no_grad()
  451. def _init_weights(self, module):
  452. """Initialize the weights"""
  453. super()._init_weights(module)
  454. if isinstance(module, ErnieLMPredictionHead):
  455. init.zeros_(module.bias)
  456. elif isinstance(module, ErnieEmbeddings):
  457. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  458. init.zeros_(module.token_type_ids)
  459. @auto_docstring(
  460. custom_intro="""
  461. The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
  462. cross-attention is added between the self-attention layers, following the architecture described in [Attention is
  463. all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
  464. Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
  465. To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
  466. to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
  467. `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
  468. """
  469. )
  470. class ErnieModel(ErniePreTrainedModel):
  471. _no_split_modules = ["ErnieLayer"]
  472. def __init__(self, config, add_pooling_layer=True):
  473. r"""
  474. add_pooling_layer (bool, *optional*, defaults to `True`):
  475. Whether to add a pooling layer
  476. """
  477. super().__init__(config)
  478. self.config = config
  479. self.gradient_checkpointing = False
  480. self.embeddings = ErnieEmbeddings(config)
  481. self.encoder = ErnieEncoder(config)
  482. self.pooler = ErniePooler(config) if add_pooling_layer else None
  483. # Initialize weights and apply final processing
  484. self.post_init()
  485. def get_input_embeddings(self):
  486. return self.embeddings.word_embeddings
  487. def set_input_embeddings(self, value):
  488. self.embeddings.word_embeddings = value
  489. @merge_with_config_defaults
  490. @capture_outputs
  491. @auto_docstring
  492. def forward(
  493. self,
  494. input_ids: torch.Tensor | None = None,
  495. attention_mask: torch.Tensor | None = None,
  496. token_type_ids: torch.Tensor | None = None,
  497. task_type_ids: torch.Tensor | None = None,
  498. position_ids: torch.Tensor | None = None,
  499. inputs_embeds: torch.Tensor | None = None,
  500. encoder_hidden_states: torch.Tensor | None = None,
  501. encoder_attention_mask: torch.Tensor | None = None,
  502. past_key_values: Cache | None = None,
  503. use_cache: bool | None = None,
  504. **kwargs: Unpack[TransformersKwargs],
  505. ) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions:
  506. r"""
  507. task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  508. Task type embedding is a special embedding to represent the characteristic of different tasks, such as
  509. word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
  510. assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
  511. config.task_type_vocab_size-1]
  512. """
  513. if (input_ids is None) ^ (inputs_embeds is not None):
  514. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  515. if self.config.is_decoder:
  516. use_cache = use_cache if use_cache is not None else self.config.use_cache
  517. else:
  518. use_cache = False
  519. if use_cache and past_key_values is None:
  520. past_key_values = (
  521. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  522. if encoder_hidden_states is not None or self.config.is_encoder_decoder
  523. else DynamicCache(config=self.config)
  524. )
  525. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  526. embedding_output = self.embeddings(
  527. input_ids=input_ids,
  528. position_ids=position_ids,
  529. token_type_ids=token_type_ids,
  530. # specific to ernie
  531. task_type_ids=task_type_ids,
  532. inputs_embeds=inputs_embeds,
  533. past_key_values_length=past_key_values_length,
  534. )
  535. attention_mask, encoder_attention_mask = self._create_attention_masks(
  536. attention_mask=attention_mask,
  537. encoder_attention_mask=encoder_attention_mask,
  538. embedding_output=embedding_output,
  539. encoder_hidden_states=encoder_hidden_states,
  540. past_key_values=past_key_values,
  541. )
  542. encoder_outputs = self.encoder(
  543. embedding_output,
  544. attention_mask=attention_mask,
  545. encoder_hidden_states=encoder_hidden_states,
  546. encoder_attention_mask=encoder_attention_mask,
  547. past_key_values=past_key_values,
  548. use_cache=use_cache,
  549. position_ids=position_ids,
  550. **kwargs,
  551. )
  552. sequence_output = encoder_outputs[0]
  553. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  554. return BaseModelOutputWithPoolingAndCrossAttentions(
  555. last_hidden_state=sequence_output,
  556. pooler_output=pooled_output,
  557. past_key_values=encoder_outputs.past_key_values,
  558. )
  559. def _create_attention_masks(
  560. self,
  561. attention_mask,
  562. encoder_attention_mask,
  563. embedding_output,
  564. encoder_hidden_states,
  565. past_key_values,
  566. ):
  567. if self.config.is_decoder:
  568. attention_mask = create_causal_mask(
  569. config=self.config,
  570. inputs_embeds=embedding_output,
  571. attention_mask=attention_mask,
  572. past_key_values=past_key_values,
  573. )
  574. else:
  575. attention_mask = create_bidirectional_mask(
  576. config=self.config,
  577. inputs_embeds=embedding_output,
  578. attention_mask=attention_mask,
  579. )
  580. if encoder_attention_mask is not None:
  581. encoder_attention_mask = create_bidirectional_mask(
  582. config=self.config,
  583. inputs_embeds=embedding_output,
  584. attention_mask=encoder_attention_mask,
  585. encoder_hidden_states=encoder_hidden_states,
  586. )
  587. return attention_mask, encoder_attention_mask
  588. @dataclass
  589. @auto_docstring(
  590. custom_intro="""
  591. Output type of [`ErnieForPreTraining`].
  592. """
  593. )
  594. class ErnieForPreTrainingOutput(ModelOutput):
  595. r"""
  596. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  597. Total loss as the sum of the masked language modeling loss and the next sequence prediction
  598. (classification) loss.
  599. prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  600. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  601. seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
  602. Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
  603. before SoftMax).
  604. """
  605. loss: torch.FloatTensor | None = None
  606. prediction_logits: torch.FloatTensor | None = None
  607. seq_relationship_logits: torch.FloatTensor | None = None
  608. hidden_states: tuple[torch.FloatTensor] | None = None
  609. attentions: tuple[torch.FloatTensor] | None = None
  610. class ErniePreTrainingHeads(nn.Module):
  611. def __init__(self, config):
  612. super().__init__()
  613. self.predictions = ErnieLMPredictionHead(config)
  614. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  615. def forward(self, sequence_output, pooled_output):
  616. prediction_scores = self.predictions(sequence_output)
  617. seq_relationship_score = self.seq_relationship(pooled_output)
  618. return prediction_scores, seq_relationship_score
  619. @auto_docstring(
  620. custom_intro="""
  621. Ernie Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
  622. sentence prediction (classification)` head.
  623. """
  624. )
  625. class ErnieForPreTraining(ErniePreTrainedModel):
  626. _tied_weights_keys = {
  627. "cls.predictions.decoder.bias": "cls.predictions.bias",
  628. "cls.predictions.decoder.weight": "ernie.embeddings.word_embeddings.weight",
  629. }
  630. def __init__(self, config):
  631. super().__init__(config)
  632. self.ernie = ErnieModel(config)
  633. self.cls = ErniePreTrainingHeads(config)
  634. # Initialize weights and apply final processing
  635. self.post_init()
  636. def get_output_embeddings(self):
  637. return self.cls.predictions.decoder
  638. def set_output_embeddings(self, new_embeddings):
  639. self.cls.predictions.decoder = new_embeddings
  640. self.cls.predictions.bias = new_embeddings.bias
  641. @can_return_tuple
  642. @auto_docstring
  643. def forward(
  644. self,
  645. input_ids: torch.Tensor | None = None,
  646. attention_mask: torch.Tensor | None = None,
  647. token_type_ids: torch.Tensor | None = None,
  648. task_type_ids: torch.Tensor | None = None,
  649. position_ids: torch.Tensor | None = None,
  650. inputs_embeds: torch.Tensor | None = None,
  651. labels: torch.Tensor | None = None,
  652. next_sentence_label: torch.Tensor | None = None,
  653. **kwargs: Unpack[TransformersKwargs],
  654. ) -> tuple[torch.Tensor] | ErnieForPreTrainingOutput:
  655. r"""
  656. task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  657. Task type embedding is a special embedding to represent the characteristic of different tasks, such as
  658. word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
  659. assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
  660. config.task_type_vocab_size-1]
  661. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  662. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  663. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
  664. the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  665. next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  666. Labels for computing the next sequence prediction (classification) loss. Input should be a sequence
  667. pair (see `input_ids` docstring) Indices should be in `[0, 1]`:
  668. - 0 indicates sequence B is a continuation of sequence A,
  669. - 1 indicates sequence B is a random sequence.
  670. Example:
  671. ```python
  672. >>> from transformers import AutoTokenizer, ErnieForPreTraining
  673. >>> import torch
  674. >>> tokenizer = AutoTokenizer.from_pretrained("nghuyong/ernie-1.0-base-zh")
  675. >>> model = ErnieForPreTraining.from_pretrained("nghuyong/ernie-1.0-base-zh")
  676. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  677. >>> outputs = model(**inputs)
  678. >>> prediction_logits = outputs.prediction_logits
  679. >>> seq_relationship_logits = outputs.seq_relationship_logits
  680. ```
  681. """
  682. outputs = self.ernie(
  683. input_ids,
  684. attention_mask=attention_mask,
  685. token_type_ids=token_type_ids,
  686. task_type_ids=task_type_ids,
  687. position_ids=position_ids,
  688. inputs_embeds=inputs_embeds,
  689. return_dict=True,
  690. **kwargs,
  691. )
  692. sequence_output, pooled_output = outputs[:2]
  693. prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
  694. total_loss = None
  695. if labels is not None and next_sentence_label is not None:
  696. loss_fct = CrossEntropyLoss()
  697. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  698. next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
  699. total_loss = masked_lm_loss + next_sentence_loss
  700. return ErnieForPreTrainingOutput(
  701. loss=total_loss,
  702. prediction_logits=prediction_scores,
  703. seq_relationship_logits=seq_relationship_score,
  704. hidden_states=outputs.hidden_states,
  705. attentions=outputs.attentions,
  706. )
  707. class ErnieOnlyMLMHead(nn.Module):
  708. def __init__(self, config):
  709. super().__init__()
  710. self.predictions = ErnieLMPredictionHead(config)
  711. def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
  712. prediction_scores = self.predictions(sequence_output)
  713. return prediction_scores
  714. @auto_docstring(
  715. custom_intro="""
  716. Ernie Model with a `language modeling` head on top for CLM fine-tuning.
  717. """
  718. )
  719. class ErnieForCausalLM(ErniePreTrainedModel, GenerationMixin):
  720. _tied_weights_keys = {
  721. "cls.predictions.decoder.weight": "ernie.embeddings.word_embeddings.weight",
  722. "cls.predictions.decoder.bias": "cls.predictions.bias",
  723. }
  724. def __init__(self, config):
  725. super().__init__(config)
  726. if not config.is_decoder:
  727. logger.warning("If you want to use `ErnieForCausalLM` as a standalone, add `is_decoder=True.`")
  728. self.ernie = ErnieModel(config, add_pooling_layer=False)
  729. self.cls = ErnieOnlyMLMHead(config)
  730. # Initialize weights and apply final processing
  731. self.post_init()
  732. def get_output_embeddings(self):
  733. return self.cls.predictions.decoder
  734. def set_output_embeddings(self, new_embeddings):
  735. self.cls.predictions.decoder = new_embeddings
  736. self.cls.predictions.bias = new_embeddings.bias
  737. @can_return_tuple
  738. @auto_docstring
  739. def forward(
  740. self,
  741. input_ids: torch.Tensor | None = None,
  742. attention_mask: torch.Tensor | None = None,
  743. token_type_ids: torch.Tensor | None = None,
  744. task_type_ids: torch.Tensor | None = None,
  745. position_ids: torch.Tensor | None = None,
  746. inputs_embeds: torch.Tensor | None = None,
  747. encoder_hidden_states: torch.Tensor | None = None,
  748. encoder_attention_mask: torch.Tensor | None = None,
  749. labels: torch.Tensor | None = None,
  750. past_key_values: list[torch.Tensor] | None = None,
  751. use_cache: bool | None = None,
  752. logits_to_keep: int | torch.Tensor = 0,
  753. **kwargs: Unpack[TransformersKwargs],
  754. ) -> tuple[torch.Tensor] | CausalLMOutputWithCrossAttentions:
  755. r"""
  756. task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  757. Task type embedding is a special embedding to represent the characteristic of different tasks, such as
  758. word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
  759. assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
  760. config.task_type_vocab_size-1]
  761. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  762. Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
  763. `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
  764. ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
  765. """
  766. if labels is not None:
  767. use_cache = False
  768. outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.ernie(
  769. input_ids,
  770. attention_mask=attention_mask,
  771. token_type_ids=token_type_ids,
  772. task_type_ids=task_type_ids,
  773. position_ids=position_ids,
  774. inputs_embeds=inputs_embeds,
  775. encoder_hidden_states=encoder_hidden_states,
  776. encoder_attention_mask=encoder_attention_mask,
  777. past_key_values=past_key_values,
  778. use_cache=use_cache,
  779. return_dict=True,
  780. **kwargs,
  781. )
  782. hidden_states = outputs.last_hidden_state
  783. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  784. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  785. logits = self.cls(hidden_states[:, slice_indices, :])
  786. loss = None
  787. if labels is not None:
  788. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  789. return CausalLMOutputWithCrossAttentions(
  790. loss=loss,
  791. logits=logits,
  792. past_key_values=outputs.past_key_values,
  793. hidden_states=outputs.hidden_states,
  794. attentions=outputs.attentions,
  795. cross_attentions=outputs.cross_attentions,
  796. )
  797. @auto_docstring
  798. class ErnieForMaskedLM(ErniePreTrainedModel):
  799. _tied_weights_keys = {
  800. "cls.predictions.decoder.bias": "cls.predictions.bias",
  801. "cls.predictions.decoder.weight": "ernie.embeddings.word_embeddings.weight",
  802. }
  803. def __init__(self, config):
  804. super().__init__(config)
  805. if config.is_decoder:
  806. logger.warning(
  807. "If you want to use `ErnieForMaskedLM` make sure `config.is_decoder=False` for "
  808. "bi-directional self-attention."
  809. )
  810. self.ernie = ErnieModel(config, add_pooling_layer=False)
  811. self.cls = ErnieOnlyMLMHead(config)
  812. # Initialize weights and apply final processing
  813. self.post_init()
  814. def get_output_embeddings(self):
  815. return self.cls.predictions.decoder
  816. def set_output_embeddings(self, new_embeddings):
  817. self.cls.predictions.decoder = new_embeddings
  818. self.cls.predictions.bias = new_embeddings.bias
  819. @can_return_tuple
  820. @auto_docstring
  821. def forward(
  822. self,
  823. input_ids: torch.Tensor | None = None,
  824. attention_mask: torch.Tensor | None = None,
  825. token_type_ids: torch.Tensor | None = None,
  826. task_type_ids: torch.Tensor | None = None,
  827. position_ids: torch.Tensor | None = None,
  828. inputs_embeds: torch.Tensor | None = None,
  829. encoder_hidden_states: torch.Tensor | None = None,
  830. encoder_attention_mask: torch.Tensor | None = None,
  831. labels: torch.Tensor | None = None,
  832. **kwargs: Unpack[TransformersKwargs],
  833. ) -> tuple[torch.Tensor] | MaskedLMOutput:
  834. r"""
  835. task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  836. Task type embedding is a special embedding to represent the characteristic of different tasks, such as
  837. word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
  838. assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
  839. config.task_type_vocab_size-1]
  840. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  841. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  842. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  843. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  844. """
  845. outputs = self.ernie(
  846. input_ids,
  847. attention_mask=attention_mask,
  848. token_type_ids=token_type_ids,
  849. task_type_ids=task_type_ids,
  850. position_ids=position_ids,
  851. inputs_embeds=inputs_embeds,
  852. encoder_hidden_states=encoder_hidden_states,
  853. encoder_attention_mask=encoder_attention_mask,
  854. return_dict=True,
  855. **kwargs,
  856. )
  857. sequence_output = outputs[0]
  858. prediction_scores = self.cls(sequence_output)
  859. masked_lm_loss = None
  860. if labels is not None:
  861. loss_fct = CrossEntropyLoss() # -100 index = padding token
  862. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  863. return MaskedLMOutput(
  864. loss=masked_lm_loss,
  865. logits=prediction_scores,
  866. hidden_states=outputs.hidden_states,
  867. attentions=outputs.attentions,
  868. )
  869. class ErnieOnlyNSPHead(nn.Module):
  870. def __init__(self, config):
  871. super().__init__()
  872. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  873. def forward(self, pooled_output):
  874. seq_relationship_score = self.seq_relationship(pooled_output)
  875. return seq_relationship_score
  876. @auto_docstring(
  877. custom_intro="""
  878. Ernie Model with a `next sentence prediction (classification)` head on top.
  879. """
  880. )
  881. class ErnieForNextSentencePrediction(ErniePreTrainedModel):
  882. def __init__(self, config):
  883. super().__init__(config)
  884. self.ernie = ErnieModel(config)
  885. self.cls = ErnieOnlyNSPHead(config)
  886. # Initialize weights and apply final processing
  887. self.post_init()
  888. @can_return_tuple
  889. @auto_docstring
  890. def forward(
  891. self,
  892. input_ids: torch.Tensor | None = None,
  893. attention_mask: torch.Tensor | None = None,
  894. token_type_ids: torch.Tensor | None = None,
  895. task_type_ids: torch.Tensor | None = None,
  896. position_ids: torch.Tensor | None = None,
  897. inputs_embeds: torch.Tensor | None = None,
  898. labels: torch.Tensor | None = None,
  899. **kwargs: Unpack[TransformersKwargs],
  900. ) -> tuple[torch.Tensor] | NextSentencePredictorOutput:
  901. r"""
  902. task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  903. Task type embedding is a special embedding to represent the characteristic of different tasks, such as
  904. word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
  905. assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
  906. config.task_type_vocab_size-1]
  907. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  908. Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
  909. (see `input_ids` docstring). Indices should be in `[0, 1]`:
  910. - 0 indicates sequence B is a continuation of sequence A,
  911. - 1 indicates sequence B is a random sequence.
  912. Example:
  913. ```python
  914. >>> from transformers import AutoTokenizer, ErnieForNextSentencePrediction
  915. >>> import torch
  916. >>> tokenizer = AutoTokenizer.from_pretrained("nghuyong/ernie-1.0-base-zh")
  917. >>> model = ErnieForNextSentencePrediction.from_pretrained("nghuyong/ernie-1.0-base-zh")
  918. >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
  919. >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
  920. >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
  921. >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
  922. >>> logits = outputs.logits
  923. >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
  924. ```
  925. """
  926. outputs = self.ernie(
  927. input_ids,
  928. attention_mask=attention_mask,
  929. token_type_ids=token_type_ids,
  930. task_type_ids=task_type_ids,
  931. position_ids=position_ids,
  932. inputs_embeds=inputs_embeds,
  933. return_dict=True,
  934. **kwargs,
  935. )
  936. pooled_output = outputs[1]
  937. seq_relationship_scores = self.cls(pooled_output)
  938. next_sentence_loss = None
  939. if labels is not None:
  940. loss_fct = CrossEntropyLoss()
  941. next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
  942. return NextSentencePredictorOutput(
  943. loss=next_sentence_loss,
  944. logits=seq_relationship_scores,
  945. hidden_states=outputs.hidden_states,
  946. attentions=outputs.attentions,
  947. )
  948. @auto_docstring(
  949. custom_intro="""
  950. Ernie Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
  951. output) e.g. for GLUE tasks.
  952. """
  953. )
  954. class ErnieForSequenceClassification(ErniePreTrainedModel):
  955. def __init__(self, config):
  956. super().__init__(config)
  957. self.num_labels = config.num_labels
  958. self.config = config
  959. self.ernie = ErnieModel(config)
  960. classifier_dropout = (
  961. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  962. )
  963. self.dropout = nn.Dropout(classifier_dropout)
  964. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  965. # Initialize weights and apply final processing
  966. self.post_init()
  967. @can_return_tuple
  968. @auto_docstring
  969. def forward(
  970. self,
  971. input_ids: torch.Tensor | None = None,
  972. attention_mask: torch.Tensor | None = None,
  973. token_type_ids: torch.Tensor | None = None,
  974. task_type_ids: torch.Tensor | None = None,
  975. position_ids: torch.Tensor | None = None,
  976. inputs_embeds: torch.Tensor | None = None,
  977. labels: torch.Tensor | None = None,
  978. **kwargs: Unpack[TransformersKwargs],
  979. ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
  980. r"""
  981. task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  982. Task type embedding is a special embedding to represent the characteristic of different tasks, such as
  983. word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
  984. assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
  985. config.task_type_vocab_size-1]
  986. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  987. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  988. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  989. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  990. """
  991. outputs = self.ernie(
  992. input_ids,
  993. attention_mask=attention_mask,
  994. token_type_ids=token_type_ids,
  995. task_type_ids=task_type_ids,
  996. position_ids=position_ids,
  997. inputs_embeds=inputs_embeds,
  998. return_dict=True,
  999. **kwargs,
  1000. )
  1001. pooled_output = outputs[1]
  1002. pooled_output = self.dropout(pooled_output)
  1003. logits = self.classifier(pooled_output)
  1004. loss = None
  1005. if labels is not None:
  1006. if self.config.problem_type is None:
  1007. if self.num_labels == 1:
  1008. self.config.problem_type = "regression"
  1009. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1010. self.config.problem_type = "single_label_classification"
  1011. else:
  1012. self.config.problem_type = "multi_label_classification"
  1013. if self.config.problem_type == "regression":
  1014. loss_fct = MSELoss()
  1015. if self.num_labels == 1:
  1016. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1017. else:
  1018. loss = loss_fct(logits, labels)
  1019. elif self.config.problem_type == "single_label_classification":
  1020. loss_fct = CrossEntropyLoss()
  1021. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1022. elif self.config.problem_type == "multi_label_classification":
  1023. loss_fct = BCEWithLogitsLoss()
  1024. loss = loss_fct(logits, labels)
  1025. return SequenceClassifierOutput(
  1026. loss=loss,
  1027. logits=logits,
  1028. hidden_states=outputs.hidden_states,
  1029. attentions=outputs.attentions,
  1030. )
  1031. @auto_docstring
  1032. class ErnieForMultipleChoice(ErniePreTrainedModel):
  1033. def __init__(self, config):
  1034. super().__init__(config)
  1035. self.ernie = ErnieModel(config)
  1036. classifier_dropout = (
  1037. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1038. )
  1039. self.dropout = nn.Dropout(classifier_dropout)
  1040. self.classifier = nn.Linear(config.hidden_size, 1)
  1041. # Initialize weights and apply final processing
  1042. self.post_init()
  1043. @can_return_tuple
  1044. @auto_docstring
  1045. def forward(
  1046. self,
  1047. input_ids: torch.Tensor | None = None,
  1048. attention_mask: torch.Tensor | None = None,
  1049. token_type_ids: torch.Tensor | None = None,
  1050. task_type_ids: torch.Tensor | None = None,
  1051. position_ids: torch.Tensor | None = None,
  1052. inputs_embeds: torch.Tensor | None = None,
  1053. labels: torch.Tensor | None = None,
  1054. **kwargs: Unpack[TransformersKwargs],
  1055. ) -> tuple[torch.Tensor] | MultipleChoiceModelOutput:
  1056. r"""
  1057. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  1058. Indices of input sequence tokens in the vocabulary.
  1059. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1060. [`PreTrainedTokenizer.__call__`] for details.
  1061. [What are input IDs?](../glossary#input-ids)
  1062. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1063. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  1064. 1]`:
  1065. - 0 corresponds to a *sentence A* token,
  1066. - 1 corresponds to a *sentence B* token.
  1067. [What are token type IDs?](../glossary#token-type-ids)
  1068. task_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1069. Task type embedding is a special embedding to represent the characteristic of different tasks, such as
  1070. word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
  1071. assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
  1072. config.task_type_vocab_size-1]
  1073. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1074. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  1075. config.max_position_embeddings - 1]`.
  1076. [What are position IDs?](../glossary#position-ids)
  1077. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  1078. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  1079. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  1080. model's internal embedding lookup matrix.
  1081. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1082. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  1083. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  1084. `input_ids` above)
  1085. """
  1086. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  1087. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  1088. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  1089. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  1090. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  1091. inputs_embeds = (
  1092. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  1093. if inputs_embeds is not None
  1094. else None
  1095. )
  1096. outputs = self.ernie(
  1097. input_ids,
  1098. attention_mask=attention_mask,
  1099. token_type_ids=token_type_ids,
  1100. task_type_ids=task_type_ids,
  1101. position_ids=position_ids,
  1102. inputs_embeds=inputs_embeds,
  1103. return_dict=True,
  1104. **kwargs,
  1105. )
  1106. pooled_output = outputs[1]
  1107. pooled_output = self.dropout(pooled_output)
  1108. logits = self.classifier(pooled_output)
  1109. reshaped_logits = logits.view(-1, num_choices)
  1110. loss = None
  1111. if labels is not None:
  1112. loss_fct = CrossEntropyLoss()
  1113. loss = loss_fct(reshaped_logits, labels)
  1114. return MultipleChoiceModelOutput(
  1115. loss=loss,
  1116. logits=reshaped_logits,
  1117. hidden_states=outputs.hidden_states,
  1118. attentions=outputs.attentions,
  1119. )
  1120. @auto_docstring
  1121. class ErnieForTokenClassification(ErniePreTrainedModel):
  1122. def __init__(self, config):
  1123. super().__init__(config)
  1124. self.num_labels = config.num_labels
  1125. self.ernie = ErnieModel(config, add_pooling_layer=False)
  1126. classifier_dropout = (
  1127. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1128. )
  1129. self.dropout = nn.Dropout(classifier_dropout)
  1130. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1131. # Initialize weights and apply final processing
  1132. self.post_init()
  1133. @can_return_tuple
  1134. @auto_docstring
  1135. def forward(
  1136. self,
  1137. input_ids: torch.Tensor | None = None,
  1138. attention_mask: torch.Tensor | None = None,
  1139. token_type_ids: torch.Tensor | None = None,
  1140. task_type_ids: torch.Tensor | None = None,
  1141. position_ids: torch.Tensor | None = None,
  1142. inputs_embeds: torch.Tensor | None = None,
  1143. labels: torch.Tensor | None = None,
  1144. **kwargs: Unpack[TransformersKwargs],
  1145. ) -> tuple[torch.Tensor] | TokenClassifierOutput:
  1146. r"""
  1147. task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1148. Task type embedding is a special embedding to represent the characteristic of different tasks, such as
  1149. word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
  1150. assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
  1151. config.task_type_vocab_size-1]
  1152. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1153. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1154. """
  1155. outputs = self.ernie(
  1156. input_ids,
  1157. attention_mask=attention_mask,
  1158. token_type_ids=token_type_ids,
  1159. task_type_ids=task_type_ids,
  1160. position_ids=position_ids,
  1161. inputs_embeds=inputs_embeds,
  1162. return_dict=True,
  1163. **kwargs,
  1164. )
  1165. sequence_output = outputs[0]
  1166. sequence_output = self.dropout(sequence_output)
  1167. logits = self.classifier(sequence_output)
  1168. loss = None
  1169. if labels is not None:
  1170. loss_fct = CrossEntropyLoss()
  1171. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1172. return TokenClassifierOutput(
  1173. loss=loss,
  1174. logits=logits,
  1175. hidden_states=outputs.hidden_states,
  1176. attentions=outputs.attentions,
  1177. )
  1178. @auto_docstring
  1179. class ErnieForQuestionAnswering(ErniePreTrainedModel):
  1180. def __init__(self, config):
  1181. super().__init__(config)
  1182. self.num_labels = config.num_labels
  1183. self.ernie = ErnieModel(config, add_pooling_layer=False)
  1184. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1185. # Initialize weights and apply final processing
  1186. self.post_init()
  1187. @can_return_tuple
  1188. @auto_docstring
  1189. def forward(
  1190. self,
  1191. input_ids: torch.Tensor | None = None,
  1192. attention_mask: torch.Tensor | None = None,
  1193. token_type_ids: torch.Tensor | None = None,
  1194. task_type_ids: torch.Tensor | None = None,
  1195. position_ids: torch.Tensor | None = None,
  1196. inputs_embeds: torch.Tensor | None = None,
  1197. start_positions: torch.Tensor | None = None,
  1198. end_positions: torch.Tensor | None = None,
  1199. **kwargs: Unpack[TransformersKwargs],
  1200. ) -> tuple[torch.Tensor] | QuestionAnsweringModelOutput:
  1201. r"""
  1202. task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1203. Task type embedding is a special embedding to represent the characteristic of different tasks, such as
  1204. word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
  1205. assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
  1206. config.task_type_vocab_size-1]
  1207. """
  1208. outputs = self.ernie(
  1209. input_ids,
  1210. attention_mask=attention_mask,
  1211. token_type_ids=token_type_ids,
  1212. task_type_ids=task_type_ids,
  1213. position_ids=position_ids,
  1214. inputs_embeds=inputs_embeds,
  1215. return_dict=True,
  1216. **kwargs,
  1217. )
  1218. sequence_output = outputs[0]
  1219. logits = self.qa_outputs(sequence_output)
  1220. start_logits, end_logits = logits.split(1, dim=-1)
  1221. start_logits = start_logits.squeeze(-1).contiguous()
  1222. end_logits = end_logits.squeeze(-1).contiguous()
  1223. total_loss = None
  1224. if start_positions is not None and end_positions is not None:
  1225. # If we are on multi-GPU, split add a dimension
  1226. if len(start_positions.size()) > 1:
  1227. start_positions = start_positions.squeeze(-1)
  1228. if len(end_positions.size()) > 1:
  1229. end_positions = end_positions.squeeze(-1)
  1230. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1231. ignored_index = start_logits.size(1)
  1232. start_positions = start_positions.clamp(0, ignored_index)
  1233. end_positions = end_positions.clamp(0, ignored_index)
  1234. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1235. start_loss = loss_fct(start_logits, start_positions)
  1236. end_loss = loss_fct(end_logits, end_positions)
  1237. total_loss = (start_loss + end_loss) / 2
  1238. return QuestionAnsweringModelOutput(
  1239. loss=total_loss,
  1240. start_logits=start_logits,
  1241. end_logits=end_logits,
  1242. hidden_states=outputs.hidden_states,
  1243. attentions=outputs.attentions,
  1244. )
  1245. __all__ = [
  1246. "ErnieForCausalLM",
  1247. "ErnieForMaskedLM",
  1248. "ErnieForMultipleChoice",
  1249. "ErnieForNextSentencePrediction",
  1250. "ErnieForPreTraining",
  1251. "ErnieForQuestionAnswering",
  1252. "ErnieForSequenceClassification",
  1253. "ErnieForTokenClassification",
  1254. "ErnieModel",
  1255. "ErniePreTrainedModel",
  1256. ]