modeling_distilbert.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942
  1. # Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. PyTorch DistilBERT model adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) and in
  16. part from HuggingFace PyTorch version of Google AI Bert model (https://github.com/google-research/bert)
  17. """
  18. from collections.abc import Callable
  19. import numpy as np
  20. import torch
  21. from torch import nn
  22. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  23. from ... import initialization as init
  24. from ...activations import get_activation
  25. from ...configuration_utils import PreTrainedConfig
  26. from ...integrations.deepspeed import is_deepspeed_zero3_enabled
  27. from ...masking_utils import create_bidirectional_mask
  28. from ...modeling_layers import GradientCheckpointingLayer
  29. from ...modeling_outputs import (
  30. BaseModelOutput,
  31. MaskedLMOutput,
  32. MultipleChoiceModelOutput,
  33. QuestionAnsweringModelOutput,
  34. SequenceClassifierOutput,
  35. TokenClassifierOutput,
  36. )
  37. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  38. from ...processing_utils import Unpack
  39. from ...pytorch_utils import (
  40. apply_chunking_to_forward,
  41. )
  42. from ...utils import (
  43. TransformersKwargs,
  44. auto_docstring,
  45. logging,
  46. )
  47. from ...utils.deprecation import deprecate_kwarg
  48. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  49. from ...utils.output_capturing import capture_outputs
  50. from .configuration_distilbert import DistilBertConfig
  51. logger = logging.get_logger(__name__)
  52. # UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
  53. def create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor):
  54. if is_deepspeed_zero3_enabled():
  55. import deepspeed
  56. with deepspeed.zero.GatheredParameters(out, modifier_rank=0):
  57. if torch.distributed.get_rank() == 0:
  58. return _create_sinusoidal_embeddings(n_pos=n_pos, dim=dim, out=out)
  59. else:
  60. return _create_sinusoidal_embeddings(n_pos=n_pos, dim=dim, out=out)
  61. def _create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor):
  62. position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
  63. out.requires_grad = False
  64. out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
  65. out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
  66. out.detach_()
  67. return out
  68. class Embeddings(nn.Module):
  69. def __init__(self, config: PreTrainedConfig):
  70. super().__init__()
  71. self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)
  72. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)
  73. self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
  74. self.dropout = nn.Dropout(config.dropout)
  75. self.register_buffer(
  76. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  77. )
  78. @deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
  79. def forward(
  80. self,
  81. input_ids: torch.Tensor,
  82. inputs_embeds: torch.Tensor | None = None,
  83. position_ids: torch.LongTensor | None = None,
  84. ) -> torch.Tensor:
  85. if input_ids is not None:
  86. inputs_embeds = self.word_embeddings(input_ids) # (bs, max_seq_length, dim)
  87. seq_length = inputs_embeds.size(1)
  88. if position_ids is None:
  89. # Setting the position-ids to the registered buffer in constructor, it helps
  90. # when tracing the model without passing position-ids, solves
  91. # issues similar to issue #5664
  92. if hasattr(self, "position_ids"):
  93. position_ids = self.position_ids[:, :seq_length]
  94. else:
  95. position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
  96. position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
  97. position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
  98. embeddings = inputs_embeds + position_embeddings # (bs, max_seq_length, dim)
  99. embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim)
  100. embeddings = self.dropout(embeddings) # (bs, max_seq_length, dim)
  101. return embeddings
  102. # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
  103. def eager_attention_forward(
  104. module: nn.Module,
  105. query: torch.Tensor,
  106. key: torch.Tensor,
  107. value: torch.Tensor,
  108. attention_mask: torch.Tensor | None,
  109. scaling: float | None = None,
  110. dropout: float = 0.0,
  111. **kwargs: Unpack[TransformersKwargs],
  112. ):
  113. if scaling is None:
  114. scaling = query.size(-1) ** -0.5
  115. # Take the dot product between "query" and "key" to get the raw attention scores.
  116. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  117. if attention_mask is not None:
  118. attn_weights = attn_weights + attention_mask
  119. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  120. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  121. attn_output = torch.matmul(attn_weights, value)
  122. attn_output = attn_output.transpose(1, 2).contiguous()
  123. return attn_output, attn_weights
  124. class DistilBertSelfAttention(nn.Module):
  125. def __init__(self, config: PreTrainedConfig):
  126. super().__init__()
  127. self.config = config
  128. self.n_heads = config.n_heads
  129. self.dim = config.dim
  130. self.attention_head_size = self.dim // self.n_heads
  131. self.scaling = self.attention_head_size**-0.5
  132. # Have an even number of multi heads that divide the dimensions
  133. if self.dim % self.n_heads != 0:
  134. # Raise value errors for even multi-head attention nodes
  135. raise ValueError(f"self.n_heads: {self.n_heads} must divide self.dim: {self.dim} evenly")
  136. self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
  137. self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
  138. self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
  139. self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
  140. self.dropout = nn.Dropout(p=config.attention_dropout)
  141. self.is_causal = False
  142. def forward(
  143. self,
  144. hidden_states: torch.Tensor,
  145. attention_mask: torch.FloatTensor | None = None,
  146. **kwargs: Unpack[TransformersKwargs],
  147. ) -> tuple[torch.Tensor]:
  148. input_shape = hidden_states.shape[:-1]
  149. hidden_shape = (*input_shape, -1, self.attention_head_size)
  150. # get all proj
  151. query_layer = self.q_lin(hidden_states).view(*hidden_shape).transpose(1, 2)
  152. key_layer = self.k_lin(hidden_states).view(*hidden_shape).transpose(1, 2)
  153. value_layer = self.v_lin(hidden_states).view(*hidden_shape).transpose(1, 2)
  154. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  155. self.config._attn_implementation, eager_attention_forward
  156. )
  157. attn_output, attn_weights = attention_interface(
  158. self,
  159. query_layer,
  160. key_layer,
  161. value_layer,
  162. attention_mask,
  163. dropout=0.0 if not self.training else self.dropout.p,
  164. scaling=self.scaling,
  165. **kwargs,
  166. )
  167. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  168. attn_output = self.out_lin(attn_output)
  169. return attn_output, attn_weights
  170. class FFN(nn.Module):
  171. def __init__(self, config: PreTrainedConfig):
  172. super().__init__()
  173. self.dropout = nn.Dropout(p=config.dropout)
  174. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  175. self.seq_len_dim = 1
  176. self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)
  177. self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)
  178. self.activation = get_activation(config.activation)
  179. def forward(self, input: torch.Tensor) -> torch.Tensor:
  180. return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)
  181. def ff_chunk(self, input: torch.Tensor) -> torch.Tensor:
  182. x = self.lin1(input)
  183. x = self.activation(x)
  184. x = self.lin2(x)
  185. x = self.dropout(x)
  186. return x
  187. class TransformerBlock(GradientCheckpointingLayer):
  188. def __init__(self, config: PreTrainedConfig):
  189. super().__init__()
  190. # Have an even number of Configure multi-heads
  191. if config.dim % config.n_heads != 0:
  192. raise ValueError(f"config.n_heads {config.n_heads} must divide config.dim {config.dim} evenly")
  193. self.attention = DistilBertSelfAttention(config)
  194. self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
  195. self.ffn = FFN(config)
  196. self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
  197. def forward(
  198. self,
  199. hidden_states: torch.Tensor,
  200. attention_mask: torch.Tensor | None = None,
  201. **kwargs: Unpack[TransformersKwargs],
  202. ) -> tuple[torch.Tensor, ...]:
  203. # Self-Attention
  204. attention_output, _ = self.attention(
  205. hidden_states,
  206. attention_mask=attention_mask,
  207. **kwargs,
  208. )
  209. attention_output = self.sa_layer_norm(attention_output + hidden_states)
  210. # Feed Forward Network
  211. ffn_output = self.ffn(attention_output)
  212. ffn_output = self.output_layer_norm(ffn_output + attention_output)
  213. return ffn_output
  214. class Transformer(nn.Module):
  215. def __init__(self, config: PreTrainedConfig):
  216. super().__init__()
  217. self.n_layers = config.n_layers
  218. self.layer = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
  219. self.gradient_checkpointing = False
  220. def forward(
  221. self,
  222. hidden_states: torch.Tensor,
  223. attention_mask: torch.Tensor | None = None,
  224. **kwargs: Unpack[TransformersKwargs],
  225. ) -> BaseModelOutput:
  226. for layer_module in self.layer:
  227. hidden_states = layer_module(
  228. hidden_states,
  229. attention_mask,
  230. **kwargs,
  231. )
  232. return BaseModelOutput(last_hidden_state=hidden_states)
  233. # INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
  234. @auto_docstring
  235. class DistilBertPreTrainedModel(PreTrainedModel):
  236. config: DistilBertConfig
  237. base_model_prefix = "distilbert"
  238. supports_gradient_checkpointing = True
  239. _supports_flash_attn = True
  240. _supports_sdpa = True
  241. _supports_flex_attn = True
  242. _supports_attention_backend = True
  243. _can_record_outputs = {
  244. "hidden_states": TransformerBlock,
  245. "attentions": DistilBertSelfAttention,
  246. }
  247. @torch.no_grad()
  248. def _init_weights(self, module: nn.Module):
  249. """Initialize the weights."""
  250. super()._init_weights(module)
  251. if isinstance(module, Embeddings):
  252. if self.config.sinusoidal_pos_embds:
  253. init.copy_(
  254. module.position_embeddings.weight,
  255. create_sinusoidal_embeddings(
  256. self.config.max_position_embeddings,
  257. self.config.dim,
  258. torch.empty_like(module.position_embeddings.weight),
  259. ),
  260. )
  261. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  262. @auto_docstring
  263. class DistilBertModel(DistilBertPreTrainedModel):
  264. def __init__(self, config: PreTrainedConfig):
  265. super().__init__(config)
  266. self.embeddings = Embeddings(config) # Embeddings
  267. self.transformer = Transformer(config) # Encoder
  268. # Initialize weights and apply final processing
  269. self.post_init()
  270. def get_position_embeddings(self) -> nn.Embedding:
  271. """
  272. Returns the position embeddings
  273. """
  274. return self.embeddings.position_embeddings
  275. def resize_position_embeddings(self, new_num_position_embeddings: int):
  276. """
  277. Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`.
  278. Arguments:
  279. new_num_position_embeddings (`int`):
  280. The number of new position embedding matrix. If position embeddings are learned, increasing the size
  281. will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
  282. end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
  283. size will add correct vectors at the end following the position encoding algorithm, whereas reducing
  284. the size will remove vectors from the end.
  285. """
  286. num_position_embeds_diff = new_num_position_embeddings - self.config.max_position_embeddings
  287. # no resizing needs to be done if the length stays the same
  288. if num_position_embeds_diff == 0:
  289. return
  290. logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
  291. self.config.max_position_embeddings = new_num_position_embeddings
  292. old_position_embeddings_weight = self.embeddings.position_embeddings.weight.clone()
  293. self.embeddings.position_embeddings = nn.Embedding(self.config.max_position_embeddings, self.config.dim)
  294. if self.config.sinusoidal_pos_embds:
  295. create_sinusoidal_embeddings(
  296. n_pos=self.config.max_position_embeddings, dim=self.config.dim, out=self.position_embeddings.weight
  297. )
  298. else:
  299. with torch.no_grad():
  300. if num_position_embeds_diff > 0:
  301. self.embeddings.position_embeddings.weight[:-num_position_embeds_diff] = nn.Parameter(
  302. old_position_embeddings_weight
  303. )
  304. else:
  305. self.embeddings.position_embeddings.weight = nn.Parameter(
  306. old_position_embeddings_weight[:num_position_embeds_diff]
  307. )
  308. # move position_embeddings to correct device
  309. self.embeddings.position_embeddings.to(self.device)
  310. def get_input_embeddings(self) -> nn.Embedding:
  311. return self.embeddings.word_embeddings
  312. def set_input_embeddings(self, new_embeddings: nn.Embedding):
  313. self.embeddings.word_embeddings = new_embeddings
  314. @merge_with_config_defaults
  315. @capture_outputs
  316. @auto_docstring
  317. def forward(
  318. self,
  319. input_ids: torch.Tensor | None = None,
  320. attention_mask: torch.Tensor | None = None,
  321. inputs_embeds: torch.Tensor | None = None,
  322. position_ids: torch.Tensor | None = None,
  323. **kwargs: Unpack[TransformersKwargs],
  324. ) -> BaseModelOutput | tuple[torch.Tensor, ...]:
  325. r"""
  326. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`):
  327. Indices of input sequence tokens in the vocabulary.
  328. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  329. [`PreTrainedTokenizer.__call__`] for details.
  330. [What are input IDs?](../glossary#input-ids)
  331. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, hidden_size)`, *optional*):
  332. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  333. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  334. model's internal embedding lookup matrix.
  335. """
  336. if (input_ids is None) ^ (inputs_embeds is not None):
  337. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  338. embeddings = self.embeddings(input_ids, inputs_embeds, position_ids)
  339. attention_mask = create_bidirectional_mask(
  340. config=self.config,
  341. inputs_embeds=embeddings,
  342. attention_mask=attention_mask,
  343. )
  344. return self.transformer(
  345. hidden_states=embeddings,
  346. attention_mask=attention_mask,
  347. **kwargs,
  348. )
  349. @auto_docstring(
  350. custom_intro="""
  351. DistilBert Model with a `masked language modeling` head on top.
  352. """
  353. )
  354. class DistilBertForMaskedLM(DistilBertPreTrainedModel):
  355. _tied_weights_keys = {"vocab_projector.weight": "distilbert.embeddings.word_embeddings.weight"}
  356. def __init__(self, config: PreTrainedConfig):
  357. super().__init__(config)
  358. self.activation = get_activation(config.activation)
  359. self.distilbert = DistilBertModel(config)
  360. self.vocab_transform = nn.Linear(config.dim, config.dim)
  361. self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
  362. self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
  363. # Initialize weights and apply final processing
  364. self.post_init()
  365. self.mlm_loss_fct = nn.CrossEntropyLoss()
  366. def get_position_embeddings(self) -> nn.Embedding:
  367. """
  368. Returns the position embeddings
  369. """
  370. return self.distilbert.get_position_embeddings()
  371. def resize_position_embeddings(self, new_num_position_embeddings: int):
  372. """
  373. Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`.
  374. Arguments:
  375. new_num_position_embeddings (`int`):
  376. The number of new position embedding matrix. If position embeddings are learned, increasing the size
  377. will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
  378. end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
  379. size will add correct vectors at the end following the position encoding algorithm, whereas reducing
  380. the size will remove vectors from the end.
  381. """
  382. self.distilbert.resize_position_embeddings(new_num_position_embeddings)
  383. def get_output_embeddings(self) -> nn.Module:
  384. return self.vocab_projector
  385. def set_output_embeddings(self, new_embeddings: nn.Module):
  386. self.vocab_projector = new_embeddings
  387. @can_return_tuple
  388. @auto_docstring
  389. def forward(
  390. self,
  391. input_ids: torch.Tensor | None = None,
  392. attention_mask: torch.Tensor | None = None,
  393. inputs_embeds: torch.Tensor | None = None,
  394. labels: torch.LongTensor | None = None,
  395. position_ids: torch.Tensor | None = None,
  396. **kwargs: Unpack[TransformersKwargs],
  397. ) -> MaskedLMOutput | tuple[torch.Tensor, ...]:
  398. r"""
  399. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`):
  400. Indices of input sequence tokens in the vocabulary.
  401. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  402. [`PreTrainedTokenizer.__call__`] for details.
  403. [What are input IDs?](../glossary#input-ids)
  404. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, hidden_size)`, *optional*):
  405. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  406. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  407. model's internal embedding lookup matrix.
  408. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  409. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  410. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  411. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  412. """
  413. dlbrt_output = self.distilbert(
  414. input_ids=input_ids,
  415. attention_mask=attention_mask,
  416. inputs_embeds=inputs_embeds,
  417. position_ids=position_ids,
  418. return_dict=True,
  419. **kwargs,
  420. )
  421. hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
  422. prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
  423. prediction_logits = self.activation(prediction_logits) # (bs, seq_length, dim)
  424. prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim)
  425. prediction_logits = self.vocab_projector(prediction_logits) # (bs, seq_length, vocab_size)
  426. mlm_loss = None
  427. if labels is not None:
  428. mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1))
  429. return MaskedLMOutput(
  430. loss=mlm_loss,
  431. logits=prediction_logits,
  432. hidden_states=dlbrt_output.hidden_states,
  433. attentions=dlbrt_output.attentions,
  434. )
  435. @auto_docstring(
  436. custom_intro="""
  437. DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the
  438. pooled output) e.g. for GLUE tasks.
  439. """
  440. )
  441. class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
  442. def __init__(self, config: PreTrainedConfig):
  443. super().__init__(config)
  444. self.num_labels = config.num_labels
  445. self.config = config
  446. self.distilbert = DistilBertModel(config)
  447. self.pre_classifier = nn.Linear(config.dim, config.dim)
  448. self.classifier = nn.Linear(config.dim, config.num_labels)
  449. self.dropout = nn.Dropout(config.seq_classif_dropout)
  450. # Initialize weights and apply final processing
  451. self.post_init()
  452. def get_position_embeddings(self) -> nn.Embedding:
  453. """
  454. Returns the position embeddings
  455. """
  456. return self.distilbert.get_position_embeddings()
  457. def resize_position_embeddings(self, new_num_position_embeddings: int):
  458. """
  459. Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`.
  460. Arguments:
  461. new_num_position_embeddings (`int`):
  462. The number of new position embedding matrix. If position embeddings are learned, increasing the size
  463. will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
  464. end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
  465. size will add correct vectors at the end following the position encoding algorithm, whereas reducing
  466. the size will remove vectors from the end.
  467. """
  468. self.distilbert.resize_position_embeddings(new_num_position_embeddings)
  469. @can_return_tuple
  470. @auto_docstring
  471. def forward(
  472. self,
  473. input_ids: torch.Tensor | None = None,
  474. attention_mask: torch.Tensor | None = None,
  475. inputs_embeds: torch.Tensor | None = None,
  476. labels: torch.LongTensor | None = None,
  477. position_ids: torch.Tensor | None = None,
  478. **kwargs: Unpack[TransformersKwargs],
  479. ) -> SequenceClassifierOutput | tuple[torch.Tensor, ...]:
  480. r"""
  481. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  482. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  483. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  484. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  485. """
  486. distilbert_output = self.distilbert(
  487. input_ids=input_ids,
  488. attention_mask=attention_mask,
  489. inputs_embeds=inputs_embeds,
  490. position_ids=position_ids,
  491. return_dict=True,
  492. **kwargs,
  493. )
  494. hidden_state = distilbert_output[0] # (bs, seq_len, dim)
  495. pooled_output = hidden_state[:, 0] # (bs, dim)
  496. pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
  497. pooled_output = nn.ReLU()(pooled_output) # (bs, dim)
  498. pooled_output = self.dropout(pooled_output) # (bs, dim)
  499. logits = self.classifier(pooled_output) # (bs, num_labels)
  500. loss = None
  501. if labels is not None:
  502. if self.config.problem_type is None:
  503. if self.num_labels == 1:
  504. self.config.problem_type = "regression"
  505. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  506. self.config.problem_type = "single_label_classification"
  507. else:
  508. self.config.problem_type = "multi_label_classification"
  509. if self.config.problem_type == "regression":
  510. loss_fct = MSELoss()
  511. if self.num_labels == 1:
  512. loss = loss_fct(logits.squeeze(), labels.squeeze())
  513. else:
  514. loss = loss_fct(logits, labels)
  515. elif self.config.problem_type == "single_label_classification":
  516. loss_fct = CrossEntropyLoss()
  517. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  518. elif self.config.problem_type == "multi_label_classification":
  519. loss_fct = BCEWithLogitsLoss()
  520. loss = loss_fct(logits, labels)
  521. return SequenceClassifierOutput(
  522. loss=loss,
  523. logits=logits,
  524. hidden_states=distilbert_output.hidden_states,
  525. attentions=distilbert_output.attentions,
  526. )
  527. @auto_docstring
  528. class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
  529. def __init__(self, config: PreTrainedConfig):
  530. super().__init__(config)
  531. self.distilbert = DistilBertModel(config)
  532. self.qa_outputs = nn.Linear(config.dim, config.num_labels)
  533. if config.num_labels != 2:
  534. raise ValueError(f"config.num_labels should be 2, but it is {config.num_labels}")
  535. self.dropout = nn.Dropout(config.qa_dropout)
  536. # Initialize weights and apply final processing
  537. self.post_init()
  538. def get_position_embeddings(self) -> nn.Embedding:
  539. """
  540. Returns the position embeddings
  541. """
  542. return self.distilbert.get_position_embeddings()
  543. def resize_position_embeddings(self, new_num_position_embeddings: int):
  544. """
  545. Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`.
  546. Arguments:
  547. new_num_position_embeddings (`int`):
  548. The number of new position embedding matrix. If position embeddings are learned, increasing the size
  549. will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
  550. end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
  551. size will add correct vectors at the end following the position encoding algorithm, whereas reducing
  552. the size will remove vectors from the end.
  553. """
  554. self.distilbert.resize_position_embeddings(new_num_position_embeddings)
  555. @can_return_tuple
  556. @auto_docstring
  557. def forward(
  558. self,
  559. input_ids: torch.Tensor | None = None,
  560. attention_mask: torch.Tensor | None = None,
  561. inputs_embeds: torch.Tensor | None = None,
  562. start_positions: torch.Tensor | None = None,
  563. end_positions: torch.Tensor | None = None,
  564. position_ids: torch.Tensor | None = None,
  565. **kwargs: Unpack[TransformersKwargs],
  566. ) -> QuestionAnsweringModelOutput | tuple[torch.Tensor, ...]:
  567. r"""
  568. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`):
  569. Indices of input sequence tokens in the vocabulary.
  570. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  571. [`PreTrainedTokenizer.__call__`] for details.
  572. [What are input IDs?](../glossary#input-ids)
  573. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, hidden_size)`, *optional*):
  574. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  575. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  576. model's internal embedding lookup matrix.
  577. """
  578. distilbert_output = self.distilbert(
  579. input_ids=input_ids,
  580. attention_mask=attention_mask,
  581. inputs_embeds=inputs_embeds,
  582. position_ids=position_ids,
  583. return_dict=True,
  584. **kwargs,
  585. )
  586. hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
  587. hidden_states = self.dropout(hidden_states) # (bs, max_query_len, dim)
  588. logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2)
  589. start_logits, end_logits = logits.split(1, dim=-1)
  590. start_logits = start_logits.squeeze(-1).contiguous() # (bs, max_query_len)
  591. end_logits = end_logits.squeeze(-1).contiguous() # (bs, max_query_len)
  592. total_loss = None
  593. if start_positions is not None and end_positions is not None:
  594. # If we are on multi-GPU, split add a dimension
  595. if len(start_positions.size()) > 1:
  596. start_positions = start_positions.squeeze(-1)
  597. if len(end_positions.size()) > 1:
  598. end_positions = end_positions.squeeze(-1)
  599. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  600. ignored_index = start_logits.size(1)
  601. start_positions = start_positions.clamp(0, ignored_index)
  602. end_positions = end_positions.clamp(0, ignored_index)
  603. loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
  604. start_loss = loss_fct(start_logits, start_positions)
  605. end_loss = loss_fct(end_logits, end_positions)
  606. total_loss = (start_loss + end_loss) / 2
  607. return QuestionAnsweringModelOutput(
  608. loss=total_loss,
  609. start_logits=start_logits,
  610. end_logits=end_logits,
  611. hidden_states=distilbert_output.hidden_states,
  612. attentions=distilbert_output.attentions,
  613. )
  614. @auto_docstring
  615. class DistilBertForTokenClassification(DistilBertPreTrainedModel):
  616. def __init__(self, config: PreTrainedConfig):
  617. super().__init__(config)
  618. self.num_labels = config.num_labels
  619. self.distilbert = DistilBertModel(config)
  620. self.dropout = nn.Dropout(config.dropout)
  621. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  622. # Initialize weights and apply final processing
  623. self.post_init()
  624. def get_position_embeddings(self) -> nn.Embedding:
  625. """
  626. Returns the position embeddings
  627. """
  628. return self.distilbert.get_position_embeddings()
  629. def resize_position_embeddings(self, new_num_position_embeddings: int):
  630. """
  631. Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`.
  632. Arguments:
  633. new_num_position_embeddings (`int`):
  634. The number of new position embedding matrix. If position embeddings are learned, increasing the size
  635. will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
  636. end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
  637. size will add correct vectors at the end following the position encoding algorithm, whereas reducing
  638. the size will remove vectors from the end.
  639. """
  640. self.distilbert.resize_position_embeddings(new_num_position_embeddings)
  641. @can_return_tuple
  642. @auto_docstring
  643. def forward(
  644. self,
  645. input_ids: torch.Tensor | None = None,
  646. attention_mask: torch.Tensor | None = None,
  647. inputs_embeds: torch.Tensor | None = None,
  648. labels: torch.LongTensor | None = None,
  649. position_ids: torch.Tensor | None = None,
  650. **kwargs: Unpack[TransformersKwargs],
  651. ) -> TokenClassifierOutput | tuple[torch.Tensor, ...]:
  652. r"""
  653. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  654. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  655. """
  656. outputs = self.distilbert(
  657. input_ids,
  658. attention_mask=attention_mask,
  659. inputs_embeds=inputs_embeds,
  660. position_ids=position_ids,
  661. return_dict=True,
  662. **kwargs,
  663. )
  664. sequence_output = outputs[0]
  665. sequence_output = self.dropout(sequence_output)
  666. logits = self.classifier(sequence_output)
  667. loss = None
  668. if labels is not None:
  669. loss_fct = CrossEntropyLoss()
  670. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  671. return TokenClassifierOutput(
  672. loss=loss,
  673. logits=logits,
  674. hidden_states=outputs.hidden_states,
  675. attentions=outputs.attentions,
  676. )
  677. @auto_docstring
  678. class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
  679. def __init__(self, config: PreTrainedConfig):
  680. super().__init__(config)
  681. self.distilbert = DistilBertModel(config)
  682. self.pre_classifier = nn.Linear(config.dim, config.dim)
  683. self.classifier = nn.Linear(config.dim, 1)
  684. self.dropout = nn.Dropout(config.seq_classif_dropout)
  685. # Initialize weights and apply final processing
  686. self.post_init()
  687. def get_position_embeddings(self) -> nn.Embedding:
  688. """
  689. Returns the position embeddings
  690. """
  691. return self.distilbert.get_position_embeddings()
  692. def resize_position_embeddings(self, new_num_position_embeddings: int):
  693. """
  694. Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`.
  695. Arguments:
  696. new_num_position_embeddings (`int`)
  697. The number of new position embeddings. If position embeddings are learned, increasing the size will add
  698. newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
  699. position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
  700. add correct vectors at the end following the position encoding algorithm, whereas reducing the size
  701. will remove vectors from the end.
  702. """
  703. self.distilbert.resize_position_embeddings(new_num_position_embeddings)
  704. @can_return_tuple
  705. @auto_docstring
  706. def forward(
  707. self,
  708. input_ids: torch.Tensor | None = None,
  709. attention_mask: torch.Tensor | None = None,
  710. inputs_embeds: torch.Tensor | None = None,
  711. labels: torch.LongTensor | None = None,
  712. position_ids: torch.Tensor | None = None,
  713. **kwargs: Unpack[TransformersKwargs],
  714. ) -> MultipleChoiceModelOutput | tuple[torch.Tensor, ...]:
  715. r"""
  716. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  717. Indices of input sequence tokens in the vocabulary.
  718. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  719. [`PreTrainedTokenizer.__call__`] for details.
  720. [What are input IDs?](../glossary#input-ids)
  721. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  722. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  723. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  724. model's internal embedding lookup matrix.
  725. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  726. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  727. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  728. `input_ids` above)
  729. Examples:
  730. ```python
  731. >>> from transformers import AutoTokenizer, DistilBertForMultipleChoice
  732. >>> import torch
  733. >>> tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")
  734. >>> model = DistilBertForMultipleChoice.from_pretrained("distilbert-base-cased")
  735. >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
  736. >>> choice0 = "It is eaten with a fork and a knife."
  737. >>> choice1 = "It is eaten while held in the hand."
  738. >>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1
  739. >>> encoding = tokenizer([[prompt, choice0], [prompt, choice1]], return_tensors="pt", padding=True)
  740. >>> outputs = model(**{k: v.unsqueeze(0) for k, v in encoding.items()}, labels=labels) # batch size is 1
  741. >>> # the linear classifier still needs to be trained
  742. >>> loss = outputs.loss
  743. >>> logits = outputs.logits
  744. ```"""
  745. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  746. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  747. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  748. inputs_embeds = (
  749. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  750. if inputs_embeds is not None
  751. else None
  752. )
  753. outputs = self.distilbert(
  754. input_ids,
  755. attention_mask=attention_mask,
  756. inputs_embeds=inputs_embeds,
  757. position_ids=position_ids,
  758. return_dict=True,
  759. **kwargs,
  760. )
  761. hidden_state = outputs[0] # (bs * num_choices, seq_len, dim)
  762. pooled_output = hidden_state[:, 0] # (bs * num_choices, dim)
  763. pooled_output = self.pre_classifier(pooled_output) # (bs * num_choices, dim)
  764. pooled_output = nn.ReLU()(pooled_output) # (bs * num_choices, dim)
  765. pooled_output = self.dropout(pooled_output) # (bs * num_choices, dim)
  766. logits = self.classifier(pooled_output) # (bs * num_choices, 1)
  767. reshaped_logits = logits.view(-1, num_choices) # (bs, num_choices)
  768. loss = None
  769. if labels is not None:
  770. loss_fct = CrossEntropyLoss()
  771. loss = loss_fct(reshaped_logits, labels)
  772. return MultipleChoiceModelOutput(
  773. loss=loss,
  774. logits=reshaped_logits,
  775. hidden_states=outputs.hidden_states,
  776. attentions=outputs.attentions,
  777. )
  778. __all__ = [
  779. "DistilBertForMaskedLM",
  780. "DistilBertForMultipleChoice",
  781. "DistilBertForQuestionAnswering",
  782. "DistilBertForSequenceClassification",
  783. "DistilBertForTokenClassification",
  784. "DistilBertModel",
  785. "DistilBertPreTrainedModel",
  786. ]