modeling_albert.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976
  1. # Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch ALBERT model."""
  15. from collections.abc import Callable
  16. from dataclasses import dataclass
  17. import torch
  18. from torch import nn
  19. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...masking_utils import create_bidirectional_mask
  23. from ...modeling_outputs import (
  24. BaseModelOutput,
  25. BaseModelOutputWithPooling,
  26. MaskedLMOutput,
  27. MultipleChoiceModelOutput,
  28. QuestionAnsweringModelOutput,
  29. SequenceClassifierOutput,
  30. TokenClassifierOutput,
  31. )
  32. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  33. from ...processing_utils import Unpack
  34. from ...pytorch_utils import (
  35. apply_chunking_to_forward,
  36. )
  37. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging
  38. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  39. from ...utils.output_capturing import capture_outputs
  40. from .configuration_albert import AlbertConfig
  41. logger = logging.get_logger(__name__)
  42. class AlbertEmbeddings(nn.Module):
  43. """
  44. Construct the embeddings from word, position and token_type embeddings.
  45. """
  46. def __init__(self, config: AlbertConfig):
  47. super().__init__()
  48. self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
  49. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
  50. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
  51. self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
  52. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  53. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  54. self.register_buffer(
  55. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  56. )
  57. self.register_buffer(
  58. "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
  59. )
  60. def forward(
  61. self,
  62. input_ids: torch.LongTensor | None = None,
  63. token_type_ids: torch.LongTensor | None = None,
  64. position_ids: torch.LongTensor | None = None,
  65. inputs_embeds: torch.FloatTensor | None = None,
  66. ) -> torch.Tensor:
  67. if input_ids is not None:
  68. input_shape = input_ids.size()
  69. else:
  70. input_shape = inputs_embeds.size()[:-1]
  71. batch_size, seq_length = input_shape
  72. if position_ids is None:
  73. position_ids = self.position_ids[:, :seq_length]
  74. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  75. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  76. # issue #5664
  77. if token_type_ids is None:
  78. if hasattr(self, "token_type_ids"):
  79. # NOTE: We assume either pos ids to have bsz == 1 (broadcastable) or bsz == effective bsz (input_shape[0])
  80. buffered_token_type_ids = self.token_type_ids.expand(position_ids.shape[0], -1)
  81. buffered_token_type_ids = torch.gather(buffered_token_type_ids, dim=1, index=position_ids)
  82. token_type_ids = buffered_token_type_ids.expand(batch_size, seq_length)
  83. else:
  84. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  85. if inputs_embeds is None:
  86. inputs_embeds = self.word_embeddings(input_ids)
  87. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  88. embeddings = inputs_embeds + token_type_embeddings
  89. position_embeddings = self.position_embeddings(position_ids)
  90. embeddings = embeddings + position_embeddings
  91. embeddings = self.LayerNorm(embeddings)
  92. embeddings = self.dropout(embeddings)
  93. return embeddings
  94. # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
  95. def eager_attention_forward(
  96. module: nn.Module,
  97. query: torch.Tensor,
  98. key: torch.Tensor,
  99. value: torch.Tensor,
  100. attention_mask: torch.Tensor | None,
  101. scaling: float | None = None,
  102. dropout: float = 0.0,
  103. **kwargs: Unpack[TransformersKwargs],
  104. ):
  105. if scaling is None:
  106. scaling = query.size(-1) ** -0.5
  107. # Take the dot product between "query" and "key" to get the raw attention scores.
  108. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  109. if attention_mask is not None:
  110. attn_weights = attn_weights + attention_mask
  111. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  112. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  113. attn_output = torch.matmul(attn_weights, value)
  114. attn_output = attn_output.transpose(1, 2).contiguous()
  115. return attn_output, attn_weights
  116. class AlbertAttention(nn.Module):
  117. def __init__(self, config: AlbertConfig):
  118. super().__init__()
  119. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  120. raise ValueError(
  121. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  122. f"heads ({config.num_attention_heads}"
  123. )
  124. self.config = config
  125. self.num_attention_heads = config.num_attention_heads
  126. self.hidden_size = config.hidden_size
  127. self.attention_head_size = config.hidden_size // config.num_attention_heads
  128. self.all_head_size = self.num_attention_heads * self.attention_head_size
  129. self.scaling = self.attention_head_size**-0.5
  130. self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob)
  131. self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
  132. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  133. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  134. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  135. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  136. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  137. self.is_causal = False
  138. def forward(
  139. self,
  140. hidden_states: torch.Tensor,
  141. attention_mask: torch.FloatTensor | None = None,
  142. **kwargs: Unpack[TransformersKwargs],
  143. ) -> tuple[torch.Tensor, torch.Tensor]:
  144. input_shape = hidden_states.shape[:-1]
  145. hidden_shape = (*input_shape, -1, self.attention_head_size)
  146. # get all proj
  147. query_layer = self.query(hidden_states).view(*hidden_shape).transpose(1, 2)
  148. key_layer = self.key(hidden_states).view(*hidden_shape).transpose(1, 2)
  149. value_layer = self.value(hidden_states).view(*hidden_shape).transpose(1, 2)
  150. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  151. self.config._attn_implementation, eager_attention_forward
  152. )
  153. attn_output, attn_weights = attention_interface(
  154. self,
  155. query_layer,
  156. key_layer,
  157. value_layer,
  158. attention_mask,
  159. dropout=0.0 if not self.training else self.attention_dropout.p,
  160. scaling=self.scaling,
  161. **kwargs,
  162. )
  163. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  164. attn_output = self.dense(attn_output)
  165. attn_output = self.output_dropout(attn_output)
  166. attn_output = self.LayerNorm(hidden_states + attn_output)
  167. return attn_output, attn_weights
  168. class AlbertLayer(nn.Module):
  169. def __init__(self, config: AlbertConfig):
  170. super().__init__()
  171. self.config = config
  172. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  173. self.seq_len_dim = 1
  174. self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  175. self.attention = AlbertAttention(config)
  176. self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
  177. self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
  178. self.activation = ACT2FN[config.hidden_act]
  179. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  180. def forward(
  181. self,
  182. hidden_states: torch.Tensor,
  183. attention_mask: torch.FloatTensor | None = None,
  184. **kwargs: Unpack[TransformersKwargs],
  185. ) -> tuple[torch.Tensor, torch.Tensor]:
  186. attention_output, _ = self.attention(hidden_states, attention_mask, **kwargs)
  187. ffn_output = apply_chunking_to_forward(
  188. self.ff_chunk,
  189. self.chunk_size_feed_forward,
  190. self.seq_len_dim,
  191. attention_output,
  192. )
  193. hidden_states = self.full_layer_layer_norm(ffn_output + attention_output)
  194. return hidden_states
  195. def ff_chunk(self, attention_output: torch.Tensor) -> torch.Tensor:
  196. ffn_output = self.ffn(attention_output)
  197. ffn_output = self.activation(ffn_output)
  198. ffn_output = self.ffn_output(ffn_output)
  199. return ffn_output
  200. class AlbertLayerGroup(nn.Module):
  201. def __init__(self, config: AlbertConfig):
  202. super().__init__()
  203. self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)])
  204. def forward(
  205. self,
  206. hidden_states: torch.Tensor,
  207. attention_mask: torch.FloatTensor | None = None,
  208. **kwargs: Unpack[TransformersKwargs],
  209. ) -> tuple[torch.Tensor | tuple[torch.Tensor], ...]:
  210. for layer_index, albert_layer in enumerate(self.albert_layers):
  211. hidden_states = albert_layer(hidden_states, attention_mask, **kwargs)
  212. return hidden_states
  213. class AlbertTransformer(nn.Module):
  214. def __init__(self, config: AlbertConfig):
  215. super().__init__()
  216. self.config = config
  217. self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size)
  218. self.albert_layer_groups = nn.ModuleList([AlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
  219. def forward(
  220. self,
  221. hidden_states: torch.Tensor,
  222. attention_mask: torch.FloatTensor | None = None,
  223. **kwargs: Unpack[TransformersKwargs],
  224. ) -> BaseModelOutput | tuple:
  225. hidden_states = self.embedding_hidden_mapping_in(hidden_states)
  226. for i in range(self.config.num_hidden_layers):
  227. # Index of the hidden group
  228. group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
  229. hidden_states = self.albert_layer_groups[group_idx](
  230. hidden_states,
  231. attention_mask,
  232. **kwargs,
  233. )
  234. return BaseModelOutput(last_hidden_state=hidden_states)
  235. @auto_docstring
  236. class AlbertPreTrainedModel(PreTrainedModel):
  237. config_class = AlbertConfig
  238. base_model_prefix = "albert"
  239. _supports_flash_attn = True
  240. _supports_sdpa = True
  241. _supports_flex_attn = True
  242. _supports_attention_backend = True
  243. _can_record_outputs = {
  244. "hidden_states": AlbertLayer,
  245. "attentions": AlbertAttention,
  246. }
  247. @torch.no_grad()
  248. def _init_weights(self, module):
  249. """Initialize the weights."""
  250. if isinstance(module, nn.Linear):
  251. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  252. if module.bias is not None:
  253. init.zeros_(module.bias)
  254. elif isinstance(module, nn.Embedding):
  255. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  256. # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
  257. if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
  258. init.zeros_(module.weight[module.padding_idx])
  259. elif isinstance(module, nn.LayerNorm):
  260. init.zeros_(module.bias)
  261. init.ones_(module.weight)
  262. elif isinstance(module, AlbertMLMHead):
  263. init.zeros_(module.bias)
  264. elif isinstance(module, AlbertEmbeddings):
  265. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  266. init.zeros_(module.token_type_ids)
  267. @dataclass
  268. @auto_docstring(
  269. custom_intro="""
  270. Output type of [`AlbertForPreTraining`].
  271. """
  272. )
  273. class AlbertForPreTrainingOutput(ModelOutput):
  274. r"""
  275. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  276. Total loss as the sum of the masked language modeling loss and the next sequence prediction
  277. (classification) loss.
  278. prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  279. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  280. sop_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
  281. Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
  282. before SoftMax).
  283. """
  284. loss: torch.FloatTensor | None = None
  285. prediction_logits: torch.FloatTensor | None = None
  286. sop_logits: torch.FloatTensor | None = None
  287. hidden_states: tuple[torch.FloatTensor] | None = None
  288. attentions: tuple[torch.FloatTensor] | None = None
  289. @auto_docstring
  290. class AlbertModel(AlbertPreTrainedModel):
  291. config_class = AlbertConfig
  292. base_model_prefix = "albert"
  293. def __init__(self, config: AlbertConfig, add_pooling_layer: bool = True):
  294. r"""
  295. add_pooling_layer (bool, *optional*, defaults to `True`):
  296. Whether to add a pooling layer
  297. """
  298. super().__init__(config)
  299. self.config = config
  300. self.embeddings = AlbertEmbeddings(config)
  301. self.encoder = AlbertTransformer(config)
  302. if add_pooling_layer:
  303. self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
  304. self.pooler_activation = nn.Tanh()
  305. else:
  306. self.pooler = None
  307. self.pooler_activation = None
  308. self.attn_implementation = config._attn_implementation
  309. # Initialize weights and apply final processing
  310. self.post_init()
  311. def get_input_embeddings(self) -> nn.Embedding:
  312. return self.embeddings.word_embeddings
  313. def set_input_embeddings(self, value: nn.Embedding) -> None:
  314. self.embeddings.word_embeddings = value
  315. @merge_with_config_defaults
  316. @capture_outputs
  317. @auto_docstring
  318. def forward(
  319. self,
  320. input_ids: torch.LongTensor | None = None,
  321. attention_mask: torch.FloatTensor | None = None,
  322. token_type_ids: torch.LongTensor | None = None,
  323. position_ids: torch.LongTensor | None = None,
  324. inputs_embeds: torch.FloatTensor | None = None,
  325. **kwargs: Unpack[TransformersKwargs],
  326. ) -> BaseModelOutputWithPooling | tuple:
  327. if (input_ids is None) ^ (inputs_embeds is not None):
  328. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  329. embedding_output = self.embeddings(
  330. input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
  331. )
  332. attention_mask = create_bidirectional_mask(
  333. config=self.config,
  334. inputs_embeds=embedding_output,
  335. attention_mask=attention_mask,
  336. )
  337. encoder_outputs = self.encoder(
  338. embedding_output,
  339. attention_mask,
  340. position_ids=position_ids,
  341. **kwargs,
  342. )
  343. sequence_output = encoder_outputs[0]
  344. pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0])) if self.pooler is not None else None
  345. return BaseModelOutputWithPooling(
  346. last_hidden_state=sequence_output,
  347. pooler_output=pooled_output,
  348. )
  349. @auto_docstring(
  350. custom_intro="""
  351. Albert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
  352. `sentence order prediction (classification)` head.
  353. """
  354. )
  355. class AlbertForPreTraining(AlbertPreTrainedModel):
  356. _tied_weights_keys = {
  357. "predictions.decoder.weight": "albert.embeddings.word_embeddings.weight",
  358. "predictions.decoder.bias": "predictions.bias",
  359. }
  360. def __init__(self, config: AlbertConfig):
  361. super().__init__(config)
  362. self.albert = AlbertModel(config)
  363. self.predictions = AlbertMLMHead(config)
  364. self.sop_classifier = AlbertSOPHead(config)
  365. # Initialize weights and apply final processing
  366. self.post_init()
  367. def get_output_embeddings(self) -> nn.Linear:
  368. return self.predictions.decoder
  369. def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
  370. self.predictions.decoder = new_embeddings
  371. def get_input_embeddings(self) -> nn.Embedding:
  372. return self.albert.embeddings.word_embeddings
  373. @can_return_tuple
  374. @auto_docstring
  375. def forward(
  376. self,
  377. input_ids: torch.LongTensor | None = None,
  378. attention_mask: torch.FloatTensor | None = None,
  379. token_type_ids: torch.LongTensor | None = None,
  380. position_ids: torch.LongTensor | None = None,
  381. inputs_embeds: torch.FloatTensor | None = None,
  382. labels: torch.LongTensor | None = None,
  383. sentence_order_label: torch.LongTensor | None = None,
  384. **kwargs: Unpack[TransformersKwargs],
  385. ) -> AlbertForPreTrainingOutput | tuple:
  386. r"""
  387. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  388. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  389. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  390. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  391. sentence_order_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  392. Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
  393. (see `input_ids` docstring) Indices should be in `[0, 1]`. `0` indicates original order (sequence A, then
  394. sequence B), `1` indicates switched order (sequence B, then sequence A).
  395. Example:
  396. ```python
  397. >>> from transformers import AutoTokenizer, AlbertForPreTraining
  398. >>> import torch
  399. >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
  400. >>> model = AlbertForPreTraining.from_pretrained("albert/albert-base-v2")
  401. >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)
  402. >>> # Batch size 1
  403. >>> outputs = model(input_ids)
  404. >>> prediction_logits = outputs.prediction_logits
  405. >>> sop_logits = outputs.sop_logits
  406. ```"""
  407. outputs = self.albert(
  408. input_ids,
  409. attention_mask=attention_mask,
  410. token_type_ids=token_type_ids,
  411. position_ids=position_ids,
  412. inputs_embeds=inputs_embeds,
  413. return_dict=True,
  414. **kwargs,
  415. )
  416. sequence_output, pooled_output = outputs[:2]
  417. prediction_scores = self.predictions(sequence_output)
  418. sop_scores = self.sop_classifier(pooled_output)
  419. total_loss = None
  420. if labels is not None and sentence_order_label is not None:
  421. loss_fct = CrossEntropyLoss()
  422. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  423. sentence_order_loss = loss_fct(sop_scores.view(-1, 2), sentence_order_label.view(-1))
  424. total_loss = masked_lm_loss + sentence_order_loss
  425. return AlbertForPreTrainingOutput(
  426. loss=total_loss,
  427. prediction_logits=prediction_scores,
  428. sop_logits=sop_scores,
  429. hidden_states=outputs.hidden_states,
  430. attentions=outputs.attentions,
  431. )
  432. class AlbertMLMHead(nn.Module):
  433. def __init__(self, config: AlbertConfig):
  434. super().__init__()
  435. self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
  436. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  437. self.dense = nn.Linear(config.hidden_size, config.embedding_size)
  438. self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
  439. self.activation = ACT2FN[config.hidden_act]
  440. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  441. hidden_states = self.dense(hidden_states)
  442. hidden_states = self.activation(hidden_states)
  443. hidden_states = self.LayerNorm(hidden_states)
  444. hidden_states = self.decoder(hidden_states)
  445. prediction_scores = hidden_states
  446. return prediction_scores
  447. class AlbertSOPHead(nn.Module):
  448. def __init__(self, config: AlbertConfig):
  449. super().__init__()
  450. self.dropout = nn.Dropout(config.classifier_dropout_prob)
  451. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  452. def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
  453. dropout_pooled_output = self.dropout(pooled_output)
  454. logits = self.classifier(dropout_pooled_output)
  455. return logits
  456. @auto_docstring
  457. class AlbertForMaskedLM(AlbertPreTrainedModel):
  458. _tied_weights_keys = {
  459. "predictions.decoder.weight": "albert.embeddings.word_embeddings.weight",
  460. "predictions.decoder.bias": "predictions.bias",
  461. }
  462. def __init__(self, config):
  463. super().__init__(config)
  464. self.albert = AlbertModel(config, add_pooling_layer=False)
  465. self.predictions = AlbertMLMHead(config)
  466. # Initialize weights and apply final processing
  467. self.post_init()
  468. def get_output_embeddings(self) -> nn.Linear:
  469. return self.predictions.decoder
  470. def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
  471. self.predictions.decoder = new_embeddings
  472. self.predictions.bias = new_embeddings.bias
  473. def get_input_embeddings(self) -> nn.Embedding:
  474. return self.albert.embeddings.word_embeddings
  475. @can_return_tuple
  476. @auto_docstring
  477. def forward(
  478. self,
  479. input_ids: torch.LongTensor | None = None,
  480. attention_mask: torch.FloatTensor | None = None,
  481. token_type_ids: torch.LongTensor | None = None,
  482. position_ids: torch.LongTensor | None = None,
  483. inputs_embeds: torch.FloatTensor | None = None,
  484. labels: torch.LongTensor | None = None,
  485. **kwargs: Unpack[TransformersKwargs],
  486. ) -> MaskedLMOutput | tuple:
  487. r"""
  488. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  489. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  490. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  491. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  492. Example:
  493. ```python
  494. >>> import torch
  495. >>> from transformers import AutoTokenizer, AlbertForMaskedLM
  496. >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
  497. >>> model = AlbertForMaskedLM.from_pretrained("albert/albert-base-v2")
  498. >>> # add mask_token
  499. >>> inputs = tokenizer("The capital of [MASK] is Paris.", return_tensors="pt")
  500. >>> with torch.no_grad():
  501. ... logits = model(**inputs).logits
  502. >>> # retrieve index of [MASK]
  503. >>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
  504. >>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
  505. >>> tokenizer.decode(predicted_token_id)
  506. 'france'
  507. ```
  508. ```python
  509. >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
  510. >>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
  511. >>> outputs = model(**inputs, labels=labels)
  512. >>> round(outputs.loss.item(), 2)
  513. 0.81
  514. ```
  515. """
  516. outputs = self.albert(
  517. input_ids=input_ids,
  518. attention_mask=attention_mask,
  519. token_type_ids=token_type_ids,
  520. position_ids=position_ids,
  521. inputs_embeds=inputs_embeds,
  522. return_dict=True,
  523. **kwargs,
  524. )
  525. sequence_outputs = outputs[0]
  526. prediction_scores = self.predictions(sequence_outputs)
  527. masked_lm_loss = None
  528. if labels is not None:
  529. loss_fct = CrossEntropyLoss()
  530. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  531. return MaskedLMOutput(
  532. loss=masked_lm_loss,
  533. logits=prediction_scores,
  534. hidden_states=outputs.hidden_states,
  535. attentions=outputs.attentions,
  536. )
  537. @auto_docstring(
  538. custom_intro="""
  539. Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
  540. output) e.g. for GLUE tasks.
  541. """
  542. )
  543. class AlbertForSequenceClassification(AlbertPreTrainedModel):
  544. def __init__(self, config: AlbertConfig):
  545. super().__init__(config)
  546. self.num_labels = config.num_labels
  547. self.config = config
  548. self.albert = AlbertModel(config)
  549. self.dropout = nn.Dropout(config.classifier_dropout_prob)
  550. self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
  551. # Initialize weights and apply final processing
  552. self.post_init()
  553. @can_return_tuple
  554. @auto_docstring
  555. def forward(
  556. self,
  557. input_ids: torch.LongTensor | None = None,
  558. attention_mask: torch.FloatTensor | None = None,
  559. token_type_ids: torch.LongTensor | None = None,
  560. position_ids: torch.LongTensor | None = None,
  561. inputs_embeds: torch.FloatTensor | None = None,
  562. labels: torch.LongTensor | None = None,
  563. **kwargs: Unpack[TransformersKwargs],
  564. ) -> SequenceClassifierOutput | tuple:
  565. r"""
  566. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  567. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  568. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  569. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  570. """
  571. outputs = self.albert(
  572. input_ids=input_ids,
  573. attention_mask=attention_mask,
  574. token_type_ids=token_type_ids,
  575. position_ids=position_ids,
  576. inputs_embeds=inputs_embeds,
  577. return_dict=True,
  578. **kwargs,
  579. )
  580. pooled_output = outputs[1]
  581. pooled_output = self.dropout(pooled_output)
  582. logits = self.classifier(pooled_output)
  583. loss = None
  584. if labels is not None:
  585. if self.config.problem_type is None:
  586. if self.num_labels == 1:
  587. self.config.problem_type = "regression"
  588. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  589. self.config.problem_type = "single_label_classification"
  590. else:
  591. self.config.problem_type = "multi_label_classification"
  592. if self.config.problem_type == "regression":
  593. loss_fct = MSELoss()
  594. if self.num_labels == 1:
  595. loss = loss_fct(logits.squeeze(), labels.squeeze())
  596. else:
  597. loss = loss_fct(logits, labels)
  598. elif self.config.problem_type == "single_label_classification":
  599. loss_fct = CrossEntropyLoss()
  600. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  601. elif self.config.problem_type == "multi_label_classification":
  602. loss_fct = BCEWithLogitsLoss()
  603. loss = loss_fct(logits, labels)
  604. return SequenceClassifierOutput(
  605. loss=loss,
  606. logits=logits,
  607. hidden_states=outputs.hidden_states,
  608. attentions=outputs.attentions,
  609. )
  610. @auto_docstring
  611. class AlbertForTokenClassification(AlbertPreTrainedModel):
  612. def __init__(self, config: AlbertConfig):
  613. super().__init__(config)
  614. self.num_labels = config.num_labels
  615. self.albert = AlbertModel(config, add_pooling_layer=False)
  616. classifier_dropout_prob = (
  617. config.classifier_dropout_prob
  618. if config.classifier_dropout_prob is not None
  619. else config.hidden_dropout_prob
  620. )
  621. self.dropout = nn.Dropout(classifier_dropout_prob)
  622. self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
  623. # Initialize weights and apply final processing
  624. self.post_init()
  625. @can_return_tuple
  626. @auto_docstring
  627. def forward(
  628. self,
  629. input_ids: torch.LongTensor | None = None,
  630. attention_mask: torch.FloatTensor | None = None,
  631. token_type_ids: torch.LongTensor | None = None,
  632. position_ids: torch.LongTensor | None = None,
  633. inputs_embeds: torch.FloatTensor | None = None,
  634. labels: torch.LongTensor | None = None,
  635. **kwargs: Unpack[TransformersKwargs],
  636. ) -> TokenClassifierOutput | tuple:
  637. r"""
  638. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  639. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  640. """
  641. outputs = self.albert(
  642. input_ids,
  643. attention_mask=attention_mask,
  644. token_type_ids=token_type_ids,
  645. position_ids=position_ids,
  646. inputs_embeds=inputs_embeds,
  647. return_dict=True,
  648. **kwargs,
  649. )
  650. sequence_output = outputs[0]
  651. sequence_output = self.dropout(sequence_output)
  652. logits = self.classifier(sequence_output)
  653. loss = None
  654. if labels is not None:
  655. loss_fct = CrossEntropyLoss()
  656. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  657. return TokenClassifierOutput(
  658. loss=loss,
  659. logits=logits,
  660. hidden_states=outputs.hidden_states,
  661. attentions=outputs.attentions,
  662. )
  663. @auto_docstring
  664. class AlbertForQuestionAnswering(AlbertPreTrainedModel):
  665. def __init__(self, config: AlbertConfig):
  666. super().__init__(config)
  667. self.num_labels = config.num_labels
  668. self.albert = AlbertModel(config, add_pooling_layer=False)
  669. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  670. # Initialize weights and apply final processing
  671. self.post_init()
  672. @can_return_tuple
  673. @auto_docstring
  674. def forward(
  675. self,
  676. input_ids: torch.LongTensor | None = None,
  677. attention_mask: torch.FloatTensor | None = None,
  678. token_type_ids: torch.LongTensor | None = None,
  679. position_ids: torch.LongTensor | None = None,
  680. inputs_embeds: torch.FloatTensor | None = None,
  681. start_positions: torch.LongTensor | None = None,
  682. end_positions: torch.LongTensor | None = None,
  683. **kwargs: Unpack[TransformersKwargs],
  684. ) -> AlbertForPreTrainingOutput | tuple:
  685. outputs = self.albert(
  686. input_ids=input_ids,
  687. attention_mask=attention_mask,
  688. token_type_ids=token_type_ids,
  689. position_ids=position_ids,
  690. inputs_embeds=inputs_embeds,
  691. return_dict=True,
  692. **kwargs,
  693. )
  694. sequence_output = outputs[0]
  695. logits: torch.Tensor = self.qa_outputs(sequence_output)
  696. start_logits, end_logits = logits.split(1, dim=-1)
  697. start_logits = start_logits.squeeze(-1).contiguous()
  698. end_logits = end_logits.squeeze(-1).contiguous()
  699. total_loss = None
  700. if start_positions is not None and end_positions is not None:
  701. # If we are on multi-GPU, split add a dimension
  702. if len(start_positions.size()) > 1:
  703. start_positions = start_positions.squeeze(-1)
  704. if len(end_positions.size()) > 1:
  705. end_positions = end_positions.squeeze(-1)
  706. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  707. ignored_index = start_logits.size(1)
  708. start_positions = start_positions.clamp(0, ignored_index)
  709. end_positions = end_positions.clamp(0, ignored_index)
  710. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  711. start_loss = loss_fct(start_logits, start_positions)
  712. end_loss = loss_fct(end_logits, end_positions)
  713. total_loss = (start_loss + end_loss) / 2
  714. return QuestionAnsweringModelOutput(
  715. loss=total_loss,
  716. start_logits=start_logits,
  717. end_logits=end_logits,
  718. hidden_states=outputs.hidden_states,
  719. attentions=outputs.attentions,
  720. )
  721. @auto_docstring
  722. class AlbertForMultipleChoice(AlbertPreTrainedModel):
  723. def __init__(self, config: AlbertConfig):
  724. super().__init__(config)
  725. self.albert = AlbertModel(config)
  726. self.dropout = nn.Dropout(config.classifier_dropout_prob)
  727. self.classifier = nn.Linear(config.hidden_size, 1)
  728. # Initialize weights and apply final processing
  729. self.post_init()
  730. @can_return_tuple
  731. @auto_docstring
  732. def forward(
  733. self,
  734. input_ids: torch.LongTensor | None = None,
  735. attention_mask: torch.FloatTensor | None = None,
  736. token_type_ids: torch.LongTensor | None = None,
  737. position_ids: torch.LongTensor | None = None,
  738. inputs_embeds: torch.FloatTensor | None = None,
  739. labels: torch.LongTensor | None = None,
  740. **kwargs: Unpack[TransformersKwargs],
  741. ) -> AlbertForPreTrainingOutput | tuple:
  742. r"""
  743. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  744. Indices of input sequence tokens in the vocabulary.
  745. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
  746. [`PreTrainedTokenizer.encode`] for details.
  747. [What are input IDs?](../glossary#input-ids)
  748. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  749. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  750. 1]`:
  751. - 0 corresponds to a *sentence A* token,
  752. - 1 corresponds to a *sentence B* token.
  753. [What are token type IDs?](../glossary#token-type-ids)
  754. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  755. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  756. config.max_position_embeddings - 1]`.
  757. [What are position IDs?](../glossary#position-ids)
  758. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  759. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  760. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  761. model's internal embedding lookup matrix.
  762. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  763. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  764. num_choices-1]` where *num_choices* is the size of the second dimension of the input tensors. (see
  765. *input_ids* above)
  766. """
  767. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  768. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  769. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  770. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  771. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  772. inputs_embeds = (
  773. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  774. if inputs_embeds is not None
  775. else None
  776. )
  777. outputs = self.albert(
  778. input_ids,
  779. attention_mask=attention_mask,
  780. token_type_ids=token_type_ids,
  781. position_ids=position_ids,
  782. inputs_embeds=inputs_embeds,
  783. return_dict=True,
  784. **kwargs,
  785. )
  786. pooled_output = outputs[1]
  787. pooled_output = self.dropout(pooled_output)
  788. logits: torch.Tensor = self.classifier(pooled_output)
  789. reshaped_logits = logits.view(-1, num_choices)
  790. loss = None
  791. if labels is not None:
  792. loss_fct = CrossEntropyLoss()
  793. loss = loss_fct(reshaped_logits, labels)
  794. return MultipleChoiceModelOutput(
  795. loss=loss,
  796. logits=reshaped_logits,
  797. hidden_states=outputs.hidden_states,
  798. attentions=outputs.attentions,
  799. )
  800. __all__ = [
  801. "AlbertPreTrainedModel",
  802. "AlbertModel",
  803. "AlbertForPreTraining",
  804. "AlbertForMaskedLM",
  805. "AlbertForSequenceClassification",
  806. "AlbertForTokenClassification",
  807. "AlbertForQuestionAnswering",
  808. "AlbertForMultipleChoice",
  809. ]