modeling_gpt2.py 50 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144
  1. # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch OpenAI GPT-2 model."""
  16. import math
  17. from collections.abc import Callable
  18. from dataclasses import dataclass
  19. import torch
  20. from torch import nn
  21. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  22. from ... import initialization as init
  23. from ...activations import ACT2FN, get_activation
  24. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  25. from ...generation import GenerationMixin
  26. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  27. from ...modeling_layers import GradientCheckpointingLayer
  28. from ...modeling_outputs import (
  29. BaseModelOutputWithPastAndCrossAttentions,
  30. CausalLMOutputWithCrossAttentions,
  31. QuestionAnsweringModelOutput,
  32. SequenceClassifierOutputWithPast,
  33. TokenClassifierOutput,
  34. )
  35. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  36. from ...pytorch_utils import Conv1D
  37. from ...utils import (
  38. ModelOutput,
  39. auto_docstring,
  40. can_return_tuple,
  41. logging,
  42. )
  43. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  44. from ...utils.output_capturing import OutputRecorder, capture_outputs
  45. from .configuration_gpt2 import GPT2Config
  46. logger = logging.get_logger(__name__)
  47. def eager_attention_forward(module, query, key, value, attention_mask, scaling=None, dropout=0.0, **kwargs):
  48. if scaling is None:
  49. scaling = query.size(-1) ** -0.5
  50. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  51. if attention_mask is not None:
  52. attn_weights = attn_weights + attention_mask
  53. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  54. # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
  55. attn_weights = attn_weights.type(value.dtype)
  56. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  57. attn_output = torch.matmul(attn_weights, value)
  58. attn_output = attn_output.transpose(1, 2)
  59. return attn_output, attn_weights
  60. class GPT2Attention(nn.Module):
  61. def __init__(self, config, is_cross_attention=False, layer_idx=None):
  62. super().__init__()
  63. self.config = config
  64. self.embed_dim = config.hidden_size
  65. self.num_heads = config.num_attention_heads
  66. self.head_dim = self.embed_dim // self.num_heads
  67. self.split_size = self.embed_dim
  68. if self.head_dim * self.num_heads != self.embed_dim:
  69. raise ValueError(
  70. f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  71. f" {self.num_heads})."
  72. )
  73. self.scale_attn_weights = config.scale_attn_weights
  74. self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
  75. self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
  76. self.is_cross_attention = is_cross_attention
  77. self.layer_idx = layer_idx
  78. # Precompute unified scaling factor (accounts for both head_dim and layer-wise scaling)
  79. self.scaling = 1.0
  80. if self.scale_attn_weights:
  81. self.scaling = self.head_dim**-0.5
  82. if self.scale_attn_by_inverse_layer_idx:
  83. self.scaling /= float(self.layer_idx + 1)
  84. if self.is_cross_attention:
  85. self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
  86. self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
  87. else:
  88. self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
  89. self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
  90. self.attn_dropout = nn.Dropout(config.attn_pdrop)
  91. self.resid_dropout = nn.Dropout(config.resid_pdrop)
  92. self.is_causal = not is_cross_attention
  93. def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None):
  94. # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
  95. bsz, num_heads, q_seq_len, dk = query.size()
  96. _, _, k_seq_len, _ = key.size()
  97. # Preallocate attn_weights for `baddbmm`
  98. attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
  99. # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
  100. with maybe_autocast(query.device.type, enabled=False):
  101. q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
  102. attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=self.scaling)
  103. attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
  104. if attention_mask is not None:
  105. # Apply the attention mask
  106. attn_weights = attn_weights + attention_mask
  107. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  108. # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
  109. if attn_weights.dtype != torch.float32:
  110. raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
  111. attn_weights = attn_weights.type(value.dtype)
  112. attn_weights = self.attn_dropout(attn_weights)
  113. attn_output = torch.matmul(attn_weights, value)
  114. attn_output = attn_output.transpose(1, 2)
  115. return attn_output, attn_weights
  116. def forward(
  117. self,
  118. hidden_states: tuple[torch.FloatTensor] | None,
  119. past_key_values: Cache | None = None,
  120. attention_mask: torch.FloatTensor | None = None,
  121. encoder_hidden_states: torch.Tensor | None = None,
  122. encoder_attention_mask: torch.FloatTensor | None = None,
  123. output_attentions: bool | None = False,
  124. **kwargs,
  125. ) -> tuple[torch.Tensor | tuple[torch.Tensor], ...]:
  126. is_cross_attention = encoder_hidden_states is not None
  127. if past_key_values is not None:
  128. if isinstance(past_key_values, EncoderDecoderCache):
  129. is_updated = past_key_values.is_updated.get(self.layer_idx)
  130. if is_cross_attention:
  131. # after the first generated id, we can subsequently re-use all key/value_layer from cache
  132. curr_past_key_values = past_key_values.cross_attention_cache
  133. else:
  134. curr_past_key_values = past_key_values.self_attention_cache
  135. else:
  136. curr_past_key_values = past_key_values
  137. if is_cross_attention:
  138. if not hasattr(self, "q_attn"):
  139. raise ValueError(
  140. "If class is used as cross attention, the weights `q_attn` have to be defined. "
  141. "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
  142. )
  143. query_states = self.q_attn(hidden_states)
  144. attention_mask = encoder_attention_mask
  145. # Try to get key/value states from cache if possible
  146. if past_key_values is not None and is_updated:
  147. key_states = curr_past_key_values.layers[self.layer_idx].keys
  148. value_states = curr_past_key_values.layers[self.layer_idx].values
  149. else:
  150. key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
  151. shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
  152. key_states = key_states.view(shape_kv).transpose(1, 2)
  153. value_states = value_states.view(shape_kv).transpose(1, 2)
  154. else:
  155. query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
  156. shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
  157. key_states = key_states.view(shape_kv).transpose(1, 2)
  158. value_states = value_states.view(shape_kv).transpose(1, 2)
  159. shape_q = (*query_states.shape[:-1], -1, self.head_dim)
  160. query_states = query_states.view(shape_q).transpose(1, 2)
  161. if (past_key_values is not None and not is_cross_attention) or (
  162. past_key_values is not None and is_cross_attention and not is_updated
  163. ):
  164. key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
  165. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  166. if is_cross_attention:
  167. past_key_values.is_updated[self.layer_idx] = True
  168. using_eager = self.config._attn_implementation == "eager"
  169. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  170. self.config._attn_implementation, eager_attention_forward
  171. )
  172. if using_eager and self.reorder_and_upcast_attn:
  173. attn_output, attn_weights = self._upcast_and_reordered_attn(
  174. query_states, key_states, value_states, attention_mask
  175. )
  176. else:
  177. attn_output, attn_weights = attention_interface(
  178. self,
  179. query_states,
  180. key_states,
  181. value_states,
  182. attention_mask,
  183. dropout=self.attn_dropout.p if self.training else 0.0,
  184. scaling=self.scaling,
  185. **kwargs,
  186. )
  187. attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
  188. attn_output = self.c_proj(attn_output)
  189. attn_output = self.resid_dropout(attn_output)
  190. return attn_output, attn_weights
  191. class GPT2MLP(nn.Module):
  192. def __init__(self, intermediate_size, config):
  193. super().__init__()
  194. embed_dim = config.hidden_size
  195. self.c_fc = Conv1D(intermediate_size, embed_dim)
  196. self.c_proj = Conv1D(embed_dim, intermediate_size)
  197. self.act = ACT2FN[config.activation_function]
  198. self.dropout = nn.Dropout(config.resid_pdrop)
  199. def forward(self, hidden_states: tuple[torch.FloatTensor] | None) -> torch.FloatTensor:
  200. hidden_states = self.c_fc(hidden_states)
  201. hidden_states = self.act(hidden_states)
  202. hidden_states = self.c_proj(hidden_states)
  203. hidden_states = self.dropout(hidden_states)
  204. return hidden_states
  205. class GPT2Block(GradientCheckpointingLayer):
  206. def __init__(self, config, layer_idx=None):
  207. super().__init__()
  208. hidden_size = config.hidden_size
  209. inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
  210. self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  211. self.attn = GPT2Attention(config=config, layer_idx=layer_idx)
  212. self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  213. if config.add_cross_attention:
  214. self.crossattention = GPT2Attention(config=config, is_cross_attention=True, layer_idx=layer_idx)
  215. self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  216. self.mlp = GPT2MLP(inner_dim, config)
  217. def forward(
  218. self,
  219. hidden_states: tuple[torch.FloatTensor] | None,
  220. past_key_values: Cache | None = None,
  221. attention_mask: torch.FloatTensor | None = None,
  222. encoder_hidden_states: torch.Tensor | None = None,
  223. encoder_attention_mask: torch.FloatTensor | None = None,
  224. use_cache: bool | None = False,
  225. **kwargs,
  226. ) -> torch.Tensor:
  227. residual = hidden_states
  228. hidden_states = self.ln_1(hidden_states)
  229. attn_output, _ = self.attn(
  230. hidden_states,
  231. past_key_values=past_key_values,
  232. attention_mask=attention_mask,
  233. use_cache=use_cache,
  234. **kwargs,
  235. )
  236. # residual connection
  237. hidden_states = attn_output + residual
  238. if encoder_hidden_states is not None:
  239. # add one self-attention block for cross-attention
  240. if not hasattr(self, "crossattention"):
  241. raise ValueError(
  242. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
  243. "cross-attention layers by setting `config.add_cross_attention=True`"
  244. )
  245. residual = hidden_states
  246. hidden_states = self.ln_cross_attn(hidden_states)
  247. cross_attn_output, _ = self.crossattention(
  248. hidden_states,
  249. past_key_values=past_key_values,
  250. attention_mask=attention_mask,
  251. encoder_hidden_states=encoder_hidden_states,
  252. encoder_attention_mask=encoder_attention_mask,
  253. )
  254. # residual connection
  255. hidden_states = residual + cross_attn_output
  256. residual = hidden_states
  257. hidden_states = self.ln_2(hidden_states)
  258. feed_forward_hidden_states = self.mlp(hidden_states)
  259. # residual connection
  260. hidden_states = residual + feed_forward_hidden_states
  261. return hidden_states
  262. # Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->GPT2
  263. class GPT2SequenceSummary(nn.Module):
  264. r"""
  265. Compute a single vector summary of a sequence hidden states.
  266. Args:
  267. config ([`GPT2Config`]):
  268. The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
  269. config class of your model for the default values it uses):
  270. - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
  271. - `"last"` -- Take the last token hidden state (like XLNet)
  272. - `"first"` -- Take the first token hidden state (like Bert)
  273. - `"mean"` -- Take the mean of all tokens hidden states
  274. - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
  275. - `"attn"` -- Not implemented now, use multi-head attention
  276. - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
  277. - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
  278. (otherwise to `config.hidden_size`).
  279. - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
  280. another string or `None` will add no activation.
  281. - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
  282. - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
  283. """
  284. def __init__(self, config: GPT2Config):
  285. super().__init__()
  286. self.summary_type = getattr(config, "summary_type", "last")
  287. if self.summary_type == "attn":
  288. # We should use a standard multi-head attention module with absolute positional embedding for that.
  289. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
  290. # We can probably just use the multi-head attention module of PyTorch >=1.1.0
  291. raise NotImplementedError
  292. self.summary = nn.Identity()
  293. if hasattr(config, "summary_use_proj") and config.summary_use_proj:
  294. if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
  295. num_classes = config.num_labels
  296. else:
  297. num_classes = config.hidden_size
  298. self.summary = nn.Linear(config.hidden_size, num_classes)
  299. activation_string = getattr(config, "summary_activation", None)
  300. self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
  301. self.first_dropout = nn.Identity()
  302. if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
  303. self.first_dropout = nn.Dropout(config.summary_first_dropout)
  304. self.last_dropout = nn.Identity()
  305. if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
  306. self.last_dropout = nn.Dropout(config.summary_last_dropout)
  307. def forward(
  308. self, hidden_states: torch.FloatTensor, cls_index: torch.LongTensor | None = None
  309. ) -> torch.FloatTensor:
  310. """
  311. Compute a single vector summary of a sequence hidden states.
  312. Args:
  313. hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
  314. The hidden states of the last layer.
  315. cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
  316. Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
  317. Returns:
  318. `torch.FloatTensor`: The summary of the sequence hidden states.
  319. """
  320. if self.summary_type == "last":
  321. output = hidden_states[:, -1]
  322. elif self.summary_type == "first":
  323. output = hidden_states[:, 0]
  324. elif self.summary_type == "mean":
  325. output = hidden_states.mean(dim=1)
  326. elif self.summary_type == "cls_index":
  327. if cls_index is None:
  328. cls_index = torch.full_like(
  329. hidden_states[..., :1, :],
  330. hidden_states.shape[-2] - 1,
  331. dtype=torch.long,
  332. )
  333. else:
  334. cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
  335. cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
  336. # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
  337. output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
  338. elif self.summary_type == "attn":
  339. raise NotImplementedError
  340. output = self.first_dropout(output)
  341. output = self.summary(output)
  342. output = self.activation(output)
  343. output = self.last_dropout(output)
  344. return output
  345. @auto_docstring
  346. class GPT2PreTrainedModel(PreTrainedModel):
  347. config: GPT2Config
  348. base_model_prefix = "transformer"
  349. supports_gradient_checkpointing = True
  350. _no_split_modules = ["GPT2Block"]
  351. _skip_keys_device_placement = "past_key_values"
  352. _supports_flash_attn = True
  353. _supports_sdpa = True
  354. _supports_attention_backend = True
  355. _can_compile_fullgraph = True
  356. _can_record_outputs = {
  357. "hidden_states": GPT2Block,
  358. "attentions": OutputRecorder(GPT2Attention, layer_name=".attn", index=1),
  359. "cross_attentions": OutputRecorder(GPT2Attention, layer_name=".crossattention", index=1),
  360. }
  361. # No longer used as we directly use our masks instead
  362. _keys_to_ignore_on_load_unexpected = ["attn.bias", "crossattention.bias"]
  363. @torch.no_grad()
  364. def _init_weights(self, module):
  365. """Initialize the weights."""
  366. if isinstance(module, (nn.Linear, Conv1D)):
  367. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  368. if module.bias is not None:
  369. init.zeros_(module.bias)
  370. elif isinstance(module, nn.Embedding):
  371. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  372. # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
  373. if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
  374. init.zeros_(module.weight[module.padding_idx])
  375. elif isinstance(module, nn.LayerNorm):
  376. init.zeros_(module.bias)
  377. init.ones_(module.weight)
  378. # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
  379. # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
  380. # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
  381. # > -- GPT-2 :: https://openai.com/blog/better-language-models/
  382. #
  383. # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
  384. if isinstance(module, PreTrainedModel):
  385. for name, p in module.named_parameters():
  386. if name == "c_proj.weight":
  387. # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
  388. init.normal_(p, mean=0.0, std=self.config.initializer_range / math.sqrt(2 * self.config.n_layer))
  389. @dataclass
  390. @auto_docstring(
  391. custom_intro="""
  392. Base class for outputs of models predicting if two sentences are consecutive or not.
  393. """
  394. )
  395. class GPT2DoubleHeadsModelOutput(ModelOutput):
  396. r"""
  397. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  398. Language modeling loss.
  399. mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided):
  400. Multiple choice classification loss.
  401. logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
  402. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  403. mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
  404. Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
  405. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  406. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  407. Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
  408. `past_key_values` input) to speed up sequential decoding.
  409. """
  410. loss: torch.FloatTensor | None = None
  411. mc_loss: torch.FloatTensor | None = None
  412. logits: torch.FloatTensor | None = None
  413. mc_logits: torch.FloatTensor | None = None
  414. past_key_values: Cache | None = None
  415. hidden_states: tuple[torch.FloatTensor] | None = None
  416. attentions: tuple[torch.FloatTensor] | None = None
  417. @auto_docstring
  418. class GPT2Model(GPT2PreTrainedModel):
  419. def __init__(self, config):
  420. super().__init__(config)
  421. self.embed_dim = config.hidden_size
  422. self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
  423. self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
  424. self.drop = nn.Dropout(config.embd_pdrop)
  425. self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
  426. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  427. self.gradient_checkpointing = False
  428. self._attn_implementation = config._attn_implementation
  429. # Initialize weights and apply final processing
  430. self.post_init()
  431. def get_input_embeddings(self):
  432. return self.wte
  433. def set_input_embeddings(self, new_embeddings):
  434. self.wte = new_embeddings
  435. @merge_with_config_defaults
  436. @capture_outputs
  437. @auto_docstring
  438. def forward(
  439. self,
  440. input_ids: torch.LongTensor | None = None,
  441. past_key_values: Cache | None = None,
  442. attention_mask: torch.FloatTensor | None = None,
  443. token_type_ids: torch.LongTensor | None = None,
  444. position_ids: torch.LongTensor | None = None,
  445. inputs_embeds: torch.FloatTensor | None = None,
  446. encoder_hidden_states: torch.Tensor | None = None,
  447. encoder_attention_mask: torch.FloatTensor | None = None,
  448. use_cache: bool | None = None,
  449. **kwargs,
  450. ) -> BaseModelOutputWithPastAndCrossAttentions:
  451. r"""
  452. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  453. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
  454. `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
  455. sequence tokens in the vocabulary.
  456. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  457. `input_ids`.
  458. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  459. [`PreTrainedTokenizer.__call__`] for details.
  460. [What are input IDs?](../glossary#input-ids)
  461. """
  462. kwargs.pop("output_attentions", None)
  463. kwargs.pop("output_hidden_states", None)
  464. if input_ids is not None and inputs_embeds is not None:
  465. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  466. elif input_ids is not None:
  467. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  468. input_shape = input_ids.size()
  469. input_ids = input_ids.view(-1, input_shape[-1])
  470. batch_size = input_ids.shape[0]
  471. elif inputs_embeds is not None:
  472. input_shape = inputs_embeds.size()[:-1]
  473. batch_size = inputs_embeds.shape[0]
  474. else:
  475. raise ValueError("You have to specify either input_ids or inputs_embeds")
  476. if token_type_ids is not None:
  477. token_type_ids = token_type_ids.view(-1, input_shape[-1])
  478. # based on pattern from src/transformers/models/whisper/modeling_whisper.py::WhisperDecoder
  479. if use_cache:
  480. if past_key_values is None:
  481. past_key_values = DynamicCache(config=self.config)
  482. if self.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache):
  483. past_key_values = EncoderDecoderCache(past_key_values, DynamicCache(config=self.config))
  484. if inputs_embeds is None:
  485. inputs_embeds = self.wte(input_ids)
  486. if position_ids is None:
  487. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  488. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  489. position_ids = position_ids.unsqueeze(0)
  490. position_embeds = self.wpe(position_ids)
  491. hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)
  492. # Attention mask.
  493. if attention_mask is not None and attention_mask.ndim < 4:
  494. attention_mask = attention_mask.view(batch_size, -1)
  495. causal_mask = create_causal_mask(
  496. config=self.config,
  497. inputs_embeds=inputs_embeds,
  498. attention_mask=attention_mask,
  499. past_key_values=past_key_values,
  500. position_ids=position_ids,
  501. )
  502. encoder_attention_mask = None
  503. if encoder_hidden_states is not None:
  504. encoder_attention_mask = create_bidirectional_mask(
  505. config=self.config,
  506. inputs_embeds=inputs_embeds,
  507. attention_mask=encoder_attention_mask,
  508. encoder_hidden_states=encoder_hidden_states,
  509. )
  510. if token_type_ids is not None:
  511. token_type_embeds = self.wte(token_type_ids)
  512. hidden_states = hidden_states + token_type_embeds
  513. hidden_states = self.drop(hidden_states)
  514. output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
  515. for i, block in enumerate(self.h):
  516. hidden_states = block(
  517. hidden_states,
  518. past_key_values if not (self.gradient_checkpointing and self.training) else None,
  519. causal_mask,
  520. encoder_hidden_states, # as a positional argument for gradient checkpointing
  521. encoder_attention_mask=encoder_attention_mask,
  522. use_cache=use_cache,
  523. position_ids=position_ids,
  524. **kwargs,
  525. )
  526. hidden_states = self.ln_f(hidden_states)
  527. hidden_states = hidden_states.view(output_shape)
  528. past_key_values = past_key_values if use_cache else None
  529. return BaseModelOutputWithPastAndCrossAttentions(
  530. last_hidden_state=hidden_states,
  531. past_key_values=past_key_values,
  532. )
  533. @auto_docstring(
  534. custom_intro="""
  535. The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
  536. embeddings).
  537. """
  538. )
  539. class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin):
  540. _tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"}
  541. def __init__(self, config):
  542. super().__init__(config)
  543. self.transformer = GPT2Model(config)
  544. self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
  545. # Initialize weights and apply final processing
  546. self.post_init()
  547. @can_return_tuple
  548. @auto_docstring
  549. def forward(
  550. self,
  551. input_ids: torch.LongTensor | None = None,
  552. past_key_values: Cache | None = None,
  553. attention_mask: torch.FloatTensor | None = None,
  554. token_type_ids: torch.LongTensor | None = None,
  555. position_ids: torch.LongTensor | None = None,
  556. inputs_embeds: torch.FloatTensor | None = None,
  557. encoder_hidden_states: torch.Tensor | None = None,
  558. encoder_attention_mask: torch.FloatTensor | None = None,
  559. labels: torch.LongTensor | None = None,
  560. use_cache: bool | None = None,
  561. logits_to_keep: int | torch.Tensor = 0,
  562. **kwargs,
  563. ) -> CausalLMOutputWithCrossAttentions:
  564. r"""
  565. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  566. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
  567. `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
  568. sequence tokens in the vocabulary.
  569. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  570. `input_ids`.
  571. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  572. [`PreTrainedTokenizer.__call__`] for details.
  573. [What are input IDs?](../glossary#input-ids)
  574. labels (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
  575. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  576. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  577. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  578. """
  579. transformer_outputs: BaseModelOutputWithPastAndCrossAttentions = self.transformer(
  580. input_ids,
  581. past_key_values=past_key_values,
  582. attention_mask=attention_mask,
  583. token_type_ids=token_type_ids,
  584. position_ids=position_ids,
  585. inputs_embeds=inputs_embeds,
  586. encoder_hidden_states=encoder_hidden_states,
  587. encoder_attention_mask=encoder_attention_mask,
  588. use_cache=use_cache,
  589. **kwargs,
  590. )
  591. hidden_states = transformer_outputs.last_hidden_state
  592. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  593. logits = self.lm_head(hidden_states[:, slice_indices, :])
  594. loss = None
  595. if labels is not None:
  596. # Flatten the tokens
  597. loss = self.loss_function(
  598. logits,
  599. labels,
  600. vocab_size=self.config.vocab_size,
  601. **kwargs,
  602. )
  603. return CausalLMOutputWithCrossAttentions(
  604. loss=loss,
  605. logits=logits,
  606. past_key_values=transformer_outputs.past_key_values,
  607. hidden_states=transformer_outputs.hidden_states,
  608. attentions=transformer_outputs.attentions,
  609. cross_attentions=transformer_outputs.cross_attentions,
  610. )
  611. @auto_docstring(
  612. custom_intro="""
  613. The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
  614. RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
  615. input embeddings, the classification head takes as input the input of a specified classification token index in the
  616. input sequence).
  617. """
  618. )
  619. class GPT2DoubleHeadsModel(GPT2PreTrainedModel, GenerationMixin):
  620. _tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"}
  621. def __init__(self, config):
  622. super().__init__(config)
  623. config.num_labels = 1
  624. self.transformer = GPT2Model(config)
  625. self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
  626. self.multiple_choice_head = GPT2SequenceSummary(config)
  627. # Initialize weights and apply final processing
  628. self.post_init()
  629. @can_return_tuple
  630. @auto_docstring
  631. def forward(
  632. self,
  633. input_ids: torch.LongTensor | None = None,
  634. past_key_values: Cache | None = None,
  635. attention_mask: torch.FloatTensor | None = None,
  636. token_type_ids: torch.LongTensor | None = None,
  637. position_ids: torch.LongTensor | None = None,
  638. inputs_embeds: torch.FloatTensor | None = None,
  639. mc_token_ids: torch.LongTensor | None = None,
  640. labels: torch.LongTensor | None = None,
  641. mc_labels: torch.LongTensor | None = None,
  642. use_cache: bool | None = None,
  643. **kwargs,
  644. ) -> GPT2DoubleHeadsModelOutput:
  645. r"""
  646. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  647. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
  648. `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
  649. sequence tokens in the vocabulary.
  650. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  651. `input_ids`.
  652. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  653. [`PreTrainedTokenizer.__call__`] for details.
  654. [What are input IDs?](../glossary#input-ids)
  655. mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
  656. Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
  657. 1]`.
  658. labels (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
  659. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  660. `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to
  661. `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]`
  662. mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*):
  663. Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
  664. where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
  665. Example:
  666. ```python
  667. >>> import torch
  668. >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel
  669. >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
  670. >>> model = GPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2")
  671. >>> # Add a [CLS] to the vocabulary (we should train it also!)
  672. >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"})
  673. >>> # Update the model embeddings with the new vocabulary size
  674. >>> embedding_layer = model.resize_token_embeddings(len(tokenizer))
  675. >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
  676. >>> encoded_choices = [tokenizer.encode(s) for s in choices]
  677. >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
  678. >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2
  679. >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1
  680. >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
  681. >>> lm_logits = outputs.logits
  682. >>> mc_logits = outputs.mc_logits
  683. ```"""
  684. transformer_outputs: BaseModelOutputWithPastAndCrossAttentions = self.transformer(
  685. input_ids,
  686. past_key_values=past_key_values,
  687. attention_mask=attention_mask,
  688. token_type_ids=token_type_ids,
  689. position_ids=position_ids,
  690. inputs_embeds=inputs_embeds,
  691. use_cache=use_cache,
  692. **kwargs,
  693. )
  694. hidden_states = transformer_outputs.last_hidden_state
  695. lm_logits = self.lm_head(hidden_states)
  696. mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
  697. mc_loss = None
  698. if mc_labels is not None:
  699. loss_fct = CrossEntropyLoss()
  700. mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
  701. lm_loss = None
  702. if labels is not None:
  703. labels = labels.to(lm_logits.device)
  704. shift_logits = lm_logits[..., :-1, :].contiguous()
  705. shift_labels = labels[..., 1:].contiguous()
  706. loss_fct = CrossEntropyLoss()
  707. lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
  708. return GPT2DoubleHeadsModelOutput(
  709. loss=lm_loss,
  710. mc_loss=mc_loss,
  711. logits=lm_logits,
  712. mc_logits=mc_logits,
  713. past_key_values=transformer_outputs.past_key_values,
  714. hidden_states=transformer_outputs.hidden_states,
  715. attentions=transformer_outputs.attentions,
  716. )
  717. @auto_docstring(
  718. custom_intro="""
  719. The GPT2 Model transformer with a sequence classification head on top (linear layer).
  720. [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  721. (e.g. GPT-1) do.
  722. Since it does classification on the last token, it requires to know the position of the last token. If a
  723. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  724. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  725. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  726. each row of the batch).
  727. """
  728. )
  729. class GPT2ForSequenceClassification(GPT2PreTrainedModel):
  730. def __init__(self, config):
  731. super().__init__(config)
  732. self.num_labels = config.num_labels
  733. self.transformer = GPT2Model(config)
  734. self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
  735. # Initialize weights and apply final processing
  736. self.post_init()
  737. @can_return_tuple
  738. @auto_docstring
  739. def forward(
  740. self,
  741. input_ids: torch.LongTensor | None = None,
  742. past_key_values: Cache | None = None,
  743. attention_mask: torch.FloatTensor | None = None,
  744. token_type_ids: torch.LongTensor | None = None,
  745. position_ids: torch.LongTensor | None = None,
  746. inputs_embeds: torch.FloatTensor | None = None,
  747. labels: torch.LongTensor | None = None,
  748. use_cache: bool | None = None,
  749. **kwargs,
  750. ) -> SequenceClassifierOutputWithPast:
  751. r"""
  752. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  753. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
  754. `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
  755. sequence tokens in the vocabulary.
  756. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  757. `input_ids`.
  758. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  759. [`PreTrainedTokenizer.__call__`] for details.
  760. [What are input IDs?](../glossary#input-ids)
  761. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  762. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  763. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  764. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  765. """
  766. transformer_outputs: BaseModelOutputWithPastAndCrossAttentions = self.transformer(
  767. input_ids,
  768. past_key_values=past_key_values,
  769. attention_mask=attention_mask,
  770. token_type_ids=token_type_ids,
  771. position_ids=position_ids,
  772. inputs_embeds=inputs_embeds,
  773. use_cache=use_cache,
  774. **kwargs,
  775. )
  776. hidden_states = transformer_outputs.last_hidden_state
  777. logits = self.score(hidden_states)
  778. if input_ids is not None:
  779. batch_size, sequence_length = input_ids.shape[:2]
  780. else:
  781. batch_size, sequence_length = inputs_embeds.shape[:2]
  782. if self.config.pad_token_id is None and batch_size != 1:
  783. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  784. if self.config.pad_token_id is None:
  785. last_non_pad_token = -1
  786. elif input_ids is not None:
  787. # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
  788. non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
  789. token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
  790. last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
  791. else:
  792. last_non_pad_token = -1
  793. logger.warning_once(
  794. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  795. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  796. )
  797. pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
  798. loss = None
  799. if labels is not None:
  800. if self.config.problem_type is None:
  801. if self.num_labels == 1:
  802. self.config.problem_type = "regression"
  803. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  804. self.config.problem_type = "single_label_classification"
  805. else:
  806. self.config.problem_type = "multi_label_classification"
  807. if self.config.problem_type == "regression":
  808. loss_fct = MSELoss()
  809. if self.num_labels == 1:
  810. loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
  811. else:
  812. loss = loss_fct(pooled_logits, labels)
  813. elif self.config.problem_type == "single_label_classification":
  814. loss_fct = CrossEntropyLoss()
  815. loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
  816. elif self.config.problem_type == "multi_label_classification":
  817. loss_fct = BCEWithLogitsLoss()
  818. loss = loss_fct(pooled_logits, labels)
  819. return SequenceClassifierOutputWithPast(
  820. loss=loss,
  821. logits=pooled_logits,
  822. past_key_values=transformer_outputs.past_key_values,
  823. hidden_states=transformer_outputs.hidden_states,
  824. attentions=transformer_outputs.attentions,
  825. )
  826. @auto_docstring
  827. class GPT2ForTokenClassification(GPT2PreTrainedModel):
  828. def __init__(self, config):
  829. super().__init__(config)
  830. self.num_labels = config.num_labels
  831. self.transformer = GPT2Model(config)
  832. if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
  833. classifier_dropout = config.classifier_dropout
  834. elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
  835. classifier_dropout = config.hidden_dropout
  836. else:
  837. classifier_dropout = 0.1
  838. self.dropout = nn.Dropout(classifier_dropout)
  839. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  840. # Initialize weights and apply final processing
  841. self.post_init()
  842. @can_return_tuple
  843. @auto_docstring
  844. def forward(
  845. self,
  846. input_ids: torch.LongTensor | None = None,
  847. past_key_values: Cache | None = None,
  848. attention_mask: torch.FloatTensor | None = None,
  849. token_type_ids: torch.LongTensor | None = None,
  850. position_ids: torch.LongTensor | None = None,
  851. inputs_embeds: torch.FloatTensor | None = None,
  852. labels: torch.LongTensor | None = None,
  853. use_cache: bool | None = None,
  854. **kwargs,
  855. ) -> TokenClassifierOutput:
  856. r"""
  857. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  858. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
  859. `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
  860. sequence tokens in the vocabulary.
  861. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  862. `input_ids`.
  863. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  864. [`PreTrainedTokenizer.__call__`] for details.
  865. [What are input IDs?](../glossary#input-ids)
  866. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  867. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  868. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  869. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  870. """
  871. transformer_outputs: BaseModelOutputWithPastAndCrossAttentions = self.transformer(
  872. input_ids,
  873. past_key_values=past_key_values,
  874. attention_mask=attention_mask,
  875. token_type_ids=token_type_ids,
  876. position_ids=position_ids,
  877. inputs_embeds=inputs_embeds,
  878. use_cache=use_cache,
  879. **kwargs,
  880. )
  881. hidden_states = transformer_outputs.last_hidden_state
  882. hidden_states = self.dropout(hidden_states)
  883. logits = self.classifier(hidden_states)
  884. loss = None
  885. if labels is not None:
  886. labels = labels.to(logits.device)
  887. loss_fct = CrossEntropyLoss()
  888. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  889. return TokenClassifierOutput(
  890. loss=loss,
  891. logits=logits,
  892. hidden_states=transformer_outputs.hidden_states,
  893. attentions=transformer_outputs.attentions,
  894. )
  895. @auto_docstring
  896. class GPT2ForQuestionAnswering(GPT2PreTrainedModel):
  897. def __init__(self, config):
  898. super().__init__(config)
  899. self.num_labels = config.num_labels
  900. self.transformer = GPT2Model(config)
  901. self.qa_outputs = nn.Linear(config.hidden_size, 2)
  902. # Initialize weights and apply final processing
  903. self.post_init()
  904. @can_return_tuple
  905. @auto_docstring
  906. def forward(
  907. self,
  908. input_ids: torch.LongTensor | None = None,
  909. attention_mask: torch.FloatTensor | None = None,
  910. token_type_ids: torch.LongTensor | None = None,
  911. position_ids: torch.LongTensor | None = None,
  912. inputs_embeds: torch.FloatTensor | None = None,
  913. start_positions: torch.LongTensor | None = None,
  914. end_positions: torch.LongTensor | None = None,
  915. **kwargs,
  916. ) -> QuestionAnsweringModelOutput:
  917. r"""
  918. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  919. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
  920. `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
  921. sequence tokens in the vocabulary.
  922. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  923. `input_ids`.
  924. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  925. [`PreTrainedTokenizer.__call__`] for details.
  926. [What are input IDs?](../glossary#input-ids)
  927. """
  928. outputs: BaseModelOutputWithPastAndCrossAttentions = self.transformer(
  929. input_ids,
  930. attention_mask=attention_mask,
  931. token_type_ids=token_type_ids,
  932. position_ids=position_ids,
  933. inputs_embeds=inputs_embeds,
  934. **kwargs,
  935. )
  936. sequence_output = outputs.last_hidden_state
  937. logits = self.qa_outputs(sequence_output)
  938. start_logits, end_logits = logits.split(1, dim=-1)
  939. start_logits = start_logits.squeeze(-1).contiguous()
  940. end_logits = end_logits.squeeze(-1).contiguous()
  941. total_loss = None
  942. if start_positions is not None and end_positions is not None:
  943. # If we are on multi-GPU, split add a dimension
  944. if len(start_positions.size()) > 1:
  945. start_positions = start_positions.squeeze(-1).to(start_logits.device)
  946. if len(end_positions.size()) > 1:
  947. end_positions = end_positions.squeeze(-1).to(end_logits.device)
  948. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  949. ignored_index = start_logits.size(1)
  950. start_positions = start_positions.clamp(0, ignored_index)
  951. end_positions = end_positions.clamp(0, ignored_index)
  952. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  953. start_loss = loss_fct(start_logits, start_positions)
  954. end_loss = loss_fct(end_logits, end_positions)
  955. total_loss = (start_loss + end_loss) / 2
  956. return QuestionAnsweringModelOutput(
  957. loss=total_loss,
  958. start_logits=start_logits,
  959. end_logits=end_logits,
  960. hidden_states=outputs.hidden_states,
  961. attentions=outputs.attentions,
  962. )
  963. __all__ = [
  964. "GPT2DoubleHeadsModel",
  965. "GPT2ForQuestionAnswering",
  966. "GPT2ForSequenceClassification",
  967. "GPT2ForTokenClassification",
  968. "GPT2LMHeadModel",
  969. "GPT2Model",
  970. "GPT2PreTrainedModel",
  971. ]