modeling_gptj.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856
  1. # Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch GPT-J model."""
  15. import math
  16. import torch
  17. from torch import nn
  18. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  19. from ... import initialization as init
  20. from ...activations import ACT2FN
  21. from ...cache_utils import Cache, DynamicCache
  22. from ...generation import GenerationMixin
  23. from ...masking_utils import create_causal_mask
  24. from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_outputs import (
  27. BaseModelOutputWithPast,
  28. CausalLMOutputWithPast,
  29. QuestionAnsweringModelOutput,
  30. SequenceClassifierOutputWithPast,
  31. )
  32. from ...modeling_utils import PreTrainedModel
  33. from ...utils import auto_docstring, logging
  34. from .configuration_gptj import GPTJConfig
  35. if is_flash_attn_available():
  36. from ...modeling_flash_attention_utils import _flash_attention_forward
  37. logger = logging.get_logger(__name__)
  38. def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
  39. inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
  40. sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float()
  41. return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
  42. def get_embed_positions(embed_positions, position_ids):
  43. return embed_positions.to(position_ids.device).repeat(position_ids.shape[0], 1, 1)
  44. def rotate_every_two(x: torch.Tensor) -> torch.Tensor:
  45. x1 = x[:, :, :, ::2]
  46. x2 = x[:, :, :, 1::2]
  47. x = torch.stack((-x2, x1), dim=-1)
  48. return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
  49. def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
  50. sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
  51. cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
  52. return (tensor * cos) + (rotate_every_two(tensor) * sin)
  53. class GPTJAttention(nn.Module):
  54. def __init__(self, config, layer_idx=None):
  55. super().__init__()
  56. self.config = config
  57. self.max_positions = config.max_position_embeddings
  58. self.attn_dropout = nn.Dropout(config.attn_pdrop)
  59. self.resid_dropout = nn.Dropout(config.resid_pdrop)
  60. self.is_causal = True
  61. self.layer_idx = layer_idx
  62. if layer_idx is None:
  63. logger.warning_once(
  64. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  65. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  66. "when creating this class."
  67. )
  68. self.embed_dim = config.hidden_size
  69. self.num_attention_heads = config.num_attention_heads
  70. self.head_dim = self.embed_dim // self.num_attention_heads
  71. if self.head_dim * self.num_attention_heads != self.embed_dim:
  72. raise ValueError(
  73. f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
  74. f" `num_attention_heads`: {self.num_attention_heads})."
  75. )
  76. self.scale_attn = math.sqrt(self.head_dim)
  77. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
  78. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
  79. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
  80. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
  81. self.rotary_dim = config.rotary_dim
  82. self.pos_embd_dim = self.rotary_dim or self.embed_dim
  83. self.register_buffer(
  84. "embed_positions", create_sinusoidal_positions(self.max_positions, self.pos_embd_dim), persistent=False
  85. )
  86. def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary):
  87. """
  88. Splits hidden dim into attn_head_size and num_attention_heads
  89. """
  90. new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
  91. tensor = tensor.view(new_shape)
  92. if rotary:
  93. return tensor
  94. if len(tensor.shape) == 5:
  95. return tensor.permute(0, 1, 3, 2, 4) # (batch, blocks, head, block_length, head_features)
  96. elif len(tensor.shape) == 4:
  97. return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
  98. else:
  99. raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
  100. def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
  101. """
  102. Merges attn_head_size dim and num_attn_heads dim into hidden dim
  103. """
  104. if len(tensor.shape) == 5:
  105. tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
  106. elif len(tensor.shape) == 4:
  107. tensor = tensor.permute(0, 2, 1, 3).contiguous()
  108. else:
  109. raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
  110. new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
  111. return tensor.view(new_shape)
  112. def _attn(
  113. self,
  114. query,
  115. key,
  116. value,
  117. attention_mask=None,
  118. ):
  119. # Keep the attention weights computation in fp32 to avoid overflow issues
  120. query = query.to(torch.float32)
  121. key = key.to(torch.float32)
  122. attn_weights = torch.matmul(query, key.transpose(-1, -2))
  123. attn_weights = attn_weights / self.scale_attn
  124. if attention_mask is not None:
  125. attn_weights = attn_weights + attention_mask
  126. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  127. attn_weights = attn_weights.to(value.dtype)
  128. attn_weights = self.attn_dropout(attn_weights)
  129. attn_output = torch.matmul(attn_weights, value)
  130. return attn_output, attn_weights
  131. def _get_embed_positions(self, position_ids):
  132. embed_positions = self.embed_positions
  133. if embed_positions.device != position_ids.device:
  134. embed_positions = embed_positions.to(position_ids.device)
  135. self.embed_positions = embed_positions
  136. return embed_positions.repeat(position_ids.shape[0], 1, 1)
  137. def forward(
  138. self,
  139. hidden_states: torch.FloatTensor,
  140. layer_past: Cache | None = None,
  141. attention_mask: torch.FloatTensor | None = None,
  142. position_ids: torch.LongTensor | None = None,
  143. use_cache: bool | None = False,
  144. output_attentions: bool | None = False,
  145. **kwargs,
  146. ) -> (
  147. tuple[torch.Tensor, tuple[torch.Tensor]]
  148. | tuple[torch.Tensor, tuple[torch.Tensor], tuple[torch.Tensor, ...]]
  149. | None
  150. ):
  151. query = self.q_proj(hidden_states)
  152. key = self.k_proj(hidden_states)
  153. value = self.v_proj(hidden_states)
  154. query = self._split_heads(query, self.num_attention_heads, self.head_dim, True)
  155. key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)
  156. value = self._split_heads(value, self.num_attention_heads, self.head_dim, False)
  157. embed_positions = self._get_embed_positions(position_ids)
  158. repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
  159. sincos = torch.gather(embed_positions, 1, repeated_position_ids).to(key.dtype)
  160. sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
  161. if self.rotary_dim is not None:
  162. k_rot = key[:, :, :, : self.rotary_dim]
  163. k_pass = key[:, :, :, self.rotary_dim :]
  164. q_rot = query[:, :, :, : self.rotary_dim]
  165. q_pass = query[:, :, :, self.rotary_dim :]
  166. k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
  167. q_rot = apply_rotary_pos_emb(q_rot, sin, cos)
  168. key = torch.cat([k_rot, k_pass], dim=-1)
  169. query = torch.cat([q_rot, q_pass], dim=-1)
  170. else:
  171. key = apply_rotary_pos_emb(key, sin, cos)
  172. query = apply_rotary_pos_emb(query, sin, cos)
  173. key = key.permute(0, 2, 1, 3)
  174. query = query.permute(0, 2, 1, 3)
  175. if layer_past is not None:
  176. key, value = layer_past.update(key, value, self.layer_idx)
  177. # compute self-attention: V x Softmax(QK^T)
  178. attn_output, attn_weights = self._attn(query, key, value, attention_mask)
  179. attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
  180. attn_output = self.out_proj(attn_output)
  181. attn_output = self.resid_dropout(attn_output)
  182. return attn_output, attn_weights
  183. class GPTJFlashAttention2(GPTJAttention):
  184. """
  185. GPTJ flash attention module. This module inherits from `GPTJAttention` as the weights of the module stays
  186. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  187. flash attention and deal with padding tokens in case the input contains any of them.
  188. """
  189. def __init__(self, *args, **kwargs):
  190. super().__init__(*args, **kwargs)
  191. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  192. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  193. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  194. self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
  195. def forward(
  196. self,
  197. hidden_states: torch.FloatTensor,
  198. layer_past: Cache | None = None,
  199. attention_mask: torch.FloatTensor | None = None,
  200. position_ids: torch.LongTensor | None = None,
  201. use_cache: bool | None = False,
  202. output_attentions: bool | None = False,
  203. **kwargs,
  204. ) -> (
  205. tuple[torch.Tensor, tuple[torch.Tensor]]
  206. | tuple[torch.Tensor, tuple[torch.Tensor], tuple[torch.Tensor, ...]]
  207. | None
  208. ):
  209. query = self.q_proj(hidden_states)
  210. key = self.k_proj(hidden_states)
  211. value = self.v_proj(hidden_states)
  212. query = self._split_heads(query, self.num_attention_heads, self.head_dim, True)
  213. key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)
  214. value = self._split_heads(value, self.num_attention_heads, self.head_dim, False)
  215. embed_positions = self._get_embed_positions(position_ids)
  216. repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
  217. sincos = torch.gather(embed_positions, 1, repeated_position_ids).to(key.dtype)
  218. sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
  219. if self.rotary_dim is not None:
  220. k_rot = key[:, :, :, : self.rotary_dim]
  221. k_pass = key[:, :, :, self.rotary_dim :]
  222. q_rot = query[:, :, :, : self.rotary_dim]
  223. q_pass = query[:, :, :, self.rotary_dim :]
  224. k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
  225. q_rot = apply_rotary_pos_emb(q_rot, sin, cos)
  226. key = torch.cat([k_rot, k_pass], dim=-1)
  227. query = torch.cat([q_rot, q_pass], dim=-1)
  228. else:
  229. key = apply_rotary_pos_emb(key, sin, cos)
  230. query = apply_rotary_pos_emb(query, sin, cos)
  231. # tanspose to have the desired shape
  232. # before transpose: batch_size x seq_length x num_attention_heads x head_dim
  233. # after transpose: batch_size x num_attention_heads x seq_length x head_dim
  234. key = key.permute(0, 2, 1, 3)
  235. query = query.permute(0, 2, 1, 3)
  236. # value: batch_size x num_attention_heads x seq_length x head_dim
  237. if layer_past is not None:
  238. key, value = layer_past.update(key, value, self.layer_idx)
  239. # The Flash attention requires the input to have the shape
  240. # batch_size x seq_length x head_dim x hidden_dim
  241. # therefore we need to keep the original shape for query and key, and reshape value
  242. # to have the correct shape.
  243. key = key.permute(0, 2, 1, 3).contiguous()
  244. query = query.permute(0, 2, 1, 3).contiguous()
  245. value = value.permute(0, 2, 1, 3).contiguous()
  246. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  247. # therefore the input hidden states gets silently casted in float32. Hence, we need
  248. # cast them back in the correct dtype just to be sure everything works as expected.
  249. # This might slowdown training & inference so it is recommended to not cast the LayerNorms
  250. # in fp32. (LlamaRMSNorm handles it correctly)
  251. input_dtype = query.dtype
  252. device_type = query.device.type if query.device.type != "mps" else "cpu"
  253. if input_dtype == torch.float32:
  254. if torch.is_autocast_enabled(device_type):
  255. target_dtype = torch.get_autocast_dtype(device_type)
  256. # Handle the case where the model is quantized
  257. elif hasattr(self.config, "_is_quantized"):
  258. target_dtype = self.config.dtype
  259. else:
  260. target_dtype = self.q_proj.weight.dtype
  261. logger.warning_once(
  262. f"The input hidden states seems to be silently casted in float32, this might be related to"
  263. f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
  264. f" {target_dtype}."
  265. )
  266. query = query.to(target_dtype)
  267. key = key.to(target_dtype)
  268. value = value.to(target_dtype)
  269. attention_dropout = self.config.attn_pdrop if self.training else 0.0 # attn_pdrop in gptj
  270. query_length = query.shape[1]
  271. # Compute attention
  272. attn_weights = _flash_attention_forward(
  273. query,
  274. key,
  275. value,
  276. attention_mask,
  277. query_length,
  278. dropout=attention_dropout,
  279. is_causal=self.is_causal,
  280. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  281. )
  282. # Reshape outputs
  283. attn_output = attn_weights.reshape(
  284. attn_weights.shape[0], attn_weights.shape[1], attn_weights.shape[2] * attn_weights.shape[3]
  285. )
  286. attn_output = self.out_proj(attn_output)
  287. attn_output = self.resid_dropout(attn_output)
  288. return attn_output, attn_weights
  289. GPTJ_ATTENTION_CLASSES = {
  290. "eager": GPTJAttention,
  291. "flash_attention_2": GPTJFlashAttention2,
  292. }
  293. class GPTJMLP(nn.Module):
  294. def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim
  295. super().__init__()
  296. embed_dim = config.n_embd
  297. self.fc_in = nn.Linear(embed_dim, intermediate_size)
  298. self.fc_out = nn.Linear(intermediate_size, embed_dim)
  299. self.act = ACT2FN[config.activation_function]
  300. self.dropout = nn.Dropout(config.resid_pdrop)
  301. def forward(self, hidden_states: torch.FloatTensor | None) -> torch.FloatTensor:
  302. hidden_states = self.fc_in(hidden_states)
  303. hidden_states = self.act(hidden_states)
  304. hidden_states = self.fc_out(hidden_states)
  305. hidden_states = self.dropout(hidden_states)
  306. return hidden_states
  307. class GPTJBlock(GradientCheckpointingLayer):
  308. def __init__(self, config, layer_idx=None):
  309. super().__init__()
  310. inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
  311. self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
  312. self.attn = GPTJ_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
  313. self.mlp = GPTJMLP(inner_dim, config)
  314. def forward(
  315. self,
  316. hidden_states: torch.FloatTensor | None,
  317. layer_past: Cache | None = None,
  318. attention_mask: torch.FloatTensor | None = None,
  319. position_ids: torch.LongTensor | None = None,
  320. use_cache: bool | None = False,
  321. output_attentions: bool | None = False,
  322. **kwargs,
  323. ) -> tuple[torch.Tensor] | tuple[torch.Tensor, tuple[torch.FloatTensor, ...]] | None:
  324. residual = hidden_states
  325. hidden_states = self.ln_1(hidden_states)
  326. attn_outputs, attn_weights = self.attn(
  327. hidden_states=hidden_states,
  328. layer_past=layer_past,
  329. attention_mask=attention_mask,
  330. position_ids=position_ids,
  331. use_cache=use_cache,
  332. output_attentions=output_attentions,
  333. )
  334. feed_forward_hidden_states = self.mlp(hidden_states)
  335. hidden_states = attn_outputs + feed_forward_hidden_states + residual
  336. return hidden_states, attn_weights
  337. @auto_docstring
  338. class GPTJPreTrainedModel(PreTrainedModel):
  339. config: GPTJConfig
  340. base_model_prefix = "transformer"
  341. supports_gradient_checkpointing = True
  342. _no_split_modules = ["GPTJBlock"]
  343. _skip_keys_device_placement = "past_key_values"
  344. _supports_flash_attn = True
  345. _can_compile_fullgraph = True
  346. def _init_weights(self, module):
  347. super()._init_weights(module)
  348. if isinstance(module, GPTJAttention):
  349. init.copy_(module.embed_positions, create_sinusoidal_positions(module.max_positions, module.pos_embd_dim))
  350. @auto_docstring
  351. class GPTJModel(GPTJPreTrainedModel):
  352. def __init__(self, config):
  353. super().__init__(config)
  354. self.embed_dim = config.n_embd
  355. self.vocab_size = config.vocab_size
  356. self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
  357. self.drop = nn.Dropout(config.embd_pdrop)
  358. self.h = nn.ModuleList([GPTJBlock(config, layer_idx=i) for i in range(config.n_layer)])
  359. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  360. self.gradient_checkpointing = False
  361. # Initialize weights and apply final processing
  362. self.post_init()
  363. def get_input_embeddings(self):
  364. return self.wte
  365. def set_input_embeddings(self, new_embeddings):
  366. self.wte = new_embeddings
  367. @auto_docstring
  368. def forward(
  369. self,
  370. input_ids: torch.LongTensor | None = None,
  371. past_key_values: Cache | None = None,
  372. attention_mask: torch.FloatTensor | None = None,
  373. token_type_ids: torch.LongTensor | None = None,
  374. position_ids: torch.LongTensor | None = None,
  375. inputs_embeds: torch.FloatTensor | None = None,
  376. use_cache: bool | None = None,
  377. output_attentions: bool | None = None,
  378. output_hidden_states: bool | None = None,
  379. return_dict: bool | None = None,
  380. **kwargs,
  381. ) -> tuple | BaseModelOutputWithPast:
  382. r"""
  383. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
  384. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  385. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  386. model's internal embedding lookup matrix.
  387. """
  388. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  389. output_hidden_states = (
  390. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  391. )
  392. use_cache = use_cache if use_cache is not None else self.config.use_cache
  393. return_dict = return_dict if return_dict is not None else self.config.return_dict
  394. if (input_ids is None) ^ (inputs_embeds is not None):
  395. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  396. if self.gradient_checkpointing and self.training:
  397. if use_cache:
  398. logger.warning_once(
  399. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  400. )
  401. use_cache = False
  402. if inputs_embeds is None:
  403. inputs_embeds = self.wte(input_ids)
  404. if use_cache and past_key_values is None:
  405. past_key_values = DynamicCache(config=self.config)
  406. seq_length = inputs_embeds.shape[1]
  407. if position_ids is None:
  408. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  409. position_ids = torch.arange(seq_length, device=inputs_embeds.device) + past_key_values_length
  410. position_ids = position_ids.unsqueeze(0)
  411. causal_mask = create_causal_mask(
  412. config=self.config,
  413. inputs_embeds=inputs_embeds,
  414. attention_mask=attention_mask,
  415. past_key_values=past_key_values,
  416. position_ids=position_ids,
  417. )
  418. hidden_states = inputs_embeds
  419. if token_type_ids is not None:
  420. token_type_ids = token_type_ids.view(-1, seq_length)
  421. token_type_embeds = self.wte(token_type_ids)
  422. hidden_states = hidden_states + token_type_embeds
  423. hidden_states = self.drop(hidden_states)
  424. output_shape = (-1, seq_length, hidden_states.size(-1))
  425. all_self_attentions = () if output_attentions else None
  426. all_hidden_states = () if output_hidden_states else None
  427. for i, block in enumerate(self.h):
  428. if output_hidden_states:
  429. all_hidden_states = all_hidden_states + (hidden_states,)
  430. outputs = block(
  431. hidden_states,
  432. layer_past=past_key_values,
  433. attention_mask=causal_mask,
  434. position_ids=position_ids,
  435. use_cache=use_cache,
  436. output_attentions=output_attentions,
  437. )
  438. hidden_states = outputs[0]
  439. if output_attentions:
  440. all_self_attentions = all_self_attentions + (outputs[1],)
  441. hidden_states = self.ln_f(hidden_states)
  442. hidden_states = hidden_states.view(output_shape)
  443. # Add last hidden state
  444. if output_hidden_states:
  445. all_hidden_states = all_hidden_states + (hidden_states,)
  446. if not return_dict:
  447. return tuple(
  448. v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None
  449. )
  450. return BaseModelOutputWithPast(
  451. last_hidden_state=hidden_states,
  452. past_key_values=past_key_values,
  453. hidden_states=all_hidden_states,
  454. attentions=all_self_attentions,
  455. )
  456. @auto_docstring(
  457. custom_intro="""
  458. The GPT-J Model transformer with a language modeling head on top.
  459. """
  460. )
  461. class GPTJForCausalLM(GPTJPreTrainedModel, GenerationMixin):
  462. _tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"}
  463. def __init__(self, config):
  464. super().__init__(config)
  465. self.transformer = GPTJModel(config)
  466. self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
  467. # Initialize weights and apply final processing
  468. self.post_init()
  469. @auto_docstring
  470. def forward(
  471. self,
  472. input_ids: torch.LongTensor | None = None,
  473. past_key_values: Cache | None = None,
  474. attention_mask: torch.FloatTensor | None = None,
  475. token_type_ids: torch.LongTensor | None = None,
  476. position_ids: torch.LongTensor | None = None,
  477. inputs_embeds: torch.FloatTensor | None = None,
  478. labels: torch.LongTensor | None = None,
  479. use_cache: bool | None = None,
  480. output_attentions: bool | None = None,
  481. output_hidden_states: bool | None = None,
  482. return_dict: bool | None = None,
  483. logits_to_keep: int | torch.Tensor = 0,
  484. **kwargs,
  485. ) -> tuple | CausalLMOutputWithPast:
  486. r"""
  487. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
  488. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  489. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  490. model's internal embedding lookup matrix.
  491. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  492. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  493. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  494. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  495. """
  496. return_dict = return_dict if return_dict is not None else self.config.return_dict
  497. transformer_outputs = self.transformer(
  498. input_ids,
  499. past_key_values=past_key_values,
  500. attention_mask=attention_mask,
  501. token_type_ids=token_type_ids,
  502. position_ids=position_ids,
  503. inputs_embeds=inputs_embeds,
  504. use_cache=use_cache,
  505. output_attentions=output_attentions,
  506. output_hidden_states=output_hidden_states,
  507. return_dict=return_dict,
  508. )
  509. hidden_states = transformer_outputs[0]
  510. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  511. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  512. logits = self.lm_head(hidden_states[:, slice_indices, :])
  513. loss = None
  514. if labels is not None:
  515. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  516. if not return_dict:
  517. output = (logits,) + transformer_outputs[1:]
  518. return ((loss,) + output) if loss is not None else output
  519. return CausalLMOutputWithPast(
  520. loss=loss,
  521. logits=logits,
  522. past_key_values=transformer_outputs.past_key_values,
  523. hidden_states=transformer_outputs.hidden_states,
  524. attentions=transformer_outputs.attentions,
  525. )
  526. @auto_docstring(
  527. custom_intro="""
  528. The GPT-J Model transformer with a sequence classification head on top (linear layer).
  529. [`GPTJForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  530. (e.g. GPT, GPT-2, GPT-Neo) do.
  531. Since it does classification on the last token, it requires to know the position of the last token. If a
  532. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  533. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  534. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  535. each row of the batch).
  536. """
  537. )
  538. class GPTJForSequenceClassification(GPTJPreTrainedModel):
  539. def __init__(self, config):
  540. super().__init__(config)
  541. self.num_labels = config.num_labels
  542. self.transformer = GPTJModel(config)
  543. self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
  544. # Initialize weights and apply final processing
  545. self.post_init()
  546. @auto_docstring
  547. def forward(
  548. self,
  549. input_ids: torch.LongTensor | None = None,
  550. past_key_values: Cache | None = None,
  551. attention_mask: torch.FloatTensor | None = None,
  552. token_type_ids: torch.LongTensor | None = None,
  553. position_ids: torch.LongTensor | None = None,
  554. inputs_embeds: torch.FloatTensor | None = None,
  555. labels: torch.LongTensor | None = None,
  556. use_cache: bool | None = None,
  557. output_attentions: bool | None = None,
  558. output_hidden_states: bool | None = None,
  559. return_dict: bool | None = None,
  560. **kwargs,
  561. ) -> tuple | SequenceClassifierOutputWithPast:
  562. r"""
  563. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
  564. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  565. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  566. model's internal embedding lookup matrix.
  567. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  568. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  569. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  570. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  571. """
  572. return_dict = return_dict if return_dict is not None else self.config.return_dict
  573. transformer_outputs = self.transformer(
  574. input_ids,
  575. past_key_values=past_key_values,
  576. attention_mask=attention_mask,
  577. token_type_ids=token_type_ids,
  578. position_ids=position_ids,
  579. inputs_embeds=inputs_embeds,
  580. use_cache=use_cache,
  581. output_attentions=output_attentions,
  582. output_hidden_states=output_hidden_states,
  583. return_dict=return_dict,
  584. )
  585. hidden_states = transformer_outputs[0]
  586. logits = self.score(hidden_states)
  587. if input_ids is not None:
  588. batch_size = input_ids.shape[0]
  589. else:
  590. batch_size = inputs_embeds.shape[0]
  591. if self.config.pad_token_id is None and batch_size != 1:
  592. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  593. if self.config.pad_token_id is None:
  594. last_non_pad_token = -1
  595. elif input_ids is not None:
  596. # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
  597. non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
  598. token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
  599. last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
  600. else:
  601. last_non_pad_token = -1
  602. logger.warning_once(
  603. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  604. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  605. )
  606. pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
  607. loss = None
  608. if labels is not None:
  609. labels = labels.to(pooled_logits.device)
  610. if self.config.problem_type is None:
  611. if self.num_labels == 1:
  612. self.config.problem_type = "regression"
  613. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  614. self.config.problem_type = "single_label_classification"
  615. else:
  616. self.config.problem_type = "multi_label_classification"
  617. if self.config.problem_type == "regression":
  618. loss_fct = MSELoss()
  619. if self.num_labels == 1:
  620. loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
  621. else:
  622. loss = loss_fct(pooled_logits, labels)
  623. elif self.config.problem_type == "single_label_classification":
  624. loss_fct = CrossEntropyLoss()
  625. loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
  626. elif self.config.problem_type == "multi_label_classification":
  627. loss_fct = BCEWithLogitsLoss()
  628. loss = loss_fct(pooled_logits, labels)
  629. if not return_dict:
  630. output = (pooled_logits,) + transformer_outputs[1:]
  631. return ((loss,) + output) if loss is not None else output
  632. return SequenceClassifierOutputWithPast(
  633. loss=loss,
  634. logits=pooled_logits,
  635. past_key_values=transformer_outputs.past_key_values,
  636. hidden_states=transformer_outputs.hidden_states,
  637. attentions=transformer_outputs.attentions,
  638. )
  639. @auto_docstring
  640. class GPTJForQuestionAnswering(GPTJPreTrainedModel):
  641. def __init__(self, config):
  642. super().__init__(config)
  643. self.num_labels = config.num_labels
  644. self.transformer = GPTJModel(config)
  645. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  646. # Initialize weights and apply final processing
  647. self.post_init()
  648. @auto_docstring
  649. def forward(
  650. self,
  651. input_ids: torch.LongTensor | None = None,
  652. attention_mask: torch.FloatTensor | None = None,
  653. token_type_ids: torch.LongTensor | None = None,
  654. position_ids: torch.LongTensor | None = None,
  655. inputs_embeds: torch.FloatTensor | None = None,
  656. start_positions: torch.LongTensor | None = None,
  657. end_positions: torch.LongTensor | None = None,
  658. output_attentions: bool | None = None,
  659. output_hidden_states: bool | None = None,
  660. return_dict: bool | None = None,
  661. **kwargs,
  662. ) -> tuple | QuestionAnsweringModelOutput:
  663. r"""
  664. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
  665. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  666. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  667. model's internal embedding lookup matrix.
  668. """
  669. return_dict = return_dict if return_dict is not None else self.config.return_dict
  670. outputs = self.transformer(
  671. input_ids,
  672. attention_mask=attention_mask,
  673. token_type_ids=token_type_ids,
  674. position_ids=position_ids,
  675. inputs_embeds=inputs_embeds,
  676. output_attentions=output_attentions,
  677. output_hidden_states=output_hidden_states,
  678. return_dict=return_dict,
  679. )
  680. sequence_output = outputs[0]
  681. logits = self.qa_outputs(sequence_output)
  682. start_logits, end_logits = logits.split(1, dim=-1)
  683. start_logits = start_logits.squeeze(-1).contiguous()
  684. end_logits = end_logits.squeeze(-1).contiguous()
  685. total_loss = None
  686. if start_positions is not None and end_positions is not None:
  687. # If we are on multi-GPU, split add a dimension
  688. if len(start_positions.size()) > 1:
  689. start_positions = start_positions.squeeze(-1).to(start_logits.device)
  690. if len(end_positions.size()) > 1:
  691. end_positions = end_positions.squeeze(-1).to(end_logits.device)
  692. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  693. ignored_index = start_logits.size(1)
  694. start_positions = start_positions.clamp(0, ignored_index)
  695. end_positions = end_positions.clamp(0, ignored_index)
  696. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  697. start_loss = loss_fct(start_logits, start_positions)
  698. end_loss = loss_fct(end_logits, end_positions)
  699. total_loss = (start_loss + end_loss) / 2
  700. if not return_dict:
  701. output = (start_logits, end_logits) + outputs[2:]
  702. return ((total_loss,) + output) if total_loss is not None else output
  703. return QuestionAnsweringModelOutput(
  704. loss=total_loss,
  705. start_logits=start_logits,
  706. end_logits=end_logits,
  707. hidden_states=outputs.hidden_states,
  708. attentions=outputs.attentions,
  709. )
  710. __all__ = [
  711. "GPTJForCausalLM",
  712. "GPTJForQuestionAnswering",
  713. "GPTJForSequenceClassification",
  714. "GPTJModel",
  715. "GPTJPreTrainedModel",
  716. ]