modeling_openai.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728
  1. # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch OpenAI GPT model."""
  16. import math
  17. from collections.abc import Callable
  18. from dataclasses import dataclass
  19. from typing import Any
  20. import torch
  21. from torch import nn
  22. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  23. from ... import initialization as init
  24. from ...activations import gelu_new, get_activation, silu
  25. from ...generation import GenerationMixin
  26. from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
  27. from ...modeling_utils import PreTrainedModel
  28. from ...pytorch_utils import Conv1D
  29. from ...utils import (
  30. ModelOutput,
  31. auto_docstring,
  32. logging,
  33. )
  34. from .configuration_openai import OpenAIGPTConfig
  35. logger = logging.get_logger(__name__)
  36. ACT_FNS = {"relu": nn.ReLU(), "silu": silu, "gelu": gelu_new, "swish": silu}
  37. class Attention(nn.Module):
  38. def __init__(self, nx, n_positions, config, scale=False):
  39. super().__init__()
  40. self.n_positions = n_positions
  41. n_state = nx # in Attention: n_state=768 (nx=n_embd)
  42. if n_state % config.n_head != 0:
  43. raise ValueError(f"Attention n_state shape: {n_state} must be divisible by config.n_head {config.n_head}")
  44. self.register_buffer(
  45. "bias",
  46. torch.tril(torch.ones(n_positions, n_positions)).view(1, 1, n_positions, n_positions),
  47. persistent=False,
  48. )
  49. self.n_head = config.n_head
  50. self.split_size = n_state
  51. self.scale = scale
  52. self.c_attn = Conv1D(n_state * 3, nx)
  53. self.c_proj = Conv1D(n_state, nx)
  54. self.attn_dropout = nn.Dropout(config.attn_pdrop)
  55. self.resid_dropout = nn.Dropout(config.resid_pdrop)
  56. def _attn(self, q, k, v, attention_mask=None, output_attentions=False):
  57. w = torch.matmul(q, k)
  58. if self.scale:
  59. w = w / math.sqrt(v.size(-1))
  60. # XD: self.b may be larger than w, so we need to crop it
  61. b = self.bias[:, :, : w.size(-2), : w.size(-1)]
  62. w = w * b + -1e4 * (1 - b)
  63. if attention_mask is not None:
  64. # Apply the attention mask
  65. w = w + attention_mask
  66. w = nn.functional.softmax(w, dim=-1)
  67. w = self.attn_dropout(w)
  68. outputs = [torch.matmul(w, v)]
  69. if output_attentions:
  70. outputs.append(w)
  71. return outputs
  72. def merge_heads(self, x):
  73. x = x.permute(0, 2, 1, 3).contiguous()
  74. new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
  75. return x.view(*new_x_shape)
  76. def split_heads(self, x, k=False):
  77. new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
  78. x = x.view(*new_x_shape)
  79. if k:
  80. return x.permute(0, 2, 3, 1)
  81. else:
  82. return x.permute(0, 2, 1, 3)
  83. def forward(self, x, attention_mask=None, output_attentions=False):
  84. x = self.c_attn(x)
  85. query, key, value = x.split(self.split_size, dim=2)
  86. query = self.split_heads(query)
  87. key = self.split_heads(key, k=True)
  88. value = self.split_heads(value)
  89. attn_outputs = self._attn(query, key, value, attention_mask, output_attentions)
  90. a = attn_outputs[0]
  91. a = self.merge_heads(a)
  92. a = self.c_proj(a)
  93. a = self.resid_dropout(a)
  94. outputs = [a] + attn_outputs[1:]
  95. return outputs # a, (attentions)
  96. class MLP(nn.Module):
  97. def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
  98. super().__init__()
  99. nx = config.n_embd
  100. self.c_fc = Conv1D(n_state, nx)
  101. self.c_proj = Conv1D(nx, n_state)
  102. self.act = ACT_FNS[config.afn]
  103. self.dropout = nn.Dropout(config.resid_pdrop)
  104. def forward(self, x):
  105. h = self.act(self.c_fc(x))
  106. h2 = self.c_proj(h)
  107. return self.dropout(h2)
  108. class Block(nn.Module):
  109. def __init__(self, n_positions, config, scale=False):
  110. super().__init__()
  111. nx = config.n_embd
  112. self.attn = Attention(nx, n_positions, config, scale)
  113. self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
  114. self.mlp = MLP(4 * nx, config)
  115. self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
  116. def forward(self, x, attention_mask=None, output_attentions=False):
  117. attn_outputs = self.attn(
  118. x,
  119. attention_mask=attention_mask,
  120. output_attentions=output_attentions,
  121. )
  122. a = attn_outputs[0]
  123. n = self.ln_1(x + a)
  124. m = self.mlp(n)
  125. h = self.ln_2(n + m)
  126. outputs = [h] + attn_outputs[1:]
  127. return outputs
  128. # Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->OpenAIGPT
  129. class OpenAIGPTSequenceSummary(nn.Module):
  130. r"""
  131. Compute a single vector summary of a sequence hidden states.
  132. Args:
  133. config ([`OpenAIGPTConfig`]):
  134. The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
  135. config class of your model for the default values it uses):
  136. - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
  137. - `"last"` -- Take the last token hidden state (like XLNet)
  138. - `"first"` -- Take the first token hidden state (like Bert)
  139. - `"mean"` -- Take the mean of all tokens hidden states
  140. - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
  141. - `"attn"` -- Not implemented now, use multi-head attention
  142. - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
  143. - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
  144. (otherwise to `config.hidden_size`).
  145. - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
  146. another string or `None` will add no activation.
  147. - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
  148. - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
  149. """
  150. def __init__(self, config: OpenAIGPTConfig):
  151. super().__init__()
  152. self.summary_type = getattr(config, "summary_type", "last")
  153. if self.summary_type == "attn":
  154. # We should use a standard multi-head attention module with absolute positional embedding for that.
  155. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
  156. # We can probably just use the multi-head attention module of PyTorch >=1.1.0
  157. raise NotImplementedError
  158. self.summary = nn.Identity()
  159. if hasattr(config, "summary_use_proj") and config.summary_use_proj:
  160. if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
  161. num_classes = config.num_labels
  162. else:
  163. num_classes = config.hidden_size
  164. self.summary = nn.Linear(config.hidden_size, num_classes)
  165. activation_string = getattr(config, "summary_activation", None)
  166. self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
  167. self.first_dropout = nn.Identity()
  168. if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
  169. self.first_dropout = nn.Dropout(config.summary_first_dropout)
  170. self.last_dropout = nn.Identity()
  171. if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
  172. self.last_dropout = nn.Dropout(config.summary_last_dropout)
  173. def forward(
  174. self, hidden_states: torch.FloatTensor, cls_index: torch.LongTensor | None = None
  175. ) -> torch.FloatTensor:
  176. """
  177. Compute a single vector summary of a sequence hidden states.
  178. Args:
  179. hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
  180. The hidden states of the last layer.
  181. cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
  182. Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
  183. Returns:
  184. `torch.FloatTensor`: The summary of the sequence hidden states.
  185. """
  186. if self.summary_type == "last":
  187. output = hidden_states[:, -1]
  188. elif self.summary_type == "first":
  189. output = hidden_states[:, 0]
  190. elif self.summary_type == "mean":
  191. output = hidden_states.mean(dim=1)
  192. elif self.summary_type == "cls_index":
  193. if cls_index is None:
  194. cls_index = torch.full_like(
  195. hidden_states[..., :1, :],
  196. hidden_states.shape[-2] - 1,
  197. dtype=torch.long,
  198. )
  199. else:
  200. cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
  201. cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
  202. # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
  203. output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
  204. elif self.summary_type == "attn":
  205. raise NotImplementedError
  206. output = self.first_dropout(output)
  207. output = self.summary(output)
  208. output = self.activation(output)
  209. output = self.last_dropout(output)
  210. return output
  211. @auto_docstring
  212. class OpenAIGPTPreTrainedModel(PreTrainedModel):
  213. config: OpenAIGPTConfig
  214. base_model_prefix = "transformer"
  215. def _init_weights(self, module):
  216. super()._init_weights(module)
  217. if isinstance(module, Attention):
  218. n_positions = module.n_positions
  219. init.copy_(
  220. module.bias, torch.tril(torch.ones(n_positions, n_positions)).view(1, 1, n_positions, n_positions)
  221. )
  222. elif isinstance(module, OpenAIGPTModel):
  223. init.copy_(module.position_ids, torch.arange(module.config.n_positions))
  224. @dataclass
  225. @auto_docstring(
  226. custom_intro="""
  227. Base class for outputs of models predicting if two sentences are consecutive or not.
  228. """
  229. )
  230. class OpenAIGPTDoubleHeadsModelOutput(ModelOutput):
  231. r"""
  232. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  233. Language modeling loss.
  234. mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided):
  235. Multiple choice classification loss.
  236. logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
  237. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  238. mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
  239. Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
  240. """
  241. loss: torch.FloatTensor | None = None
  242. mc_loss: torch.FloatTensor | None = None
  243. logits: torch.FloatTensor | None = None
  244. mc_logits: torch.FloatTensor | None = None
  245. hidden_states: tuple[torch.FloatTensor] | None = None
  246. attentions: tuple[torch.FloatTensor] | None = None
  247. @auto_docstring
  248. class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
  249. def __init__(self, config):
  250. super().__init__(config)
  251. self.tokens_embed = nn.Embedding(config.vocab_size, config.n_embd)
  252. self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
  253. self.drop = nn.Dropout(config.embd_pdrop)
  254. self.h = nn.ModuleList([Block(config.n_positions, config, scale=True) for _ in range(config.n_layer)])
  255. self.register_buffer("position_ids", torch.arange(config.n_positions), persistent=False)
  256. # Initialize weights and apply final processing
  257. self.post_init()
  258. def get_input_embeddings(self):
  259. return self.tokens_embed
  260. def set_input_embeddings(self, new_embeddings):
  261. self.tokens_embed = new_embeddings
  262. @auto_docstring
  263. def forward(
  264. self,
  265. input_ids: torch.LongTensor | None = None,
  266. attention_mask: torch.FloatTensor | None = None,
  267. token_type_ids: torch.LongTensor | None = None,
  268. position_ids: torch.LongTensor | None = None,
  269. inputs_embeds: torch.FloatTensor | None = None,
  270. output_attentions: bool | None = None,
  271. output_hidden_states: bool | None = None,
  272. return_dict: bool | None = None,
  273. **kwargs,
  274. ) -> tuple[torch.Tensor] | BaseModelOutput:
  275. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  276. output_hidden_states = (
  277. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  278. )
  279. return_dict = return_dict if return_dict is not None else self.config.return_dict
  280. if input_ids is not None and inputs_embeds is not None:
  281. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  282. elif input_ids is not None:
  283. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  284. input_shape = input_ids.size()
  285. input_ids = input_ids.view(-1, input_shape[-1])
  286. elif inputs_embeds is not None:
  287. input_shape = inputs_embeds.size()[:-1]
  288. else:
  289. raise ValueError("You have to specify either input_ids or inputs_embeds")
  290. if position_ids is None:
  291. # Code is different from when we had a single embedding matrix from position and token embeddings
  292. position_ids = self.position_ids[None, : input_shape[-1]]
  293. # Attention mask.
  294. if attention_mask is not None:
  295. # We create a 3D attention mask from a 2D tensor mask.
  296. # Sizes are [batch_size, 1, 1, to_seq_length]
  297. # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
  298. # this attention mask is more simple than the triangular masking of causal attention
  299. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
  300. attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
  301. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  302. # masked positions, this operation will create a tensor which is 0.0 for
  303. # positions we want to attend and the dtype's smallest value for masked positions.
  304. # Since we are adding it to the raw scores before the softmax, this is
  305. # effectively the same as removing these entirely.
  306. attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
  307. attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
  308. if inputs_embeds is None:
  309. inputs_embeds = self.tokens_embed(input_ids)
  310. position_embeds = self.positions_embed(position_ids)
  311. if token_type_ids is not None:
  312. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
  313. token_type_embeds = self.tokens_embed(token_type_ids)
  314. else:
  315. token_type_embeds = 0
  316. hidden_states = inputs_embeds + position_embeds + token_type_embeds
  317. hidden_states = self.drop(hidden_states)
  318. output_shape = input_shape + (hidden_states.size(-1),)
  319. all_attentions = () if output_attentions else None
  320. all_hidden_states = () if output_hidden_states else None
  321. for i, block in enumerate(self.h):
  322. if output_hidden_states:
  323. all_hidden_states = all_hidden_states + (hidden_states,)
  324. outputs = block(hidden_states, attention_mask, output_attentions=output_attentions)
  325. hidden_states = outputs[0]
  326. if output_attentions:
  327. all_attentions = all_attentions + (outputs[1],)
  328. hidden_states = hidden_states.view(*output_shape)
  329. # Add last layer
  330. if output_hidden_states:
  331. all_hidden_states = all_hidden_states + (hidden_states,)
  332. if not return_dict:
  333. return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
  334. return BaseModelOutput(
  335. last_hidden_state=hidden_states,
  336. hidden_states=all_hidden_states,
  337. attentions=all_attentions,
  338. )
  339. @auto_docstring(
  340. custom_intro="""
  341. OpenAI GPT Model transformer with a language modeling head on top (linear layer with weights tied to the input
  342. embeddings).
  343. """
  344. )
  345. class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel, GenerationMixin):
  346. _tied_weights_keys = {"lm_head.weight": "transformer.tokens_embed.weight"}
  347. def __init__(self, config):
  348. super().__init__(config)
  349. self.transformer = OpenAIGPTModel(config)
  350. self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
  351. # Initialize weights and apply final processing
  352. self.post_init()
  353. @auto_docstring
  354. def forward(
  355. self,
  356. input_ids: torch.LongTensor | None = None,
  357. attention_mask: torch.FloatTensor | None = None,
  358. token_type_ids: torch.LongTensor | None = None,
  359. position_ids: torch.LongTensor | None = None,
  360. inputs_embeds: torch.FloatTensor | None = None,
  361. labels: torch.LongTensor | None = None,
  362. output_attentions: bool | None = None,
  363. output_hidden_states: bool | None = None,
  364. return_dict: bool | None = None,
  365. logits_to_keep: int | torch.Tensor = 0,
  366. **kwargs,
  367. ) -> tuple[torch.Tensor] | CausalLMOutput:
  368. r"""
  369. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  370. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  371. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  372. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  373. """
  374. return_dict = return_dict if return_dict is not None else self.config.return_dict
  375. transformer_outputs = self.transformer(
  376. input_ids,
  377. attention_mask=attention_mask,
  378. token_type_ids=token_type_ids,
  379. position_ids=position_ids,
  380. inputs_embeds=inputs_embeds,
  381. output_attentions=output_attentions,
  382. output_hidden_states=output_hidden_states,
  383. return_dict=return_dict,
  384. )
  385. hidden_states = transformer_outputs[0]
  386. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  387. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  388. logits = self.lm_head(hidden_states[:, slice_indices, :])
  389. loss = None
  390. if labels is not None:
  391. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  392. if not return_dict:
  393. output = (logits,) + transformer_outputs[1:]
  394. return ((loss,) + output) if loss is not None else output
  395. return CausalLMOutput(
  396. loss=loss,
  397. logits=logits,
  398. hidden_states=transformer_outputs.hidden_states,
  399. attentions=transformer_outputs.attentions,
  400. )
  401. def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> dict[str, Any]:
  402. # Overwritten -- old model with reduced inputs
  403. model_inputs = {"input_ids": input_ids}
  404. # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
  405. for key, value in kwargs.items():
  406. if key not in model_inputs:
  407. model_inputs[key] = value
  408. return model_inputs
  409. @auto_docstring(
  410. custom_intro="""
  411. OpenAI GPT Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
  412. RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
  413. input embeddings, the classification head takes as input the input of a specified classification token index in the
  414. input sequence).
  415. """
  416. )
  417. class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
  418. _tied_weights_keys = {"transformer.tokens_embed.weight": "lm_head.weight"}
  419. def __init__(self, config):
  420. super().__init__(config)
  421. config.num_labels = 1
  422. self.transformer = OpenAIGPTModel(config)
  423. self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
  424. self.multiple_choice_head = OpenAIGPTSequenceSummary(config)
  425. # Initialize weights and apply final processing
  426. self.post_init()
  427. @auto_docstring
  428. def forward(
  429. self,
  430. input_ids: torch.LongTensor | None = None,
  431. attention_mask: torch.FloatTensor | None = None,
  432. token_type_ids: torch.LongTensor | None = None,
  433. position_ids: torch.LongTensor | None = None,
  434. inputs_embeds: torch.FloatTensor | None = None,
  435. mc_token_ids: torch.LongTensor | None = None,
  436. labels: torch.LongTensor | None = None,
  437. mc_labels: torch.LongTensor | None = None,
  438. output_attentions: bool | None = None,
  439. output_hidden_states: bool | None = None,
  440. return_dict: bool | None = None,
  441. **kwargs,
  442. ) -> tuple[torch.Tensor] | OpenAIGPTDoubleHeadsModelOutput:
  443. r"""
  444. mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
  445. Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
  446. 1]`.
  447. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  448. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  449. `labels = input_ids` Indices are selected in `[-1, 0, ..., config.vocab_size]` All labels set to `-100` are
  450. ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  451. mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*):
  452. Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
  453. where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
  454. Examples:
  455. ```python
  456. >>> from transformers import AutoTokenizer, OpenAIGPTDoubleHeadsModel
  457. >>> import torch
  458. >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt")
  459. >>> model = OpenAIGPTDoubleHeadsModel.from_pretrained("openai-community/openai-gpt")
  460. >>> tokenizer.add_special_tokens(
  461. ... {"cls_token": "[CLS]"}
  462. ... ) # Add a [CLS] to the vocabulary (we should train it also!)
  463. >>> model.resize_token_embeddings(len(tokenizer))
  464. >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
  465. >>> input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
  466. >>> mc_token_ids = torch.tensor([input_ids.size(-1) - 1, input_ids.size(-1) - 1]).unsqueeze(0) # Batch size 1
  467. >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
  468. >>> lm_logits = outputs.logits
  469. >>> mc_logits = outputs.mc_logits
  470. ```"""
  471. return_dict = return_dict if return_dict is not None else self.config.return_dict
  472. transformer_outputs = self.transformer(
  473. input_ids,
  474. attention_mask=attention_mask,
  475. token_type_ids=token_type_ids,
  476. position_ids=position_ids,
  477. inputs_embeds=inputs_embeds,
  478. output_attentions=output_attentions,
  479. output_hidden_states=output_hidden_states,
  480. return_dict=return_dict,
  481. )
  482. hidden_states = transformer_outputs[0]
  483. lm_logits = self.lm_head(hidden_states)
  484. mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
  485. lm_loss, mc_loss = None, None
  486. if mc_labels is not None:
  487. loss_fct = CrossEntropyLoss()
  488. mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
  489. if labels is not None:
  490. shift_logits = lm_logits[..., :-1, :].contiguous()
  491. shift_labels = labels[..., 1:].contiguous()
  492. loss_fct = CrossEntropyLoss()
  493. lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
  494. if not return_dict:
  495. output = (lm_logits, mc_logits) + transformer_outputs[1:]
  496. if mc_loss is not None:
  497. output = (mc_loss,) + output
  498. return ((lm_loss,) + output) if lm_loss is not None else output
  499. return OpenAIGPTDoubleHeadsModelOutput(
  500. loss=lm_loss,
  501. mc_loss=mc_loss,
  502. logits=lm_logits,
  503. mc_logits=mc_logits,
  504. hidden_states=transformer_outputs.hidden_states,
  505. attentions=transformer_outputs.attentions,
  506. )
  507. @auto_docstring(
  508. custom_intro="""
  509. The Original OpenAI GPT Model transformer with a sequence classification head on top (linear layer).
  510. [`OpenAIGPTForSequenceClassification`] uses the last token in order to do the classification, as other causal
  511. models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the
  512. last token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding
  513. token in each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since
  514. it cannot guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take
  515. the last value in each row of the batch).
  516. """
  517. )
  518. class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel):
  519. def __init__(self, config):
  520. super().__init__(config)
  521. self.num_labels = config.num_labels
  522. self.transformer = OpenAIGPTModel(config)
  523. self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
  524. # Initialize weights and apply final processing
  525. self.post_init()
  526. @auto_docstring
  527. def forward(
  528. self,
  529. input_ids: torch.LongTensor | None = None,
  530. attention_mask: torch.FloatTensor | None = None,
  531. token_type_ids: torch.LongTensor | None = None,
  532. position_ids: torch.LongTensor | None = None,
  533. inputs_embeds: torch.FloatTensor | None = None,
  534. labels: torch.LongTensor | None = None,
  535. output_attentions: bool | None = None,
  536. output_hidden_states: bool | None = None,
  537. return_dict: bool | None = None,
  538. **kwargs,
  539. ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
  540. r"""
  541. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  542. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  543. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  544. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  545. """
  546. return_dict = return_dict if return_dict is not None else self.config.return_dict
  547. transformer_outputs = self.transformer(
  548. input_ids,
  549. attention_mask=attention_mask,
  550. token_type_ids=token_type_ids,
  551. position_ids=position_ids,
  552. inputs_embeds=inputs_embeds,
  553. output_attentions=output_attentions,
  554. output_hidden_states=output_hidden_states,
  555. return_dict=return_dict,
  556. )
  557. hidden_states = transformer_outputs[0]
  558. logits = self.score(hidden_states)
  559. if input_ids is not None:
  560. batch_size, sequence_length = input_ids.shape[:2]
  561. else:
  562. batch_size, sequence_length = inputs_embeds.shape[:2]
  563. # Ensure the batch size is > 1 if there is no padding.
  564. if self.config.pad_token_id is None and batch_size != 1:
  565. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  566. if self.config.pad_token_id is None:
  567. last_non_pad_token = -1
  568. elif input_ids is not None:
  569. # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
  570. non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
  571. token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
  572. last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
  573. else:
  574. last_non_pad_token = -1
  575. logger.warning_once(
  576. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  577. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  578. )
  579. pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
  580. loss = None
  581. if labels is not None:
  582. if self.config.problem_type is None:
  583. if self.num_labels == 1:
  584. self.config.problem_type = "regression"
  585. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  586. self.config.problem_type = "single_label_classification"
  587. else:
  588. self.config.problem_type = "multi_label_classification"
  589. if self.config.problem_type == "regression":
  590. loss_fct = MSELoss()
  591. if self.num_labels == 1:
  592. loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
  593. else:
  594. loss = loss_fct(pooled_logits, labels)
  595. elif self.config.problem_type == "single_label_classification":
  596. loss_fct = CrossEntropyLoss()
  597. loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
  598. elif self.config.problem_type == "multi_label_classification":
  599. loss_fct = BCEWithLogitsLoss()
  600. loss = loss_fct(pooled_logits, labels)
  601. if not return_dict:
  602. output = (pooled_logits,) + transformer_outputs[1:]
  603. return ((loss,) + output) if loss is not None else output
  604. return SequenceClassifierOutput(
  605. loss=loss,
  606. logits=pooled_logits,
  607. hidden_states=transformer_outputs.hidden_states,
  608. attentions=transformer_outputs.attentions,
  609. )
  610. __all__ = [
  611. "OpenAIGPTDoubleHeadsModel",
  612. "OpenAIGPTForSequenceClassification",
  613. "OpenAIGPTLMHeadModel",
  614. "OpenAIGPTModel",
  615. "OpenAIGPTPreTrainedModel",
  616. ]